Commit d2603183fb96948bdea2aef75d851c6423770e59

Authored by Chunk
1 parent f1aad98a
Exists in master and in 1 other branch refactor

staged.

Showing 1 changed file with 5 additions and 6 deletions   Show diff stats
test/test_model.py
@@ -5,7 +5,7 @@ from sklearn import cross_validation @@ -5,7 +5,7 @@ from sklearn import cross_validation
5 from ..common import * 5 from ..common import *
6 from ..mdata import CV, ILSVRC, ILSVRC_S 6 from ..mdata import CV, ILSVRC, ILSVRC_S
7 from ..mmodel.svm import SVM 7 from ..mmodel.svm import SVM
8 -from ..mmodel.theano import THEANO 8 +from ..mmodel.theano import THEANO
9 9
10 import gzip 10 import gzip
11 import cPickle 11 import cPickle
@@ -14,6 +14,7 @@ import cPickle @@ -14,6 +14,7 @@ import cPickle
14 timer = Timer() 14 timer = Timer()
15 package_dir = os.path.dirname(os.path.abspath(__file__)) 15 package_dir = os.path.dirname(os.path.abspath(__file__))
16 16
  17 +
17 def test_SVM_CV(): 18 def test_SVM_CV():
18 timer.mark() 19 timer.mark()
19 dcv = CV.DataCV() 20 dcv = CV.DataCV()
@@ -89,7 +90,7 @@ def test_SVM_ILSVRC_HBASE(): @@ -89,7 +90,7 @@ def test_SVM_ILSVRC_HBASE():
89 X1, Y1 = dil.load_data(mode='local') 90 X1, Y1 = dil.load_data(mode='local')
90 91
91 X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.4, random_state=0) 92 X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.4, random_state=0)
92 - print Y,np.sum(np.array(Y)==0),np.sum(np.array(Y)==1) 93 + print Y, np.sum(np.array(Y) == 0), np.sum(np.array(Y) == 1)
93 print np.array(Y).shape, np.array(X).shape 94 print np.array(Y).shape, np.array(X).shape
94 print np.array(X_train).shape, np.array(Y_train).shape 95 print np.array(X_train).shape, np.array(Y_train).shape
95 print np.array(X_test).shape, np.array(Y_test).shape 96 print np.array(X_test).shape, np.array(Y_test).shape
@@ -148,9 +149,7 @@ def test_SVM_ILSVRC_S(): @@ -148,9 +149,7 @@ def test_SVM_ILSVRC_S():
148 # test_SVM_ILSVRC_SPARK() 149 # test_SVM_ILSVRC_SPARK()
149 150
150 151
151 -  
152 def test_THEANO_crop(): 152 def test_THEANO_crop():
153 -  
154 timer.mark() 153 timer.mark()
155 dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil') 154 dilc = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category='Test_crop_pil')
156 X, Y = dilc.load_data(mode='local', feattype='coef') 155 X, Y = dilc.load_data(mode='local', feattype='coef')
@@ -158,11 +157,11 @@ def test_THEANO_crop(): @@ -158,11 +157,11 @@ def test_THEANO_crop():
158 157
159 # X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0) 158 # X_train, X_test, Y_train, Y_test = cross_validation.train_test_split(X, Y, test_size=0.2, random_state=0)
160 # with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f: 159 # with open(os.path.join(package_dir,'../res/','ils_crop.pkl'),'wb') as f:
161 - # cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f) 160 + # cPickle.dump([(X_train,Y_train),(X_test,Y_test)], f)
162 161
163 timer.mark() 162 timer.mark()
164 mtheano = THEANO.ModelTHEANO(toolset='cnn') 163 mtheano = THEANO.ModelTHEANO(toolset='cnn')
165 - mtheano._train_cnn(dataset='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val/ils_crop.pkl') 164 + mtheano._train_cnn(X, Y)
166 timer.report() 165 timer.report()
167 166
168 167