diff --git a/mmodel/theano/THEANO.py b/mmodel/theano/THEANO.py index f3370d5..dd51d16 100644 --- a/mmodel/theano/THEANO.py +++ b/mmodel/theano/THEANO.py @@ -17,7 +17,7 @@ class ModelTHEANO(ModelBase): self.sparker = sc self.model = None - def _train_cnn(X, Y, learning_rate=0.1, n_epochs=200, nkerns=[20, 50, 50], + def _train_cnn(self, X, Y, learning_rate=0.1, n_epochs=200, nkerns=[20, 50, 50], batch_size=200): X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) -- libgit2 0.21.2