Commit 5a469df5bf54b238d31c2d676fa4c8929e18ed1c

Authored by Chunk
1 parent e585c8c2
Exists in master and in 1 other branch refactor

staged.

mdata/ILSVRC.py
... ... @@ -457,7 +457,7 @@ class DataILSVRC(DataDumperBase):
457 457 dict_dataset[hash] = (tag, im.getCoefMatrix(channel='Y'))
458 458  
459 459 for tag, feat in dict_dataset.values():
460   - X.append(feat.tolist())
  460 + X.append(feat)
461 461 Y.append(int(tag))
462 462  
463 463 else:
... ...
mmodel/theano/THEANO.py
... ... @@ -9,6 +9,11 @@ from .theanoutil import *
9 9 import numpy as np
10 10 from sklearn import cross_validation
11 11  
  12 +import gzip
  13 +import cPickle
  14 +
  15 +package_dir = os.path.dirname(os.path.abspath(__file__))
  16 +
12 17  
13 18 class ModelTHEANO(ModelBase):
14 19 def __init__(self, toolset='cnn', sc=None):
... ... @@ -17,9 +22,23 @@ class ModelTHEANO(ModelBase):
17 22 self.sparker = sc
18 23 self.model = None
19 24  
20   - def _train_cnn(self, X, Y, learning_rate=0.1, n_epochs=200, nkerns=[20, 50, 50],
  25 + def _train_cnn(self, X=None, Y=None, dataset=os.path.join(package_dir, '../../res/', 'ils_crop.pkl'),
  26 + learning_rate=0.1, n_epochs=200,
  27 + nkerns=[20, 50, 50],
21 28 batch_size=200):
22   - X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0)
  29 +
  30 + if X == None:
  31 + assert dataset != None
  32 + with open(dataset, 'rb') as f:
  33 + train_set, test_set = cPickle.load(f)
  34 +
  35 + X_train, Y_train = train_set
  36 + X_test, Y_test = test_set
  37 + else:
  38 + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0)
  39 +
  40 + print type(X), type(X_train), type(X_train[0])
  41 + return
23 42  
24 43 X_train, Y_train = np.array(X_train), np.array(Y_train)
25 44 X_test, Y_test = np.array(X_test), np.array(Y_test)
... ...
test/test_model.py
... ... @@ -7,8 +7,12 @@ from ..mdata import CV, ILSVRC, ILSVRC_S
7 7 from ..mmodel.svm import SVM
8 8 from ..mmodel.theano import THEANO
9 9  
10   -timer = Timer()
  10 +import gzip
  11 +import cPickle
  12 +
11 13  
  14 +timer = Timer()
  15 +package_dir = os.path.dirname(os.path.abspath(__file__))
12 16  
13 17 def test_SVM_CV():
14 18 timer.mark()
... ... @@ -151,10 +155,13 @@ def test_THEANO_crop():
151 155 dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil')
152 156 X, Y = dilc.load_data(mode='local', feattype='coef')
153 157 timer.report()
  158 + X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0)
  159 + with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f:
  160 + cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f)
154 161  
155 162 timer.mark()
156 163 mtheano = THEANO.ModelTHEANO(toolset='cnn')
157   - mtheano.train(X,Y)
  164 + mtheano._train_cnn(dataset='../../res/ils_crop.pkl')
158 165 timer.report()
159 166  
160 167  
... ...