Commit 6f94e4410fc543a4a1485f295032f7698e4a5b3e

Authored by Chunk
1 parent d2603183
Exists in master and in 1 other branch refactor

staged.

Showing 1 changed file with 6 additions and 7 deletions   Show diff stats
mmodel/theano/THEANO.py
... ... @@ -37,16 +37,15 @@ class ModelTHEANO(ModelBase):
37 37 else:
38 38 X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0)
39 39  
40   - print type(X), type(X_train), type(X_train[0])
41   - return
42   -
43   - X_train, Y_train = np.array(X_train), np.array(Y_train)
44   - X_test, Y_test = np.array(X_test), np.array(Y_test)
  40 + X_train = theano.shared(np.asarray(X_train, dtype=theano.config.floatX), borrow=True)
  41 + Y_train = theano.shared(np.asarray(Y_train, dtype=theano.config.floatX), borrow=True)
  42 + X_test = theano.shared(np.asarray(X_test, dtype=theano.config.floatX), borrow=True)
  43 + Y_test = theano.shared(np.asarray(Y_test, dtype=theano.config.floatX), borrow=True)
45 44  
46 45 n_train_batches = X_train.shape[0] / batch_size
47 46 n_test_batches = X_test.shape[0] / batch_size
48 47  
49   - rng = np.random.RandomState("whoami")
  48 + rng = np.random.RandomState(12306)
50 49 index = T.lscalar()
51 50 x = T.matrix('x')
52 51 y = T.ivector('y')
... ... @@ -89,7 +88,7 @@ class ModelTHEANO(ModelBase):
89 88 layer2 = ConvPoolLayer(
90 89 rng,
91 90 input=layer1.output,
92   - image_shape=(batch_size, nkerns[0], 16, 16),
  91 + image_shape=(batch_size, nkerns[1], 16, 16),
93 92 filter_shape=(nkerns[2], nkerns[1], 5, 5),
94 93 poolsize=(3, 3)
95 94 )
... ...