SC.py 4.27 KB
__author__ = 'chunk'

from ..common import *
from .dependencies import *
from . import *

import sys
from pyspark import SparkConf, SparkContext
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD
from pyspark.mllib.regression import LabeledPoint
from numpy import array
import json
import pickle


def parse_cv(raw_row):
    """
    input: (u'key0',u'cf_feat:hog:[0.056273,...]--%--cf_pic:data:\ufffd\ufffd\...--%--cf_tag:hog:True')
    return: ([0.056273,...],1)
    """
    data = raw_row[1].split('--%--')
    feat = json.loads(data[0].split(':')[-1])
    tag = 1 if data[-1].split(':')[-1] == 'True' else 0
    return (feat, tag)


class Sparker(object):
    def __init__(self, host='HPC-server', appname='NewPySparkApp', **kwargs):
        load_env()
        self.host = host
        self.appname = appname
        self.master = kwargs.get('master', 'spark://%s:7077' % self.host)
        print self.master
        self.conf = SparkConf()
        self.conf.setSparkHome(self.host) \
            .setMaster(self.master) \
            .setAppName(self.appname)

        # self.conf.set("spark.akka.frameSize","10685760")
        # self.conf.set("spark.driver.extraClassPath", extraClassPath) \
        # .set("spark.executor.extraClassPath", extraClassPath) \
        # .set("SPARK_CLASSPATH", extraClassPath) \
        # .set("spark.driver.memory", "1G") \
        # .set("spark.yarn.jar", sparkJar)

        self.sc = SparkContext(conf=self.conf)

        self.model = None

    def read_habase(self, table_name, columns=None):
        """
        ref - http://happybase.readthedocs.org/en/latest/user.html#retrieving-data

        Filter format:
            columns=['cf1:col1', 'cf1:col2']
            or
            columns=['cf1']

        """
        hconf = {"hbase.zookeeper.quorum": self.host,
                 "hbase.mapreduce.inputtable": table_name,
        }

        hbase_rdd = self.sc.newAPIHadoopRDD(inputFormatClass=hparams["inputFormatClass"],
                                            keyClass=hparams["readKeyClass"],
                                            valueClass=hparams["readValueClass"],
                                            keyConverter=hparams["readKeyConverter"],
                                            valueConverter=hparams["readValueConverter"],
                                            conf=hconf)
        hbase_rdd = hbase_rdd.map(lambda x: parse_cv(x))
        output = hbase_rdd.collect()
        return output

    def write_habase(self, table_name, data):
        """
        Data Format:
            e.g. [["row8", "f1", "", "caocao cao"], ["row9", "f1", "c1", "asdfg hhhh"]]
        """
        hconf = {"hbase.zookeeper.quorum": self.host,
                 "hbase.mapreduce.inputtable": table_name,
                 "hbase.mapred.outputtable": table_name,
                 "mapreduce.outputformat.class": hparams["outputFormatClass"],
                 "mapreduce.job.output.key.class": hparams["writeKeyClass"],
                 "mapreduce.job.output.value.class": hparams["writeValueClass"],
        }

        self.sc.parallelize(data).map(lambda x: (x[0], x)).saveAsNewAPIHadoopDataset(
            conf=hconf,
            keyConverter=hparams["writeKeyConverter"],
            valueConverter=hparams["writeValueConverter"])


    def train_svm(self, rdd_labeled):
        svm = SVMWithSGD.train(rdd_labeled)
        self.model = svm

        return svm

    def train_svm(self, X, Y):

        # data = []
        # for feat, tag in zip(X, Y):
        # data.append(LabeledPoint(tag, feat))
        # svm = SVMWithSGD.train(self.sc.parallelize(data))

        hdd_data = self.sc.parallelize(zip(X, Y), 20).map(lambda x: LabeledPoint(x[1], x[0]))
        svm = SVMWithSGD.train(hdd_data)

        self.model = svm

        # with open('res/svm_spark.model', 'wb') as modelfile:
        # model = pickle.dump(svm, modelfile)

        return svm

    def predict_svm(self, x, model=None):
        if model is None:
            if self.model != None:
                model = self.model
            else:
                # with open('res/svm_spark.model', 'rb') as modelfile:
                # model = pickle.load(modelfile)
                raise Exception("No model available!")

        return model.predict(x)

    def test_svm(self, X, Y, model=None):
        pass