ILSVRC.py 10.2 KB
__author__ = 'chunk'

from . import *
from ..mfeat import HOG, IntraBlockDiff
from ..mspark import SC
from ..common import *

import os, sys
from PIL import Image
from hashlib import md5
import csv
import shutil
import json
import collections
import happybase

from ..mjpeg import *
from ..msteg import *

import os
import numpy as np
from numpy.random import randn
import pandas as pd
from scipy import stats


np.random.seed(sum(map(ord, "whoami")))


class DataILSVRC(DataDumperBase):
    def __init__(self, base_dir='/media/chunk/Elements/D/data/ImageNet/img/ILSVRC2013_DET_val', category='Train'):
        DataDumperBase.__init__(self, base_dir, category)

        self.base_dir = base_dir
        self.category = category
        self.data_dir = os.path.join(self.base_dir, self.category)

        self.dst_dir = os.path.join(self.base_dir, 'dst', self.category)
        self.list_file = os.path.join(self.dst_dir, 'file-tag.tsv')
        self.feat_dir = os.path.join(self.dst_dir, 'Feat')
        self.img_dir = os.path.join(self.dst_dir, 'Img')

        self.dict_data = {}

        self.table_name = self.base_dir.strip('/').split('/')[-1] + '-' + self.category
        self.sparkcontex = None

    def format(self):
        self.extract()

    def _hash_copy(self, image):
        if not image.endswith('jpg'):
            img = Image.open(image)
            img.save('res/tmp.jpg', format='JPEG')
            image = 'res/tmp.jpg'

        with open(image, 'rb') as f:
            index = md5(f.read()).hexdigest()

        im = Jpeg(image, key=sample_key)
        self.dict_data[index] = [im.image_width, im.image_height, os.path.getsize(image), im.getQuality()]

        # origion:
        # dir = base_dir + 'Img/Train/' + index[:3]
        dir = os.path.join(self.img_dir, index[:3])
        if not os.path.exists(dir):
            os.makedirs(dir)
        image_path = os.path.join(dir, index[3:] + '.jpg')
        # print image_path

        if not os.path.exists(image_path):
            shutil.copy(image, image_path)
        else:
            pass

    def _build_list(self):
        assert self.list_file != None

        ordict_img = collections.OrderedDict(sorted(self.dict_data.items(), key=lambda d: d[0]))

        with open(self.list_file, 'w') as f:
            tsvfile = csv.writer(f, delimiter='\t')
            for key, value in ordict_img.items():
                tsvfile.writerow([key] + value)

    def _anaylis(self):
        df_ILS = pd.read_csv(self.list_file, names=['hash', 'width', 'height', 'size', 'quality'], sep='\t')
        length = df_ILS.shape[0]
        df_new = df_ILS.sort(['size', 'quality'], ascending=True)
        rand_class = stats.bernoulli.rvs(0.3, size=length)

        df_new['class'] = pd.Series(rand_class, index=df_new.index)
        df_new.to_csv(self.list_file, header=False, index=False, sep='\t')


    def extract(self):
        for path, subdirs, files in os.walk(self.data_dir):
            for name in files:
                imagepath = os.path.join(path, name)
                # print imagepath
                try:
                    self._hash_copy(imagepath)
                except:
                    pass

        self._build_list()
        self._anaylis()

    def get_table(self):
        if self.table != None:
            return self.table

        if self.connection is None:
            c = happybase.Connection('HPC-server')
            self.connection = c

        tables = self.connection.tables()
        if self.table_name not in tables:
            families = {'cf_pic': dict(),
                        'cf_info': dict(max_versions=10),
                        'cf_tag': dict(),
                        'cf_feat': dict(),
                        }
            self.connection.create_table(name=self.table_name, families=families)

        table = self.connection.table(name=self.table_name)

        self.table = table

        return table


    def store_image(self):
        if self.table == None:
            self.table = self.get_table()

        dict_databuf = {}

        with open(self.list_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                path_img = os.path.join(self.img_dir, line[0][:3], line[0][3:] + '.jpg')
                if path_img:
                    with open(path_img, 'rb') as fpic:
                        dict_databuf[line[0] + '.jpg'] = fpic.read()

        try:
            with self.table.batch(batch_size=5000) as b:
                for imgname, imgdata in dict_databuf.items():
                    b.put(imgname, {'cf_pic:data': imgdata})
        except ValueError:
            raise
            pass

    def store_info(self, infotype='all'):
        if self.table == None:
            self.table = self.get_table()

        dict_infobuf = {}

        with open(self.list_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_infobuf[line[0] + '.jpg'] = line[2:-1]

        if infotype == 'all':
            try:
                with self.table.batch(batch_size=5000) as b:
                    for imgname, imginfo in dict_infobuf.items():
                        b.put(imgname,
                              {'cf_info:width': imginfo[0], 'cf_info:height': imginfo[1], 'cf_info:size': imginfo[2],
                               'cf_info:quality': imginfo[3]})
            except ValueError:
                raise
                pass
        else:
            raise Exception("Unknown mode!")


    def store_tag(self, tagtype='class'):
        if self.table == None:
            self.table = self.get_table()

        dict_tagbuf = {}

        with open(self.list_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_tagbuf[line[0] + '.jpg'] = line[-1]

        try:
            with self.table.batch(batch_size=5000) as b:
                for imgname, imgtag in dict_tagbuf.items():
                    b.put(imgname, {'cf_tag:' + tagtype: imgtag})
        except ValueError:
            raise
            pass


    def get_feat(self, image, feattype='ibd', **kwargs):
        size = kwargs.get('size', (48, 48))

        if feattype == 'hog':
            feater = HOG.FeatHOG(size=size)
        elif feattype == 'ibd':
            feater = IntraBlockDiff.FeatIntraBlockDiff()
        else:
            raise Exception("Unknown feature type!")

        desc = feater.feat(image)

        return desc

    def extract_feat(self, feattype='ibd'):

        if feattype == 'hog':
            feater = HOG.FeatHOG(size=(48, 48))
        elif feattype == 'ibd':
            feater = IntraBlockDiff.FeatIntraBlockDiff()
        else:
            raise Exception("Unknown feature type!")

        list_image = []
        with open(self.list_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                list_image.append(line[0])

        dict_featbuf = {}
        for imgname in list_image:
            # if imgtag == 'True':
            image = os.path.join(self.img_dir, imgname[:3], imgname[3:] + '.jpg')
            desc = feater.feat(image)
            dict_featbuf[imgname] = desc

        for imgname, desc in dict_featbuf.items():
            # print imgname, desc
            dir = os.path.join(self.feat_dir, imgname[:3])
            if not os.path.exists(dir):
                os.makedirs(dir)
            featpath = os.path.join(dir, imgname[3:].split('.')[0] + '.' + feattype)
            with open(featpath, 'wb') as featfile:
                featfile.write(json.dumps(desc.tolist()))


    def store_feat(self, feattype='ibd'):
        if self.table == None:
            self.table = self.get_table()

        dict_featbuf = {}
        for path, subdirs, files in os.walk(self.feat_dir):
            for name in files:
                featpath = os.path.join(path, name)
                # print featpath
                with open(featpath, 'rb') as featfile:
                    imgname = path.split('/')[-1] + name.replace('.' + feattype, '.jpg')
                    dict_featbuf[imgname] = featfile.read()

        try:
            with self.table.batch(batch_size=5000) as b:
                for imgname, featdesc in dict_featbuf.items():
                    b.put(imgname, {'cf_feat:' + feattype: featdesc})
        except ValueError:
            raise
            pass

    def load_data(self, mode='local', feattype='ibd', tagtype='class'):
        INDEX = []
        X = []
        Y = []

        if mode == "local":

            dict_tagbuf = {}
            with open(self.list_file, 'rb') as tsvfile:
                tsvfile = csv.reader(tsvfile, delimiter='\t')
                for line in tsvfile:
                    imgname = line[0] + '.jpg'
                    dict_tagbuf[imgname] = line[-1]

            dict_dataset = {}
            for path, subdirs, files in os.walk(self.feat_dir):
                for name in files:
                    featpath = os.path.join(path, name)
                    with open(featpath, 'rb') as featfile:
                        imgname = path.split('/')[-1] + name.replace('.' + feattype, '.jpg')
                        dict_dataset[imgname] = json.loads(featfile.read())

            for imgname, tag in dict_tagbuf.items():
                tag = 1 if tag == 'True' else 0
                INDEX.append(imgname)
                X.append(dict_dataset[imgname])
                Y.append(tag)

        elif mode == "remote" or mode == "hbase":
            if self.table == None:
                self.table = self.get_table()

            col_feat, col_tag = 'cf_feat:' + feattype, 'cf_tag:' + tagtype
            for key, data in self.table.scan(columns=[col_feat, col_tag]):
                X.append(json.loads(data[col_feat]))
                Y.append(1 if data[col_tag] == 'True' else 0)

        elif mode == "spark" or mode == "cluster":
            if self.sparkcontex == None:
                self.sparkcontex = SC.Sparker(host='HPC-server', appname='ImageCV', master='spark://HPC-server:7077')

            result = self.sparkcontex.read_habase(self.table_name)  # result = {key:[feat,tag],...}
            for feat, tag in result:
                X.append(feat)
                Y.append(tag)

        else:
            raise Exception("Unknown mode!")

        return X, Y