Blame view

mspark/SC.py 17.7 KB
3b4e250d   Chunk   staged.
1
# -*- coding: utf-8 -*-
a9c10957   Chunk   hbase-svm & spark...
2
3
__author__ = 'chunk'

ca73c96f   Chunk   Transformed into ...
4
from ..common import *
f69baeb6   Chunk   spark streaming ...
5
6
from .dependencies import *
from . import *
3b4e250d   Chunk   staged.
7
8
9
10
11
# from ..mdata import MSR, CV, ILSVRC, ILSVRC_S

from ..mjpeg import *
from ..msteg import *
from ..msteg.steganography import LSB, F3, F4, F5
1c2a3fa0   Chunk   staged.
12
from ..mfeat import IntraBlockDiff
e3e7e73a   Chunk   spider standalone...
13
from ..mmodel.svm import SVM2
ca73c96f   Chunk   Transformed into ...
14

a9c10957   Chunk   hbase-svm & spark...
15
import sys
02528074   Chunk   staged.
16
from pyspark import RDD
a9c10957   Chunk   hbase-svm & spark...
17
18
19
20
from pyspark import SparkConf, SparkContext
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD
from pyspark.mllib.regression import LabeledPoint
from numpy import array
5ec38adb   Chunk   spark-local of da...
21
22
import json
import pickle
3b4e250d   Chunk   staged.
23
24
25
import tempfile

import numpy as np
3b4e250d   Chunk   staged.
26
27
28
from scipy import stats
from hashlib import md5

1c2a3fa0   Chunk   staged.
29
np.random.seed(sum(map(ord, "whoami")))
3b4e250d   Chunk   staged.
30
package_dir = os.path.dirname(os.path.abspath(__file__))
e3e7e73a   Chunk   spider standalone...
31
classifier = SVM2.ModelSVM(toolset='sklearn')
3b4e250d   Chunk   staged.
32

e3ec1f74   Chunk   staged.
33

3b4e250d   Chunk   staged.
34
def rddparse_data_CV(raw_row):
5ec38adb   Chunk   spark-local of da...
35
36
37
38
39
40
41
42
43
44
    """
    input: (u'key0',u'cf_feat:hog:[0.056273,...]--%--cf_pic:data:\ufffd\ufffd\...--%--cf_tag:hog:True')
    return: ([0.056273,...],1)
    """
    data = raw_row[1].split('--%--')
    feat = json.loads(data[0].split(':')[-1])
    tag = 1 if data[-1].split(':')[-1] == 'True' else 0
    return (feat, tag)


3b4e250d   Chunk   staged.
45
46
47
48
49
50
51
def rddparse_data_ILS(raw_row):
    """
    input: (u'key0',u'cf_feat:hog:[0.056273,...]--%--cf_pic:data:\ufffd\ufffd\...--%--cf_tag:hog:True')
    return: ([0.056273,...],1)

    In fact we can also use mapValues.
    """
3b4e250d   Chunk   staged.
52
    key = raw_row[0]
1c2a3fa0   Chunk   staged.
53
54
55
56
    # if key == '04650c488a2b163ca8a1f52da6022f03.jpg':
    # with open('/tmp/hhhh','wb') as f:
    # f.write(raw_row[1].decode('unicode-escape')).encode('latin-1')
    items = raw_row[1].decode('unicode-escape').encode('latin-1').split('--%--')
3b4e250d   Chunk   staged.
57
58
59
60
61
    data = items[0].split('cf_pic:data:')[-1]
    return (key, data)


def rddparse_all_ILS(raw_row):
8bddd8b3   Chunk   You guess what? T...
62
63
64
    """
    Deprecated
    """
3b4e250d   Chunk   staged.
65
    key = raw_row[0]
1c2a3fa0   Chunk   staged.
66
    items = raw_row[1].decode('unicode-escape').encode('latin-1').split('--%--')
8bddd8b3   Chunk   You guess what? T...
67
68
69

    # @TODO
    # N.B "ValueError: No JSON object could be decoded" Because the spark-hbase IO is based on strings.
02528074   Chunk   staged.
70
71
    # And the order of items is not as expected. See ../res/row-sample.txt or check in hbase shell for that.

ece71a0d   Chunk   Streaming! encodi...
72
73
    data = [items[0].split('cf_pic:data:')[-1]] + [json.loads(item.split(':')[-1]) for item in
                                                   items[1:]]
8bddd8b3   Chunk   You guess what? T...
74

3b4e250d   Chunk   staged.
75
76
77
    return (key, data)


