Commit 5a469df5bf54b238d31c2d676fa4c8929e18ed1c
1 parent
e585c8c2
Exists in
master
and in
1 other branch
staged.
Showing
3 changed files
with
31 additions
and
5 deletions
Show diff stats
mdata/ILSVRC.py
mmodel/theano/THEANO.py
| ... | ... | @@ -9,6 +9,11 @@ from .theanoutil import * |
| 9 | 9 | import numpy as np |
| 10 | 10 | from sklearn import cross_validation |
| 11 | 11 | |
| 12 | +import gzip | |
| 13 | +import cPickle | |
| 14 | + | |
| 15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | |
| 16 | + | |
| 12 | 17 | |
| 13 | 18 | class ModelTHEANO(ModelBase): |
| 14 | 19 | def __init__(self, toolset='cnn', sc=None): |
| ... | ... | @@ -17,9 +22,23 @@ class ModelTHEANO(ModelBase): |
| 17 | 22 | self.sparker = sc |
| 18 | 23 | self.model = None |
| 19 | 24 | |
| 20 | - def _train_cnn(self, X, Y, learning_rate=0.1, n_epochs=200, nkerns=[20, 50, 50], | |
| 25 | + def _train_cnn(self, X=None, Y=None, dataset=os.path.join(package_dir, '../../res/', 'ils_crop.pkl'), | |
| 26 | + learning_rate=0.1, n_epochs=200, | |
| 27 | + nkerns=[20, 50, 50], | |
| 21 | 28 | batch_size=200): |
| 22 | - X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | |
| 29 | + | |
| 30 | + if X == None: | |
| 31 | + assert dataset != None | |
| 32 | + with open(dataset, 'rb') as f: | |
| 33 | + train_set, test_set = cPickle.load(f) | |
| 34 | + | |
| 35 | + X_train, Y_train = train_set | |
| 36 | + X_test, Y_test = test_set | |
| 37 | + else: | |
| 38 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | |
| 39 | + | |
| 40 | + print type(X), type(X_train), type(X_train[0]) | |
| 41 | + return | |
| 23 | 42 | |
| 24 | 43 | X_train, Y_train = np.array(X_train), np.array(Y_train) |
| 25 | 44 | X_test, Y_test = np.array(X_test), np.array(Y_test) | ... | ... |
test/test_model.py
| ... | ... | @@ -7,8 +7,12 @@ from ..mdata import CV, ILSVRC, ILSVRC_S |
| 7 | 7 | from ..mmodel.svm import SVM |
| 8 | 8 | from ..mmodel.theano import THEANO |
| 9 | 9 | |
| 10 | -timer = Timer() | |
| 10 | +import gzip | |
| 11 | +import cPickle | |
| 12 | + | |
| 11 | 13 | |
| 14 | +timer = Timer() | |
| 15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | |
| 12 | 16 | |
| 13 | 17 | def test_SVM_CV(): |
| 14 | 18 | timer.mark() |
| ... | ... | @@ -151,10 +155,13 @@ def test_THEANO_crop(): |
| 151 | 155 | dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') |
| 152 | 156 | X, Y = dilc.load_data(mode='local', feattype='coef') |
| 153 | 157 | timer.report() |
| 158 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | |
| 159 | + with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f: | |
| 160 | + cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f) | |
| 154 | 161 | |
| 155 | 162 | timer.mark() |
| 156 | 163 | mtheano = THEANO.ModelTHEANO(toolset='cnn') |
| 157 | - mtheano.train(X,Y) | |
| 164 | + mtheano._train_cnn(dataset='../../res/ils_crop.pkl') | |
| 158 | 165 | timer.report() |
| 159 | 166 | |
| 160 | 167 | ... | ... |