Commit 5a469df5bf54b238d31c2d676fa4c8929e18ed1c
1 parent
e585c8c2
Exists in
master
and in
1 other branch
staged.
Showing
3 changed files
with
31 additions
and
5 deletions
Show diff stats
mdata/ILSVRC.py
mmodel/theano/THEANO.py
... | ... | @@ -9,6 +9,11 @@ from .theanoutil import * |
9 | 9 | import numpy as np |
10 | 10 | from sklearn import cross_validation |
11 | 11 | |
12 | +import gzip | |
13 | +import cPickle | |
14 | + | |
15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | |
16 | + | |
12 | 17 | |
13 | 18 | class ModelTHEANO(ModelBase): |
14 | 19 | def __init__(self, toolset='cnn', sc=None): |
... | ... | @@ -17,9 +22,23 @@ class ModelTHEANO(ModelBase): |
17 | 22 | self.sparker = sc |
18 | 23 | self.model = None |
19 | 24 | |
20 | - def _train_cnn(self, X, Y, learning_rate=0.1, n_epochs=200, nkerns=[20, 50, 50], | |
25 | + def _train_cnn(self, X=None, Y=None, dataset=os.path.join(package_dir, '../../res/', 'ils_crop.pkl'), | |
26 | + learning_rate=0.1, n_epochs=200, | |
27 | + nkerns=[20, 50, 50], | |
21 | 28 | batch_size=200): |
22 | - X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | |
29 | + | |
30 | + if X == None: | |
31 | + assert dataset != None | |
32 | + with open(dataset, 'rb') as f: | |
33 | + train_set, test_set = cPickle.load(f) | |
34 | + | |
35 | + X_train, Y_train = train_set | |
36 | + X_test, Y_test = test_set | |
37 | + else: | |
38 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | |
39 | + | |
40 | + print type(X), type(X_train), type(X_train[0]) | |
41 | + return | |
23 | 42 | |
24 | 43 | X_train, Y_train = np.array(X_train), np.array(Y_train) |
25 | 44 | X_test, Y_test = np.array(X_test), np.array(Y_test) | ... | ... |
test/test_model.py
... | ... | @@ -7,8 +7,12 @@ from ..mdata import CV, ILSVRC, ILSVRC_S |
7 | 7 | from ..mmodel.svm import SVM |
8 | 8 | from ..mmodel.theano import THEANO |
9 | 9 | |
10 | -timer = Timer() | |
10 | +import gzip | |
11 | +import cPickle | |
12 | + | |
11 | 13 | |
14 | +timer = Timer() | |
15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | |
12 | 16 | |
13 | 17 | def test_SVM_CV(): |
14 | 18 | timer.mark() |
... | ... | @@ -151,10 +155,13 @@ def test_THEANO_crop(): |
151 | 155 | dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') |
152 | 156 | X, Y = dilc.load_data(mode='local', feattype='coef') |
153 | 157 | timer.report() |
158 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | |
159 | + with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f: | |
160 | + cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f) | |
154 | 161 | |
155 | 162 | timer.mark() |
156 | 163 | mtheano = THEANO.ModelTHEANO(toolset='cnn') |
157 | - mtheano.train(X,Y) | |
164 | + mtheano._train_cnn(dataset='../../res/ils_crop.pkl') | |
158 | 165 | timer.report() |
159 | 166 | |
160 | 167 | ... | ... |