02528074   Chunk   staged.
78
79
80
81
82
83
84
def rddparse_dataset_ILS(raw_row):
    if raw_row[0] == '04650c488a2b163ca8a1f52da6022f03.jpg':
        print raw_row
    items = raw_row[1].decode('unicode-escape').encode('latin-1').split('--%--')
    # tag = int(items[-2].split('cf_tag:' + tagtype)[-1])
    # feat = [item for sublist in json.loads(items[-1].split('cf_feat:' + feattype)[-1]) for subsublist in sublist for item in subsublist]
    tag = int(items[-1].split(':')[-1])
ece71a0d   Chunk   Streaming! encodi...
85
86
    feat = [item for sublist in json.loads(items[0].split(':')[-1]) for subsublist in sublist for
            item in subsublist]
02528074   Chunk   staged.
87
88
89
90

    return (tag, feat)


1c2a3fa0   Chunk   staged.
91
def rddinfo_ILS(img, info_rate=None, tag_chosen=None, tag_class=None):
3b4e250d   Chunk   staged.
92
93
94
95
96
97
98
    """
    Tempfile is our friend. (?)
    """
    info_rate = info_rate if info_rate != None else 0.0
    tag_chosen = tag_chosen if tag_chosen != None else stats.bernoulli.rvs(0.8)
    tag_class = tag_class if tag_class != None else 0
    try:
489c5608   Chunk   debugging...
99
        tmpf = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b', delete=True)
3b4e250d   Chunk   staged.
100
101
102
        tmpf.write(img)
        tmpf.seek(0)
        im = Jpeg(tmpf.name, key=sample_key)
1c2a3fa0   Chunk   staged.
103
104
105
106
107
108
109
110
111
112
        info = [
            im.image_width,
            im.image_height,
            im.image_width * im.image_height,
            im.getCapacity(),
            im.getQuality(),
            info_rate,
            tag_chosen,
            tag_class
        ]
3b4e250d   Chunk   staged.
113
114
115
        return info
    except Exception as e:
        print e
1c2a3fa0   Chunk   staged.
116
        raise
3b4e250d   Chunk   staged.
117
118
119
120
    finally:
        tmpf.close()


d47ae6ce   Chunk   staged.
121
def rddembed_ILS(row, rate=None):
3b4e250d   Chunk   staged.
122
123
124
125
126
127
128
    """
    input:
        e.g. row =('row1',[1,3400,'hello'])
    return:
        newrow = ('row2',[34,5400,'embeded'])
    """
    items = row[1]
d47ae6ce   Chunk   staged.
129
    capacity, chosen = int(items[4]), int(items[7])
3b4e250d   Chunk   staged.
130
131
132
133
134
135
136
137
    if chosen == 0:
        return None
    try:
        tmpf_src = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')
        tmpf_src.write(items[0])
        tmpf_src.seek(0)
        tmpf_dst = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')

1c2a3fa0   Chunk   staged.
138
139
        steger = F5.F5(sample_key, 1)

3b4e250d   Chunk   staged.
140
        if rate == None:
ece71a0d   Chunk   Streaming! encodi...
141
142
            embed_rate = steger.embed_raw_data(tmpf_src.name,
                                               os.path.join(package_dir, '../res/toembed'),
3b4e250d   Chunk   staged.
143
144
145
146
147
148
149
150
151
152
153
                                               tmpf_dst.name)
        else:
            assert (rate >= 0 and rate < 1)
            # print capacity
            hidden = np.random.bytes(int(int(capacity) * rate) / 8)
            embed_rate = steger.embed_raw_data(tmpf_src.name, hidden, tmpf_dst.name, frommem=True)

        tmpf_dst.seek(0)
        raw = tmpf_dst.read()
        index = md5(raw).hexdigest()

1c2a3fa0   Chunk   staged.
154
        return (index + '.jpg', [raw] + rddinfo_ILS(raw, embed_rate, 0, 1))
3b4e250d   Chunk   staged.
155
156
157
158
159
160
161
162

    except Exception as e:
        print e
        raise
    finally:
        tmpf_src.close()
        tmpf_dst.close()

d642d837   Chunk   staged.
163

489c5608   Chunk   debugging...
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def rddembed_ILS_EXT(row, rate=None):
    """
    input:
        e.g. row =('row1',[1,3400,'hello'])
    return:
        newrow = ('row2',[34,5400,'embeded']) or NULL
        [row,newrow]
    """
    items = row[1]
    capacity, chosen = int(items[4]), int(items[7])
    if chosen == 0:
        return [row]
    try:
        tmpf_src = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')
        tmpf_src.write(items[0])
        tmpf_src.seek(0)
        tmpf_dst = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')

