Commit 6f94e4410fc543a4a1485f295032f7698e4a5b3e
1 parent
d2603183
Exists in
master
and in
1 other branch
staged.
Showing
1 changed file
with
6 additions
and
7 deletions
Show diff stats
mmodel/theano/THEANO.py
@@ -37,16 +37,15 @@ class ModelTHEANO(ModelBase): | @@ -37,16 +37,15 @@ class ModelTHEANO(ModelBase): | ||
37 | else: | 37 | else: |
38 | X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | 38 | X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) |
39 | 39 | ||
40 | - print type(X), type(X_train), type(X_train[0]) | ||
41 | - return | ||
42 | - | ||
43 | - X_train, Y_train = np.array(X_train), np.array(Y_train) | ||
44 | - X_test, Y_test = np.array(X_test), np.array(Y_test) | 40 | + X_train = theano.shared(np.asarray(X_train, dtype=theano.config.floatX), borrow=True) |
41 | + Y_train = theano.shared(np.asarray(Y_train, dtype=theano.config.floatX), borrow=True) | ||
42 | + X_test = theano.shared(np.asarray(X_test, dtype=theano.config.floatX), borrow=True) | ||
43 | + Y_test = theano.shared(np.asarray(Y_test, dtype=theano.config.floatX), borrow=True) | ||
45 | 44 | ||
46 | n_train_batches = X_train.shape[0] / batch_size | 45 | n_train_batches = X_train.shape[0] / batch_size |
47 | n_test_batches = X_test.shape[0] / batch_size | 46 | n_test_batches = X_test.shape[0] / batch_size |
48 | 47 | ||
49 | - rng = np.random.RandomState("whoami") | 48 | + rng = np.random.RandomState(12306) |
50 | index = T.lscalar() | 49 | index = T.lscalar() |
51 | x = T.matrix('x') | 50 | x = T.matrix('x') |
52 | y = T.ivector('y') | 51 | y = T.ivector('y') |
@@ -89,7 +88,7 @@ class ModelTHEANO(ModelBase): | @@ -89,7 +88,7 @@ class ModelTHEANO(ModelBase): | ||
89 | layer2 = ConvPoolLayer( | 88 | layer2 = ConvPoolLayer( |
90 | rng, | 89 | rng, |
91 | input=layer1.output, | 90 | input=layer1.output, |
92 | - image_shape=(batch_size, nkerns[0], 16, 16), | 91 | + image_shape=(batch_size, nkerns[1], 16, 16), |
93 | filter_shape=(nkerns[2], nkerns[1], 5, 5), | 92 | filter_shape=(nkerns[2], nkerns[1], 5, 5), |
94 | poolsize=(3, 3) | 93 | poolsize=(3, 3) |
95 | ) | 94 | ) |