Commit 5a469df5bf54b238d31c2d676fa4c8929e18ed1c
1 parent
e585c8c2
Exists in
master
and in
1 other branch
staged.
Showing
3 changed files
with
31 additions
and
5 deletions
Show diff stats
mdata/ILSVRC.py
| @@ -457,7 +457,7 @@ class DataILSVRC(DataDumperBase): | @@ -457,7 +457,7 @@ class DataILSVRC(DataDumperBase): | ||
| 457 | dict_dataset[hash] = (tag, im.getCoefMatrix(channel='Y')) | 457 | dict_dataset[hash] = (tag, im.getCoefMatrix(channel='Y')) |
| 458 | 458 | ||
| 459 | for tag, feat in dict_dataset.values(): | 459 | for tag, feat in dict_dataset.values(): |
| 460 | - X.append(feat.tolist()) | 460 | + X.append(feat) |
| 461 | Y.append(int(tag)) | 461 | Y.append(int(tag)) |
| 462 | 462 | ||
| 463 | else: | 463 | else: |
mmodel/theano/THEANO.py
| @@ -9,6 +9,11 @@ from .theanoutil import * | @@ -9,6 +9,11 @@ from .theanoutil import * | ||
| 9 | import numpy as np | 9 | import numpy as np |
| 10 | from sklearn import cross_validation | 10 | from sklearn import cross_validation |
| 11 | 11 | ||
| 12 | +import gzip | ||
| 13 | +import cPickle | ||
| 14 | + | ||
| 15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| 16 | + | ||
| 12 | 17 | ||
| 13 | class ModelTHEANO(ModelBase): | 18 | class ModelTHEANO(ModelBase): |
| 14 | def __init__(self, toolset='cnn', sc=None): | 19 | def __init__(self, toolset='cnn', sc=None): |
| @@ -17,9 +22,23 @@ class ModelTHEANO(ModelBase): | @@ -17,9 +22,23 @@ class ModelTHEANO(ModelBase): | ||
| 17 | self.sparker = sc | 22 | self.sparker = sc |
| 18 | self.model = None | 23 | self.model = None |
| 19 | 24 | ||
| 20 | - def _train_cnn(self, X, Y, learning_rate=0.1, n_epochs=200, nkerns=[20, 50, 50], | 25 | + def _train_cnn(self, X=None, Y=None, dataset=os.path.join(package_dir, '../../res/', 'ils_crop.pkl'), |
| 26 | + learning_rate=0.1, n_epochs=200, | ||
| 27 | + nkerns=[20, 50, 50], | ||
| 21 | batch_size=200): | 28 | batch_size=200): |
| 22 | - X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | 29 | + |
| 30 | + if X == None: | ||
| 31 | + assert dataset != None | ||
| 32 | + with open(dataset, 'rb') as f: | ||
| 33 | + train_set, test_set = cPickle.load(f) | ||
| 34 | + | ||
| 35 | + X_train, Y_train = train_set | ||
| 36 | + X_test, Y_test = test_set | ||
| 37 | + else: | ||
| 38 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | ||
| 39 | + | ||
| 40 | + print type(X), type(X_train), type(X_train[0]) | ||
| 41 | + return | ||
| 23 | 42 | ||
| 24 | X_train, Y_train = np.array(X_train), np.array(Y_train) | 43 | X_train, Y_train = np.array(X_train), np.array(Y_train) |
| 25 | X_test, Y_test = np.array(X_test), np.array(Y_test) | 44 | X_test, Y_test = np.array(X_test), np.array(Y_test) |
test/test_model.py
| @@ -7,8 +7,12 @@ from ..mdata import CV, ILSVRC, ILSVRC_S | @@ -7,8 +7,12 @@ from ..mdata import CV, ILSVRC, ILSVRC_S | ||
| 7 | from ..mmodel.svm import SVM | 7 | from ..mmodel.svm import SVM |
| 8 | from ..mmodel.theano import THEANO | 8 | from ..mmodel.theano import THEANO |
| 9 | 9 | ||
| 10 | -timer = Timer() | 10 | +import gzip |
| 11 | +import cPickle | ||
| 12 | + | ||
| 11 | 13 | ||
| 14 | +timer = Timer() | ||
| 15 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| 12 | 16 | ||
| 13 | def test_SVM_CV(): | 17 | def test_SVM_CV(): |
| 14 | timer.mark() | 18 | timer.mark() |
| @@ -151,10 +155,13 @@ def test_THEANO_crop(): | @@ -151,10 +155,13 @@ def test_THEANO_crop(): | ||
| 151 | dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') | 155 | dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') |
| 152 | X, Y = dilc.load_data(mode='local', feattype='coef') | 156 | X, Y = dilc.load_data(mode='local', feattype='coef') |
| 153 | timer.report() | 157 | timer.report() |
| 158 | + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) | ||
| 159 | + with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f: | ||
| 160 | + cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f) | ||
| 154 | 161 | ||
| 155 | timer.mark() | 162 | timer.mark() |
| 156 | mtheano = THEANO.ModelTHEANO(toolset='cnn') | 163 | mtheano = THEANO.ModelTHEANO(toolset='cnn') |
| 157 | - mtheano.train(X,Y) | 164 | + mtheano._train_cnn(dataset='../../res/ils_crop.pkl') |
| 158 | timer.report() | 165 | timer.report() |
| 159 | 166 | ||
| 160 | 167 |