MSR.py 8.63 KB
# -*- coding: utf-8 -*-
__author__ = 'chunk'

from . import *
from ..mfeat import HOG, IntraBlockDiff
from ..mspark import SC
from ..common import *

import os, sys
import base64
from PIL import Image
from io import BytesIO
from hashlib import md5
import csv
import shutil
import json
import collections
import itertools

import happybase


class DataMSR(DataDumperBase):
    def __init__(self, base_dir='/home/hadoop/data/MSR-IRC2014/', category='Dev',
                 data_file='DevSetImage.tsv', tag_file='DevSetLabel.tsv'):
        DataDumperBase.__init__(self, base_dir, category)

        self.data_file = self.base + self.category + '/' + data_file
        self.tag_file = self.base + self.category + '/' + tag_file
        self.map_file = self.base + self.category + '/' + 'images_map.tsv'

        self.table_name = self.base.split('/')[-2] + '-' + self.category

    def format(self):
        self.extract()

    def _load_base64(self):
        assert self.data_file != None and os.path.exists(self.data_file)

        with open(self.data_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                yield (line[0], line[1])


    def _hash_dump(self, data):
        img = Image.open(BytesIO(base64.b64decode(data)))
        img.save('res/tmp.jpg', format='JPEG')

        with open('res/tmp.jpg', 'rb') as f:
            index = md5(f.read()).hexdigest()

        dir = self.img_dir + index[:3] + '/'
        if not os.path.exists(dir):
            os.makedirs(dir)
        image_path = dir + index[3:] + '.jpg'
        print image_path

        if not os.path.exists(image_path):
            shutil.copy('res/tmp.jpg', image_path)
            # or :
            # img.save(image, format='JPEG')


    def extract(self):
        for name, data in self._load_base64():
            self._hash_dump(data)


    def build_list(self):
        assert self.list_file != None
        with open(self.list_file, 'wb') as f:
            for path, subdirs, files in os.walk(self.img_dir):
                for name in files:
                    entry = path.split('/')[-1] + '/' + name
                    print entry
                    f.write(entry + '\n')


    def get_table(self):
        if self.table != None:
            return self.table

        if self.connection is None:
            c = happybase.Connection('HPC-server')
            self.connection = c

        tables = self.connection.tables()
        if self.table_name not in tables:
            families = {'cf_pic': dict(),
                        'cf_info': dict(max_versions=10),
                        'cf_tag': dict(),
                        'cf_feat': dict(),
            }
            self.connection.create_table(name=self.table_name, families=families)

        table = self.connection.table(name=self.table_name)

        self.table = table

        return table


    def store_img(self):
        if self.table == None:
            self.table = self.get_table()

        dict_buffer = {}
        with open(self.list_file, 'rb') as f:
            for line in f:
                path_img = line.strip('\n')
                if path_img:
                    with open(self.img_dir + path_img, 'rb') as fpic:
                        dict_buffer[path_img.replace('/', '')] = fpic.read()

        try:
            with self.table.batch(batch_size=5000) as b:
                for imgname, imgdata in dict_buffer.items():
                    b.put(imgname, {'cf_pic:data': imgdata})
        except ValueError:
            raise
            pass


    def store_tag(self, tagtype='retrieve'):
        if self.table == None:
            self.table = self.get_table()

        dict_namebuf = {}
        dict_tagbuf = {}

        with open(self.map_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_namebuf[line[2]] = line[0]

        with open(self.tag_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_tagbuf[line[-2]] = (line[0].split(), line[-1])

        try:
            with self.table.batch(batch_size=5000) as b:
                for key, data in self.table.scan():
                    value = dict_tagbuf[dict_namebuf[key[:-4]]]
                    b.put(key, {'cf_tag:' + tagtype: json.dumps(value[0]), 'cf_tag:eval': value[1]})
        except ValueError:
            raise
            pass


    def get_feat(self, image, feattype='ibd', **kwargs):
        size = kwargs.get('size', (48, 48))

        if feattype == 'hog':
            feater = HOG.FeatHOG(size=size)
        elif feattype == 'ibd':
            feater = IntraBlockDiff.FeatIntraBlockDiff()
        else:
            raise Exception("Unknown feature type!")

        desc = feater.feat(image)

        return desc

    def extract_feat(self, feattype='ibd'):

        if feattype == 'hog':
            feater = HOG.FeatHOG(size=(48, 48))
        elif feattype == 'ibd':
            feater = IntraBlockDiff.FeatIntraBlockDiff()
        else:
            raise Exception("Unknown feature type!")

        list_image = []
        with open(self.list_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                list_image.append(line[0])

        dict_featbuf = {}
        for imgname in list_image:
            # if imgtag == 'True':
            image = self.img_dir + imgname
            desc = feater.feat(image)
            dict_featbuf[imgname] = desc

        for imgname, desc in dict_featbuf.items():
            # print imgname, desc
            dir = self.feat_dir + imgname[:4]
            if not os.path.exists(dir):
                os.makedirs(dir)
            featpath = dir + imgname[4:].split('.')[0] + '.' + feattype
            with open(featpath, 'wb') as featfile:
                featfile.write(json.dumps(desc.tolist()))

    def store_feat(self, feattype='ibd'):
        if self.table == None:
            self.table = self.get_table()

        dict_featbuf = {}
        for path, subdirs, files in os.walk(self.feat_dir):
            for name in files:
                featpath = os.path.join(path, name)
                # print featpath
                with open(featpath, 'rb') as featfile:
                    imgname = path.split('/')[-1] + name.replace('.' + feattype, '.jpg')
                    dict_featbuf[imgname] = featfile.read()

        try:
            # with self.table.batch(batch_size=5000) as b:
            # for imgname, featdesc in dict_featbuf.items():
            # b.put(imgname, {'cf_feat:' + feattype: featdesc})
            with self.table.batch(batch_size=5000) as b:
                for key, data in self.table.scan():
                    b.put(key, {'cf_feat:' + feattype: dict_featbuf[key]})

        except ValueError:
            raise
            pass

    def load_data(self, mode='local', feattype='hog'):
        INDEX = []
        X = []
        Y = []

        if mode == "local":

            dict_tagbuf = {}
            with open(self.list_file, 'rb') as tsvfile:
                tsvfile = csv.reader(tsvfile, delimiter='\t')
                for line in tsvfile:
                    imgname = line[0] + '.jpg'
                    dict_tagbuf[imgname] = line[1]

            dict_dataset = {}
            for path, subdirs, files in os.walk(self.feat_dir):
                for name in files:
                    featpath = os.path.join(path, name)
                    with open(featpath, 'rb') as featfile:
                        imgname = path.split('/')[-1] + name.replace('.' + feattype, '.jpg')
                        dict_dataset[imgname] = json.loads(featfile.read())

            for imgname, tag in dict_tagbuf.items():
                tag = 1 if tag == 'True' else 0
                INDEX.append(imgname)
                X.append(dict_dataset[imgname])
                Y.append(tag)

        elif mode == "remote" or mode == "hbase":
            if self.table == None:
                self.table = self.get_table()

            for key, data in self.table.scan(columns=['cf_feat:' + feattype, 'cf_tag:' + feattype]):
                print key, data
                X.append(data['cf_feat:' + feattype])
                Y.append(data['cf_tag:' + feattype])

        elif mode == "spark" or mode == "cluster":
            if self.sparkcontex == None:
                self.sparkcontex = SC.Sparker(host='HPC-server', appname='ImageCV', master='spark://HPC-server:7077')

            result = self.sparkcontex.read_hbase(self.table_name)  # result = {key:[feat,tag],...}
            for key, data in result.items():
                X.append(data[0])
                Y.append(data[1])

        return X, Y