ILSVRC-S.py 12.5 KB
__author__ = 'chunk'

from . import *
from ..mfeat import HOG, IntraBlockDiff
from ..mspark import SC
from ..common import *

import os, sys
from PIL import Image
from hashlib import md5
import csv
import shutil
import json
import collections
import happybase

from ..mjpeg import *
from ..msteg import *
from ..msteg.steganography import LSB, F3, F4, F5

import numpy as np
from numpy.random import randn
import pandas as pd
from scipy import stats

from subprocess import Popen, PIPE, STDOUT
import tempfile

np.random.seed(sum(map(ord, "whoami")))

package_dir = os.path.dirname(os.path.abspath(__file__))


class DataILSVRCS(DataDumperBase):
    """
    This module is specially for ILSVRC data processing under spark & hbase.

    We posit that the DB(e.g. HBase) has only the images data with md5 name as id.
    The task is to gennerate info(size,capacity,quality,etc.) and class & chosen tags, and then to perform embedding and finally to calcculate ibd features.

    Each step includes reading from & writing to Hbase (though PC).
    And each step must have a 'spark' mode option, which means that the operation is performed by spark with reading & wrting through RDDs.

    chunkplus@gmail.com
    """

    def __init__(self, base_dir='/media/chunk/Elements/D/data/ImageNet/img/ILSVRC2013_DET_val', category='Train'):
        DataDumperBase.__init__(self, base_dir, category)

        self.base_dir = base_dir
        self.category = category

        self.dict_data = {}

        self.table_name = self.base_dir.strip('/').split('/')[-1] + '-' + self.category
        self.sparkcontex = None


    def _get_table(self):
        if self.table != None:
            return self.table

        if self.connection is None:
            c = happybase.Connection('HPC-server')
            self.connection = c

        tables = self.connection.tables()
        if self.table_name not in tables:
            families = {'cf_pic': dict(),
                        'cf_info': dict(max_versions=10),
                        'cf_tag': dict(),
                        'cf_feat': dict(),
                        }
            self.connection.create_table(name=self.table_name, families=families)

        table = self.connection.table(name=self.table_name)

        self.table = table

        return table

    def _get_info(self, img, info_rate=None, tag_chosen=None, tag_class=None):
        """
        Tempfile is our friend. (?)
        """
        info_rate = info_rate if info_rate != None else 0.0
        tag_chosen = tag_chosen if tag_chosen != None else stats.bernoulli.rvs(0.8)
        tag_class = tag_class if tag_class != None else 0
        try:
            tmpf = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')
            tmpf.write(img)
            # tmpf.seek(0)
            im = Jpeg(tmpf.name, key=sample_key)
            info = [im.image_width,
                    im.image_height,
                    im.image_width * im.image_height,
                    im.getCapacity(),
                    im.getQuality(),
                    info_rate,
                    tag_chosen,
                    tag_class]
            return info
        except Exception as e:
            print e
        finally:
            tmpf.close()

    def _get_feat(self, image, feattype='ibd', **kwargs):
        size = kwargs.get('size', (48, 48))

        if feattype == 'hog':
            feater = HOG.FeatHOG(size=size)
        elif feattype == 'ibd':
            feater = IntraBlockDiff.FeatIntraBlockDiff()
        else:
            raise Exception("Unknown feature type!")

        desc = feater.feat(image)

        return desc

    def _extract_data(self, mode='hbase', writeback=False):
        """
        Get info barely out of image data.
        """
        if mode == 'hbase':
            if self.table == None:
                self.table = self.get_table()

            cols = ['cf_pic:data']
            for key, data in self.table.scan(columns=cols, scan_batching=True):
                self.dict_data[key] = [data] + self._get_info(data)

            if not writeback:
                return self.dict_data
            else:
                try:
                    with self.table.batch(batch_size=5000) as b:
                        for imgname, imginfo in self.dict_data.items():
                            b.put(imgname,
                                  {
                                      # 'cf_pic:data': imginfo[0],
                                      'cf_info:width': imginfo[1],
                                      'cf_info:height': imginfo[2],
                                      'cf_info:size': imginfo[3],
                                      'cf_info:capacity': imginfo[4],
                                      'cf_info:quality': imginfo[5],
                                      'cf_info:rate': imginfo[6],
                                      'cf_tag:chosen': imginfo[7],
                                      'cf_tag:class': imginfo[8], })
                except ValueError:
                    raise


        elif mode == 'spark':
            pass
        else:
            raise Exception("Unknown mode!")


    def _embed_data(self, mode='hbase', rate=None, readforward=False, writeback=False):
        f5 = F5.F5(sample_key, 1)
        if mode == 'hbase':
            if self.table == None:
                self.table = self.get_table()

            if readforward:
                self.dict_data = {}
                cols = ['cf_pic:data',
                        'cf_info:width',
                        'cf_info:height',
                        'cf_info:size',
                        'cf_info:capacity',
                        'cf_info:quality',
                        'cf_info:rate',
                        'cf_tag:chosen',
                        'cf_tag:class']
                for key, data in self.table.scan(columns=cols, scan_batching=True):
                    self.dict_data[key] = data

            dict_data_ext = {}

            for imgname, imgdata in self.dict_data.items():
                try:
                    tmpf_src = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')
                    tmpf_src.write(imgdata[0])
                    tmpf_dst = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')

                    if rate == None:
                        embed_rate = f5.embed_raw_data(tmpf_src, os.path.join(package_dir, '../res/toembed'), tmpf_dst)
                    else:
                        assert (rate >= 0 and rate < 1)
                        # print capacity
                        hidden = np.random.bytes(int(imgdata[4] * rate) / 8)
                        embed_rate = f5.embed_raw_data(tmpf_src, hidden, tmpf_dst, frommem=True)

                    tmpf_dst.seek(0)
                    raw = tmpf_dst.read()
                    index = md5(raw).hexdigest()
                    dict_data_ext[index + '.jpg'] = [raw] + self._get_info(raw, embed_rate, 0, 1)


                except Exception as e:
                    print e
                finally:
                    tmpf_src.close()
                    tmpf_dst.close()

            self.dict_data.update(dict_data_ext)

            if not writeback:
                return self.dict_data
            else:
                try:
                    with self.table.batch(batch_size=5000) as b:
                        for imgname, imginfo in dict_data_ext.items():
                            b.put(imgname,
                                  {
                                      'cf_pic:data': imginfo[0],
                                      'cf_info:width': imginfo[1],
                                      'cf_info:height': imginfo[2],
                                      'cf_info:size': imginfo[3],
                                      'cf_info:capacity': imginfo[4],
                                      'cf_info:quality': imginfo[5],
                                      'cf_info:rate': imginfo[6],
                                      'cf_tag:chosen': imginfo[7],
                                      'cf_tag:class': imginfo[8], })
                except ValueError:
                    raise

        elif mode == 'spark':
            pass
        else:
            raise Exception("Unknown mode!")


    def _extract_feat(self, mode='hbase', feattype='ibd', readforward=False, writeback=False, **kwargs):
        if mode == 'hbase':
            if self.table == None:
                self.table = self.get_table()

            if readforward:
                self.dict_data = {}
                cols = ['cf_pic:data',
                        'cf_info:width',
                        'cf_info:height',
                        'cf_info:size',
                        'cf_info:capacity',
                        'cf_info:quality',
                        'cf_info:rate',
                        'cf_tag:chosen',
                        'cf_tag:class']
                for key, data in self.table.scan(columns=cols, scan_batching=True):
                    self.dict_data[key] = data

            for imgname, imgdata in self.dict_data.items():
                try:
                    tmpf_src = tempfile.NamedTemporaryFile(suffix='.jpg', mode='w+b')
                    tmpf_src.write(imgdata[0])

                    desc = json.dumps(self._get_feat(tmpf_src, feattype=feattype))

                    self.dict_data[imgname].append(desc)

                except Exception as e:
                    print e
                finally:
                    tmpf_src.close()

            if not writeback:
                return self.dict_data
            else:
                try:
                    with self.table.batch(batch_size=5000) as b:
                        for imgname, imginfo in self.dict_data.items():
                            b.put(imgname,
                                  {
                                      'cf_pic:data': imginfo[0],
                                      'cf_info:width': imginfo[1],
                                      'cf_info:height': imginfo[2],
                                      'cf_info:size': imginfo[3],
                                      'cf_info:capacity': imginfo[4],
                                      'cf_info:quality': imginfo[5],
                                      'cf_info:rate': imginfo[6],
                                      'cf_tag:chosen': imginfo[7],
                                      'cf_tag:class': imginfo[8],
                                      'cf_feat:' + feattype: imginfo[9]})
                except ValueError:
                    raise

        elif mode == 'spark':
            pass
        else:
            raise Exception("Unknown mode!")


    def format(self):
        self._extract_data(mode='hbase', writeback=False)
        self._embed_data(mode='hbase', rate=0.1, readforward=False, writeback=False)
        self._extract_feat(mode='hbase', feattype='ibd', readforward=False, writeback=True)


    def load_data(self, mode='local', feattype='ibd', tagtype='class'):
        INDEX = []
        X = []
        Y = []

        if mode == "local":

            dict_dataset = {}

            with open(self.list_file, 'rb') as tsvfile:
                tsvfile = csv.reader(tsvfile, delimiter='\t')
                for line in tsvfile:
                    hash = line[0]
                    tag = line[-1]
                    path_feat = os.path.join(self.feat_dir, hash[:3], hash[3:] + '.' + feattype)
                    if path_feat:
                        with open(path_feat, 'rb') as featfile:
                            dict_dataset[hash] = (tag, json.loads(featfile.read()))

            for tag, feat in dict_dataset.values():
                X.append([item for sublist in feat for subsublist in sublist for item in subsublist])
                Y.append(int(tag))

        elif mode == "remote" or mode == "hbase":
            if self.table == None:
                self.table = self.get_table()

            col_feat, col_tag = 'cf_feat:' + feattype, 'cf_tag:' + tagtype
            for key, data in self.table.scan(columns=[col_feat, col_tag]):
                X.append(json.loads(data[col_feat]))
                Y.append(1 if data[col_tag] == 'True' else 0)

        elif mode == "spark" or mode == "cluster":
            if self.sparkcontex == None:
                self.sparkcontex = SC.Sparker(host='HPC-server', appname='ImageCV', master='spark://HPC-server:7077')

            result = self.sparkcontex.read_hbase(self.table_name)  # result = {key:[feat,tag],...}
            for feat, tag in result:
                X.append(feat)
                Y.append(tag)

        else:
            raise Exception("Unknown mode!")

        return X, Y