51708346   Chunk   final experiments...
182
        steger = F5.F5(sample_key, 2)
489c5608   Chunk   debugging...
183
184

        if rate == None:
ece71a0d   Chunk   Streaming! encodi...
185
186
            embed_rate = steger.embed_raw_data(tmpf_src.name,
                                               os.path.join(package_dir, '../res/toembed'),
489c5608   Chunk   debugging...
187
188
189
190
191
192
193
194
195
196
197
                                               tmpf_dst.name)
        else:
            assert (rate >= 0 and rate < 1)
            # print capacity
            hidden = np.random.bytes(int(int(capacity) * rate) / 8)
            embed_rate = steger.embed_raw_data(tmpf_src.name, hidden, tmpf_dst.name, frommem=True)

        tmpf_dst.seek(0)
        raw = tmpf_dst.read()
        index = md5(raw).hexdigest()

d642d837   Chunk   staged.
198
        return [row, (index + '.jpg', [raw] + rddinfo_ILS(raw, embed_rate, 0, 1))]
489c5608   Chunk   debugging...
199
200
201
202
203
204
205
206

    except Exception as e:
        print e
        raise
    finally:
        tmpf_src.close()
        tmpf_dst.close()

3b4e250d   Chunk   staged.
207

1c2a3fa0   Chunk   staged.
208
209
210
211
212
213
214
215
216
217
218
def _get_feat(image, feattype='ibd', **kwargs):
    if feattype == 'ibd':
        feater = IntraBlockDiff.FeatIntraBlockDiff()
    else:
        raise Exception("Unknown feature type!")

    desc = feater.feat(image)

    return desc


8bddd8b3   Chunk   You guess what? T...
219
def rddfeat_ILS(items, feattype='ibd', **kwargs):
1c2a3fa0   Chunk   staged.
220
221
222
223
224
225
    try:
        tmpf_src = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')
        tmpf_src.write(items[0])
        tmpf_src.seek(0)

        desc = json.dumps(_get_feat(tmpf_src.name, feattype=feattype).tolist())
8bddd8b3   Chunk   You guess what? T...
226
227
        # print 'desccccccccccccccccccc',desc
        return items + [desc]
1c2a3fa0   Chunk   staged.
228
229
230
231
232
233
234

    except Exception as e:
        print e
        raise
    finally:
        tmpf_src.close()

e3ec1f74   Chunk   staged.
235
236

def rddanalysis_ILS(items, feattype='ibd', **kwargs):
4f36b116   Chunk   staged.
237
238
239
    head = np.fromstring(items[0][:2], dtype=np.uint8)
    if not np.array_equal(head, [255, 216]):
        return items + [0]
e3ec1f74   Chunk   staged.
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
    try:
        tmpf_src = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')
        tmpf_src.write(items[0])
        tmpf_src.seek(0)

        desc = _get_feat(tmpf_src.name, feattype=feattype)
        tag = classifier.predict(desc.ravel())[0]
        # print 'desccccccccccccccccccc',desc
        return items + [tag]

    except Exception as e:
        print e
        raise
    finally:
        tmpf_src.close()

        # return items + classifier.predict(items[-1])
e3e7e73a   Chunk   spider standalone...
257

1c2a3fa0   Chunk   staged.
258

d47ae6ce   Chunk   staged.
259
def format_out(row, cols, withdata=False):
0fbc087e   Chunk   staged.
260
261
262
263
264
    """
    input:
        e.g. row =('row1',[1,3400,'hello'])
            cols = [['cf_info', 'id'], ['cf_info', 'size'], ['cf_tag', 'desc']]
    return:
1c2a3fa0   Chunk   staged.
265
        [('row1',['row1', 'cf_info', 'id', '1']),('row1',['row1', 'cf_info', 'size', '3400']),('row1',['row1', 'cf_tag', 'desc', 'hello'])]
0fbc087e   Chunk   staged.
266
267
268
    """
    puts = []
    key = row[0]
d47ae6ce   Chunk   staged.
269
270
271
272
273
274
275
276
    # if key == '04650c488a2b163ca8a1f52da6022f03.jpg':
    # print row
    if not withdata:
        for data, col in zip(row[1][1:], cols[1:]):
            puts.append((key, [key] + col + [str(data)]))
    else:
        for data, col in zip(row[1], cols):
            puts.append((key, [key] + col + [str(data)]))
