Commit f731b60bffd67122b09f6044600147f00ca59c44
1 parent
163652ee
Exists in
refactor
SVM2
Showing
2 changed files
with
227 additions
and
2 deletions
Show diff stats
... | ... | @@ -0,0 +1,225 @@ |
1 | +''' | |
2 | +SVM Model. | |
3 | + | |
4 | +@author: chunk | |
5 | +chunkplus@gmail.com | |
6 | +2014 Dec | |
7 | +''' | |
8 | +import os, sys | |
9 | +# from ...mfeat import * | |
10 | +from ...mmodel import * | |
11 | +# from ...mmodel.svm.svmutil import * | |
12 | +from ...common import * | |
13 | + | |
14 | +import numpy as np | |
15 | +import csv | |
16 | +import json | |
17 | +import pickle | |
18 | +# import cv2 | |
19 | +from sklearn import svm | |
20 | + | |
21 | +package_dir = os.path.dirname(os.path.abspath(__file__)) | |
22 | + | |
23 | +dict_Train = {} | |
24 | +dict_databuf = {} | |
25 | +dict_tagbuf = {} | |
26 | +dict_featbuf = {} | |
27 | + | |
28 | + | |
29 | +class ModelSVM(ModelBase): | |
30 | + def __init__(self, toolset='sklearn', sc=None): | |
31 | + ModelBase.__init__(self) | |
32 | + self.toolset = toolset | |
33 | + self.sparker = sc | |
34 | + | |
35 | + | |
36 | + def _train_sklearn(self, X, Y): | |
37 | + clf = svm.SVC(C=4, kernel='linear', shrinking=False, verbose=True) | |
38 | + clf.fit(X, Y) | |
39 | + with open(os.path.join(package_dir, '../..', 'res/svm_sklearn.model'), 'wb') as modelfile: | |
40 | + model = pickle.dump(clf, modelfile) | |
41 | + | |
42 | + self.model = clf | |
43 | + | |
44 | + return clf | |
45 | + | |
46 | + | |
47 | + def _predict_sklearn(self, feat, model=None): | |
48 | + """N.B. sklearn.svm.base.predict : | |
49 | + Perform classification on samples in X. | |
50 | + Parameters | |
51 | + ---------- | |
52 | + X : {array-like, sparse matrix}, shape = [n_samples, n_features] | |
53 | + | |
54 | + Returns | |
55 | + ------- | |
56 | + y_pred : array, shape = [n_samples] | |
57 | + Class labels for samples in X. | |
58 | + """ | |
59 | + if model is None: | |
60 | + if self.model != None: | |
61 | + model = self.model | |
62 | + else: | |
63 | + print 'loading model ...' | |
64 | + with open(os.path.join(package_dir, '../..', 'res/svm_sklearn.model'), 'rb') as modelfile: | |
65 | + model = pickle.load(modelfile) | |
66 | + | |
67 | + return model.predict(feat) | |
68 | + | |
69 | + def __test_sklearn(self, X, Y, model=None): | |
70 | + if model is None: | |
71 | + if self.model != None: | |
72 | + model = self.model | |
73 | + else: | |
74 | + print 'loading model ...' | |
75 | + with open(os.path.join(package_dir, '../..', 'res/svm_sklearn.model'), 'rb') as modelfile: | |
76 | + model = pickle.load(modelfile) | |
77 | + | |
78 | + result_Y = np.array(self._predict_sklearn(X, model)) | |
79 | + | |
80 | + fp = 0 | |
81 | + tp = 0 | |
82 | + sum = np.sum(np.array(Y) == 1) | |
83 | + positive, negative = np.sum(np.array(Y) == 1), np.sum(np.array(Y) == 0) | |
84 | + print positive, negative | |
85 | + for i in range(len(Y)): | |
86 | + if Y[i] == 0 and result_Y[i] == 1: | |
87 | + fp += 1 | |
88 | + elif Y[i] == 1 and result_Y[i] == 1: | |
89 | + tp += 1 | |
90 | + return float(fp) / negative, float(tp) / positive, np.mean(Y == result_Y) | |
91 | + | |
92 | + def _test_sklearn(self, X, Y, model=None): | |
93 | + if model is None: | |
94 | + if self.model != None: | |
95 | + model = self.model | |
96 | + else: | |
97 | + print 'loading model ...' | |
98 | + with open(os.path.join(package_dir, '../..', 'res/svm_sklearn.model'), 'rb') as modelfile: | |
99 | + model = pickle.load(modelfile) | |
100 | + | |
101 | + return model.score(X, Y) | |
102 | + | |
103 | + # def _train_libsvm(self, X, Y): | |
104 | + # X, Y = list(X), list(Y) | |
105 | + # # X, Y = [float(i) for i in X], [float(i) for i in Y] | |
106 | + # prob = svm_problem(Y, X) | |
107 | + # param = svm_parameter('-t 0 -c 4 -b 1 -h 0') | |
108 | + # # param = svm_parameter(kernel_type=LINEAR, C=10) | |
109 | + # m = svm_train(prob, param) | |
110 | + # svm_save_model(os.path.join(package_dir, '../..', 'res/svm_libsvm.model'), m) | |
111 | + # | |
112 | + # self.model = m | |
113 | + # | |
114 | + # return m | |
115 | + # | |
116 | + # def _predict_libsvm(self, feat, model=None): | |
117 | + # if model is None: | |
118 | + # if self.model != None: | |
119 | + # model = self.model | |
120 | + # else: | |
121 | + # print 'loading model ...' | |
122 | + # model = svm_load_model(os.path.join(package_dir, '../..', 'res/svm_libsvm.model')) | |
123 | + # | |
124 | + # feat = [list(feat)] | |
125 | + # # print len(feat),[0] * len(feat) | |
126 | + # label, _, _ = svm_predict([0] * len(feat), feat, model) | |
127 | + # return label | |
128 | + # | |
129 | + # | |
130 | + # def _test_libsvm(self, X, Y, model=None): | |
131 | + # if model is None: | |
132 | + # if self.model != None: | |
133 | + # model = self.model | |
134 | + # else: | |
135 | + # print 'loading model ...' | |
136 | + # model = svm_load_model(os.path.join(package_dir, '../..', 'res/svm_libsvm.model')) | |
137 | + # | |
138 | + # X, Y = list(X), list(Y) | |
139 | + # p_labs, p_acc, p_vals = svm_predict(Y, X, model) | |
140 | + # # ACC, MSE, SCC = evaluations(Y, p_labs) | |
141 | + # | |
142 | + # return p_acc | |
143 | + | |
144 | + # def _train_opencv(self, X, Y): | |
145 | + # svm_params = dict(kernel_type=cv2.SVM_LINEAR, | |
146 | + # svm_type=cv2.SVM_C_SVC, | |
147 | + # C=4) | |
148 | + # | |
149 | + # X, Y = np.array(X, dtype=np.float32), np.array(Y, dtype=np.float32) | |
150 | + # | |
151 | + # svm = cv2.SVM() | |
152 | + # svm.train(X, Y, params=svm_params) | |
153 | + # svm.save(os.path.join(package_dir, '../..', 'res/svm_opencv.model')) | |
154 | + # | |
155 | + # self.model = svm | |
156 | + # | |
157 | + # return svm | |
158 | + # | |
159 | + # | |
160 | + # def _predict_opencv(self, feat, model=None): | |
161 | + # if model is None: | |
162 | + # if self.model != None: | |
163 | + # model = self.model | |
164 | + # else: | |
165 | + # print 'loading model ...' | |
166 | + # with open(os.path.join(package_dir, '../..', 'res/svm_opencv.model'), 'rb') as modelfile: | |
167 | + # model = pickle.load(modelfile) | |
168 | + # feat = np.array(feat, dtype=np.float32) | |
169 | + # | |
170 | + # return model.predict(feat) | |
171 | + # | |
172 | + # | |
173 | + # def _test_opencv(self, X, Y, model=None): | |
174 | + # if model is None: | |
175 | + # if self.model != None: | |
176 | + # model = self.model | |
177 | + # else: | |
178 | + # print 'loading model ...' | |
179 | + # with open(os.path.join(package_dir, '../..', 'res/svm_opencv.model'), 'rb') as modelfile: | |
180 | + # model = pickle.load(modelfile) | |
181 | + # | |
182 | + # X, Y = np.array(X, dtype=np.float32), np.array(Y, dtype=np.float32) | |
183 | + # | |
184 | + # # result_Y = np.array([self._predict_opencv(x, model) for x in X]) | |
185 | + # result_Y = np.array(model.predict_all(X)).ravel() | |
186 | + # # print X[0] | |
187 | + # # print result_Y,Y | |
188 | + # return np.mean(Y == result_Y) | |
189 | + | |
190 | + def train(self, X, Y=None): | |
191 | + | |
192 | + if self.toolset == 'sklearn': | |
193 | + return self._train_sklearn(X, Y) | |
194 | + else: | |
195 | + raise Exception("Unknown toolset!") | |
196 | + | |
197 | + def predict(self, feat, model=None): | |
198 | + | |
199 | + if self.toolset == 'sklearn': | |
200 | + return self._predict_sklearn(feat, model) | |
201 | + else: | |
202 | + raise Exception("Unknown toolset!") | |
203 | + | |
204 | + | |
205 | + def test(self, X, Y=None, model=None): | |
206 | + | |
207 | + if self.toolset == 'sklearn': | |
208 | + return self.__test_sklearn(X, Y, model) | |
209 | + else: | |
210 | + raise Exception("Unknown toolset!") | |
211 | + | |
212 | + | |
213 | + | |
214 | + | |
215 | + | |
216 | + | |
217 | + | |
218 | + | |
219 | + | |
220 | + | |
221 | + | |
222 | + | |
223 | + | |
224 | + | |
225 | + | ... | ... |
mspark/rdd.py
... | ... | @@ -6,7 +6,7 @@ from ..mjpeg import * |
6 | 6 | from ..msteg import * |
7 | 7 | from ..msteg.steganography import LSB, F3, F4, F5 |
8 | 8 | from ..mfeat import IntraBlockDiff |
9 | -from ..mmodel.svm import SVM | |
9 | +from ..mmodel.svm import SVM2 | |
10 | 10 | |
11 | 11 | from numpy import array |
12 | 12 | import json |
... | ... | @@ -19,7 +19,7 @@ from hashlib import md5 |
19 | 19 | |
20 | 20 | np.random.seed(sum(map(ord, "whoami"))) |
21 | 21 | package_dir = os.path.dirname(os.path.abspath(__file__)) |
22 | -classifier = SVM.ModelSVM(toolset='sklearn') | |
22 | +classifier = SVM2.ModelSVM(toolset='sklearn') | |
23 | 23 | |
24 | 24 | def rddparse_data_CV(raw_row): |
25 | 25 | """ | ... | ... |