test_whole.py 3.93 KB
__author__ = 'chunk'

from ..mspark import SC
from ..common import *
from ..mdata import ILSVRC, ILSVRC_S

from pyspark.mllib.regression import LabeledPoint
import happybase


def test_whole():
    cols0 = [
        'cf_pic:data',
        'cf_info:width',
        'cf_info:height',
        'cf_info:size',
        'cf_info:capacity',
        'cf_info:quality',
        'cf_info:rate',
        'cf_tag:chosen',
        'cf_tag:class'
    ]
    cols1 = [
        'cf_pic:data',
        'cf_info:width',
        'cf_info:height',
        'cf_info:size',
        'cf_info:capacity',
        'cf_info:quality',
        'cf_info:rate',
        'cf_tag:chosen',
        'cf_tag:class',
        'cf_feat:bid',
    ]

    sparker = SC.Sparker(host='HPC-server', appname='ImageILSVRC-S', master='spark://HPC-server:7077')

    # rdd_data = sparker.read_hbase("ILSVRC2013_DET_val-Test_1", func=SC.rddparse_data_ILS, collect=False) \
    # .mapValues(lambda data: [data] + SC.rddinfo_ILS(data)) \
    # .flatMap(lambda x: SC.rddembed_ILS_EXT(x, rate=0.2)) \
    # .mapValues(lambda items: SC.rddfeat_ILS(items))

    rdd_data = sparker.read_hbase("ILSVRC2013_DET_val-Test_1", func=SC.rddparse_data_ILS, collect=False).mapValues(
        lambda data: [data] + SC.rddinfo_ILS(data))
    rdd_data_ext = rdd_data.map(lambda x: SC.rddembed_ILS(x, rate=0.2)).filter(lambda x: x != None)

    rdd_data = rdd_data.union(rdd_data_ext).mapValues(lambda items: SC.rddfeat_ILS(items))

    print len(rdd_data.collect())

    # sparker.write_hbase("ILSVRC2013_DET_val-Test_1", rdd_data, fromrdd=True, columns=cols1,
    #                     withdata=True)


def test_whole_ext(category='Train_100'):
    timer = Timer()

    print '[time]category:', category

    print '[time]formating table...'
    timer.mark()
    dil = ILSVRC.DataILSVRC(base_dir='/data/hadoop/ImageNet/ILSVRC/ILSVRC2013_DET_val', category=category)
    dil.delete_table()
    dil.format()
    dil.store_img()
    timer.report()

    print '[time]reading table...'
    timer.mark()
    table_name = dil.table_name
    connection = happybase.Connection('HPC-server')
    tables = connection.tables()
    if table_name not in tables:
        families = {'cf_pic': dict(),
                    'cf_info': dict(max_versions=10),
                    'cf_tag': dict(),
                    'cf_feat': dict(),
                    }
        connection.create_table(name=table_name, families=families)
    table = connection.table(name=table_name)

    cols = ['cf_pic:data']
    list_data = []
    for key, data in table.scan(columns=cols):
        data = data['cf_pic:data']
        list_data.append((key, data))
    timer.report()

    print '[time]processing...'
    timer.mark()
    sparker = SC.Sparker(host='HPC-server', appname='ImageILSVRC-S', master='spark://HPC-server:7077')
    rdd_data = sparker.sc.parallelize(list_data, 40) \
        .mapValues(lambda data: [data] + SC.rddinfo_ILS(data)) \
        .flatMap(lambda x: SC.rddembed_ILS_EXT(x, rate=0.2)) \
        .mapValues(lambda items: SC.rddfeat_ILS(items))
    timer.report()

    print '[time]writing table...'
    timer.mark()
    try:
        with table.batch(batch_size=5000) as b:
            for item in rdd_data.collect():
                imgname, imginfo = item[0], item[1]
                b.put(imgname,
                      {
                          'cf_pic:data': imginfo[0],
                          'cf_info:width': str(imginfo[1]),
                          'cf_info:height': str(imginfo[2]),
                          'cf_info:size': str(imginfo[3]),
                          'cf_info:capacity': str(imginfo[4]),
                          'cf_info:quality': str(imginfo[5]),
                          'cf_info:rate': str(imginfo[6]),
                          'cf_tag:chosen': str(imginfo[7]),
                          'cf_tag:class': str(imginfo[8]),
                          'cf_feat:ibd': imginfo[9],
                      })
    except ValueError:
        raise
    timer.report()