0fbc087e   Chunk   staged.
277
278
    return puts

26616791   Chunk   RDD-hbase bug fix...
279

54e2adda   Chunk   staged.
280
281
# scconf = SparkConf()
# scconf.setSparkHome("HPC-server") \
26616791   Chunk   RDD-hbase bug fix...
282
# .setMaster("spark://HPC-server:7077") \
e3ec1f74   Chunk   staged.
283
# .setAppName("example")
54e2adda   Chunk   staged.
284
285
286
287
# sc = SparkContext(conf=scconf)
#
#
# def read_hbase(table_name, func=None, collect=False):
4f36b116   Chunk   staged.
288
# """
54e2adda   Chunk   staged.
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
#     ref - http://happybase.readthedocs.org/en/latest/user.html#retrieving-data
#
#     Filter format:
#         columns=['cf1:col1', 'cf1:col2']
#         or
#         columns=['cf1']
#
#     """
#
#     hconf = {
#         "hbase.zookeeper.quorum": "HPC-server, HPC, HPC2",
#         # "hbase.zookeeper.quorum": self.host,
#         "hbase.mapreduce.inputtable": table_name,
#     }
#
#     hbase_rdd = sc.newAPIHadoopRDD(inputFormatClass=hparams["inputFormatClass"],
#                                            keyClass=hparams["readKeyClass"],
#                                            valueClass=hparams["readValueClass"],
#                                            keyConverter=hparams["readKeyConverter"],
#                                            valueConverter=hparams["readValueConverter"],
#                                            conf=hconf)
#
#     parser = func if func != None else rddparse_data_CV
#     hbase_rdd = hbase_rdd.map(lambda x: parser(x))
#
#     if collect:
#         return hbase_rdd.collect()
#     else:
#         return hbase_rdd
#
#
# def write_hbase(table_name, data, fromrdd=False, columns=None, withdata=False):
#     """
#     Data Format: (Deprecated)
#         e.g. [["row8", "f1", "", "caocao cao"], ["row9", "f1", "c1", "asdfg hhhh"]]
#
#     Data(from dictionary):
#         e.g. data ={'row1':[1,3400,'hello'], 'row2':[34,5000,'here in mine']},
#             cols = ['cf_info:id', 'cf_info:size', 'cf_tag:desc']
#     Data(from Rdd):
#         e.g. data =[('row1',[1,3400,'hello']), ('row2',[34,5000,'here in mine'])],
#             cols = ['cf_info:id', 'cf_info:size', 'cf_tag:desc']
#     """
#     hconf = {
#         "hbase.zookeeper.quorum": "HPC-server, HPC, HPC2",  # "hbase.zookeeper.quorum": self.host,
#         "hbase.mapreduce.inputtable": table_name,
#         "hbase.mapred.outputtable": table_name,
#         "mapreduce.outputformat.class": hparams["outputFormatClass"],
#         "mapreduce.job.output.key.class": hparams["writeKeyClass"],
#         "mapreduce.job.output.value.class": hparams["writeValueClass"],
#     }
#     cols = [col.split(':') for col in columns]
#     if not fromrdd:
#         rdd_data = sc.parallelize(data)
#     else:
#         rdd_data = data
#
#     rdd_data.flatMap(lambda x: format_out(x, cols, withdata=withdata)).saveAsNewAPIHadoopDataset(
#         conf=hconf,
#         keyConverter=hparams["writeKeyConverter"],
#         valueConverter=hparams["writeValueConverter"])

0fbc087e   Chunk   staged.
351

a9c10957   Chunk   hbase-svm & spark...
352
353
354
355
356
class Sparker(object):
    def __init__(self, host='HPC-server', appname='NewPySparkApp', **kwargs):
        load_env()
        self.host = host
        self.appname = appname
018ebf56   Chunk   Spark Streaming T...
357
        self.master = kwargs.get('master', 'spark://%s:7077' % self.host)
a9c10957   Chunk   hbase-svm & spark...
358
        self.conf = SparkConf()
5ec38adb   Chunk   spark-local of da...
359
360
361
        self.conf.setSparkHome(self.host) \
            .setMaster(self.master) \
            .setAppName(self.appname)
a9c10957   Chunk   hbase-svm & spark...
362

018ebf56   Chunk   Spark Streaming T...
363
        # self.conf.set("spark.akka.frameSize","10685760")
