Commit 5a469df5bf54b238d31c2d676fa4c8929e18ed1c
1 parent
e585c8c2
Exists in
master
and in
1 other branch
staged.
Showing
3 changed files
with
31 additions
and
5 deletions
Show diff stats
mdata/ILSVRC.py
@@ -457,7 +457,7 @@ class DataILSVRC(DataDumperBase): | @@ -457,7 +457,7 @@ class DataILSVRC(DataDumperBase): | ||
457 | dict_dataset[hash] = (tag, im.getCoefMatrix(channel='Y')) | 457 | dict_dataset[hash] = (tag, im.getCoefMatrix(channel='Y')) |
458 | 458 | ||
459 | for tag, feat in dict_dataset.values(): | 459 | for tag, feat in dict_dataset.values(): |
460 | - X.append(feat.tolist()) | 460 | + X.append(feat) |
461 | Y.append(int(tag)) | 461 | Y.append(int(tag)) |
462 | 462 | ||
463 | else: | 463 | else: |
mmodel/theano/THEANO.py
@@ -9,6 +9,11 @@ from .theanoutil import * | @@ -9,6 +9,11 @@ from .theanoutil import * | ||
9 | import numpy as np | 9 | import numpy as np |
10 | from sklearn import cross_validation | 10 | from sklearn import cross_validation |
11 | 11 | ||
12 | +import gzip | ||
13 | +import cPickle | ||
14 | + | ||
15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | ||
16 | + | ||
12 | 17 | ||
13 | class ModelTHEANO(ModelBase): | 18 | class ModelTHEANO(ModelBase): |
14 | def __init__(self, toolset='cnn', sc=None): | 19 | def __init__(self, toolset='cnn', sc=None): |
@@ -17,9 +22,23 @@ class ModelTHEANO(ModelBase): | @@ -17,9 +22,23 @@ class ModelTHEANO(ModelBase): | ||
17 | self.sparker = sc | 22 | self.sparker = sc |
18 | self.model = None | 23 | self.model = None |
19 | 24 | ||
20 | - def _train_cnn(self, X, Y, learning_rate=0.1, n_epochs=200, nkerns=[20, 50, 50], | 25 | + def _train_cnn(self, X=None, Y=None, dataset=os.path.join(package_dir, '../../res/', 'ils_crop.pkl'), |
26 | + learning_rate=0.1, n_epochs=200, | ||
27 | + nkerns=[20, 50, 50], | ||
21 | batch_size=200): | 28 | batch_size=200): |
22 | - X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | 29 | + |
30 | + if X == None: | ||
31 | + assert dataset != None | ||
32 | + with open(dataset, 'rb') as f: | ||
33 | + train_set, test_set = cPickle.load(f) | ||
34 | + | ||
35 | + X_train, Y_train = train_set | ||
36 | + X_test, Y_test = test_set | ||
37 | + else: | ||
38 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | ||
39 | + | ||
40 | + print type(X), type(X_train), type(X_train[0]) | ||
41 | + return | ||
23 | 42 | ||
24 | X_train, Y_train = np.array(X_train), np.array(Y_train) | 43 | X_train, Y_train = np.array(X_train), np.array(Y_train) |
25 | X_test, Y_test = np.array(X_test), np.array(Y_test) | 44 | X_test, Y_test = np.array(X_test), np.array(Y_test) |
test/test_model.py
@@ -7,8 +7,12 @@ from ..mdata import CV, ILSVRC, ILSVRC_S | @@ -7,8 +7,12 @@ from ..mdata import CV, ILSVRC, ILSVRC_S | ||
7 | from ..mmodel.svm import SVM | 7 | from ..mmodel.svm import SVM |
8 | from ..mmodel.theano import THEANO | 8 | from ..mmodel.theano import THEANO |
9 | 9 | ||
10 | -timer = Timer() | 10 | +import gzip |
11 | +import cPickle | ||
12 | + | ||
11 | 13 | ||
14 | +timer = Timer() | ||
15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | ||
12 | 16 | ||
13 | def test_SVM_CV(): | 17 | def test_SVM_CV(): |
14 | timer.mark() | 18 | timer.mark() |
@@ -151,10 +155,13 @@ def test_THEANO_crop(): | @@ -151,10 +155,13 @@ def test_THEANO_crop(): | ||
151 | dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') | 155 | dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') |
152 | X, Y = dilc.load_data(mode='local', feattype='coef') | 156 | X, Y = dilc.load_data(mode='local', feattype='coef') |
153 | timer.report() | 157 | timer.report() |
158 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | ||
159 | + with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f: | ||
160 | + cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f) | ||
154 | 161 | ||
155 | timer.mark() | 162 | timer.mark() |
156 | mtheano = THEANO.ModelTHEANO(toolset='cnn') | 163 | mtheano = THEANO.ModelTHEANO(toolset='cnn') |
157 | - mtheano.train(X,Y) | 164 | + mtheano._train_cnn(dataset='../../res/ils_crop.pkl') |
158 | timer.report() | 165 | timer.report() |
159 | 166 | ||
160 | 167 |