5ec38adb   Chunk   spark-local of da...
364
365
366
367
368
369
370
        # self.conf.set("spark.driver.extraClassPath", extraClassPath) \
        # .set("spark.executor.extraClassPath", extraClassPath) \
        # .set("SPARK_CLASSPATH", extraClassPath) \
        # .set("spark.driver.memory", "1G") \
        # .set("spark.yarn.jar", sparkJar)

        self.sc = SparkContext(conf=self.conf)
a9c10957   Chunk   hbase-svm & spark...
371
372
373

        self.model = None

f4fb4381   Chunk   staged.
374
    def read_hbase(self, table_name, func=None, collect=False, parallelism=30):
a9c10957   Chunk   hbase-svm & spark...
375
376
377
378
379
380
381
382
383
        """
        ref - http://happybase.readthedocs.org/en/latest/user.html#retrieving-data

        Filter format:
            columns=['cf1:col1', 'cf1:col2']
            or
            columns=['cf1']

        """
3b4e250d   Chunk   staged.
384

489c5608   Chunk   debugging...
385
        hconf = {
54e2adda   Chunk   staged.
386
387
388
            "hbase.zookeeper.quorum": "HPC-server, HPC, HPC2",
            # "hbase.zookeeper.quorum": self.host,
            "hbase.mapreduce.inputtable": table_name,
d642d837   Chunk   staged.
389
        }
a9c10957   Chunk   hbase-svm & spark...
390
391
392
393
394
395
396

        hbase_rdd = self.sc.newAPIHadoopRDD(inputFormatClass=hparams["inputFormatClass"],
                                            keyClass=hparams["readKeyClass"],
                                            valueClass=hparams["readValueClass"],
                                            keyConverter=hparams["readKeyConverter"],
                                            valueConverter=hparams["readValueConverter"],
                                            conf=hconf)
a9c10957   Chunk   hbase-svm & spark...
397

3b4e250d   Chunk   staged.
398
        parser = func if func != None else rddparse_data_CV
ea1eb31a   Chunk   spark is privileg...
399
400
401
402
403
        hbase_rdd = hbase_rdd.map(lambda x: parser(x))

        if collect:
            return hbase_rdd.collect()
        else:
0a55c5f4   Chunk   staged.
404
405
406
407
408
409
410
411
412
            """
            RDD-hbase bug fixed.(with 'repartition()')
            <http://stackoverflow.com/questions/29011574/how-is-spark-partitioning-from-hdfs>

            When Spark reads a file from HDFS, it creates a single partition for a single input split. Input split is set by the Hadoop InputFormat used to read this file. For instance, if you use textFile() it would be TextInputFormat in Hadoop, which would return you a single partition for a single block of HDFS (but the split between partitions would be done on line split, not the exact block split), unless you have a compressed text file. In case of compressed file you would get a single partition for a single file (as compressed text files are not splittable).
            When you call rdd.repartition(x) it would perform a shuffle of the data from N partititons you have in rdd to x partitions you want to have, partitioning would be done on round robin basis.
            If you have a 30GB uncompressed text file stored on HDFS, then with the default HDFS block size setting (128MB) it would be stored in 235 blocks, which means that the RDD you read from this file would have 235 partitions. When you call repartition(1000) your RDD would be marked as to be repartitioned, but in fact it would be shuffled to 1000 partitions only when you will execute an action on top of this RDD (lazy execution concept)

            """
26616791   Chunk   RDD-hbase bug fix...
413
            return hbase_rdd.repartition(parallelism)
ea1eb31a   Chunk   spark is privileg...
414

d47ae6ce   Chunk   staged.
415
    def write_hbase(self, table_name, data, fromrdd=False, columns=None, withdata=False):
a9c10957   Chunk   hbase-svm & spark...
416
        """
0fbc087e   Chunk   staged.
417
        Data Format: (Deprecated)
a9c10957   Chunk   hbase-svm & spark...
418
            e.g. [["row8", "f1", "", "caocao cao"], ["row9", "f1", "c1", "asdfg hhhh"]]
0fbc087e   Chunk   staged.
419
420
421
422
423
424
425

        Data(from dictionary):
            e.g. data ={'row1':[1,3400,'hello'], 'row2':[34,5000,'here in mine']},
                cols = ['cf_info:id', 'cf_info:size', 'cf_tag:desc']
        Data(from Rdd):
            e.g. data =[('row1',[1,3400,'hello']), ('row2',[34,5000,'here in mine'])],
                cols = ['cf_info:id', 'cf_info:size', 'cf_tag:desc']
a9c10957   Chunk   hbase-svm & spark...
426
        """
489c5608   Chunk   debugging...
427
        hconf = {
ece71a0d   Chunk   Streaming! encodi...
428
            "hbase.zookeeper.quorum": "HPC-server, HPC, HPC2",
e3ec1f74   Chunk   staged.
429
            # "hbase.zookeeper.quorum": self.host,
54e2adda   Chunk   staged.
430
431
432
433
434
            "hbase.mapreduce.inputtable": table_name,
            "hbase.mapred.outputtable": table_name,
            "mapreduce.outputformat.class": hparams["outputFormatClass"],
            "mapreduce.job.output.key.class": hparams["writeKeyClass"],
            "mapreduce.job.output.value.class": hparams["writeValueClass"],
d642d837   Chunk   staged.
435
        }
0fbc087e   Chunk   staged.
436
437
438
        cols = [col.split(':') for col in columns]
        if not fromrdd:
            rdd_data = self.sc.parallelize(data)
3b4e250d   Chunk   staged.
439
440
        else:
            rdd_data = data
a9c10957   Chunk   hbase-svm & spark...
441

ece71a0d   Chunk   Streaming! encodi...
442
443
        rdd_data.flatMap(
            lambda x: format_out(x, cols, withdata=withdata)).saveAsNewAPIHadoopDataset(
a9c10957   Chunk   hbase-svm & spark...
444
445
446
447
            conf=hconf,
            keyConverter=hparams["writeKeyConverter"],
            valueConverter=hparams["writeValueConverter"])

02528074   Chunk   staged.
448
    def train_svm(self, X, Y=None):
10b4f63f   Chunk   staged. Before Pa...
449

02528074   Chunk   staged.
450
451
452
453
454
455
456
457
458
        if Y == None:
            # From rdd_labeled
            assert isinstance(X, RDD)
            svm = SVMWithSGD.train(X)
        else:
            # data = []
            # for feat, tag in zip(X, Y):
            # data.append(LabeledPoint(tag, feat))
            # svm = SVMWithSGD.train(self.sc.parallelize(data))
f4fb4381   Chunk   staged.
459
            hdd_data = self.sc.parallelize(zip(X, Y), 30).map(lambda x: LabeledPoint(x[1], x[0]))
02528074   Chunk   staged.
460
            svm = SVMWithSGD.train(hdd_data)
a9c10957   Chunk   hbase-svm & spark...
461
        self.model = svm
10b4f63f   Chunk   staged. Before Pa...
462
463
        # with open('res/svm_spark.model', 'wb') as modelfile:
        # model = pickle.dump(svm, modelfile)
5ec38adb   Chunk   spark-local of da...
464

02528074   Chunk   staged.
465
        return self.model
a9c10957   Chunk   hbase-svm & spark...
466

02528074   Chunk   staged.
467
468
469
470
471
472
473
474
475
476
477
478
    def predict_svm(self, x, collect=False, model=None):
        """
        From pyspark.mlib.classification.py:

            >> svm.predict([1.0])
            1
            >> svm.predict(sc.parallelize([[1.0]])).collect()
            [1]
            >> svm.clearThreshold()
            >> svm.predict(array([1.0]))
            1.25...
        """
a9c10957   Chunk   hbase-svm & spark...
479
480
481
482
        if model is None:
            if self.model != None:
                model = self.model
            else:
10b4f63f   Chunk   staged. Before Pa...
483
484
485
                # with open('res/svm_spark.model', 'rb') as modelfile:
                # model = pickle.load(modelfile)
                raise Exception("No model available!")
a9c10957   Chunk   hbase-svm & spark...
486

02528074   Chunk   staged.
487
488
489
490
491
        res = model.predict(x)
        if collect:
            return res.collect()
        else:
            return res
f20e20ce   Chunk   staged.
492

02528074   Chunk   staged.
493
494
495
496
497
498
499
500
    def test_svm(self, X, Y=None, model=None):
        if model is None:
            if self.model != None:
                model = self.model
            else:
                # with open('res/svm_spark.model', 'rb') as modelfile:
                # model = pickle.load(modelfile)
                raise Exception("No model available!")
a9c10957   Chunk   hbase-svm & spark...
501

02528074   Chunk   staged.
502
503
504
505
506
507
        if Y == None:
            assert isinstance(X, RDD)
            pass
        else:
            result_Y = np.array(self.predict_svm(X, collect=True))
            return np.mean(Y == result_Y)