MSR.py 4.37 KB
__author__ = 'chunk'

from mdata import *
from mfeat import *

import os, sys
import base64
from PIL import Image
from io import BytesIO
from hashlib import md5
import csv
import shutil
from common import *
import json
import collections
import itertools

import happybase


class DataMSR(DataDumperBase):
    def __init__(self, base_dir='/home/hadoop/data/MSR-IRC2014/', sub_dir='Dev/', data_file='DevSetImage.tsv'):
        DataDumperBase.__init__(self)
        self.base_dir = base_dir
        self.sub_dir = sub_dir
        self.data_file = self.base_dir + self.sub_dir + data_file

    def format(self):
        self.extract()

    def _load_base64(self):
        assert self.data_file != None and os.path.exists(self.data_file)

        with open(self.data_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                yield (line[0], line[1])


    def _hash_dump(self, data):
        img = Image.open(BytesIO(base64.b64decode(data)))
        img.save('res/tmp.jpg', format='JPEG')

        with open('res/tmp.jpg', 'rb') as f:
            index = md5(f.read()).hexdigest()

        dir = self.base_dir + self.sub_dir + 'Img/' + index[:3]
        if not os.path.exists(dir):
            os.makedirs(dir)
        path = dir + '/' + index[3:] + '.jpg'
        print path

        if not os.path.exists(path):
            shutil.copy('res/tmp.jpg', path)
            # or :
            # img.save(path, format='JPEG')


    def extract(self):
        for name, data in self.load_base64():
            self.hash_dump(data)


    def build_list(self):
        dir = self.base_dir + self.sub_dir
        lst = dir + 'Image.lst'
        with open(lst, 'wb') as f:
            for path, subdirs, files in os.walk(dir):
                for name in files:
                    entry = path.split('/')[-1] + '/' + name
                    print entry
                    f.write(entry + '\n')


    def get_table(self, tablename, connection=None):
        if connection is not None:
            c = connection
        else:
            c = happybase.Connection('HPC-server')
        tables = c.tables()
        if tablename not in tables:
            families = {'cf_pic': dict(),
                        'cf_info': dict(max_versions=10),
                        'cf_tag': dict(),
            }
            c.create_table(name=tablename, families=families)

        tb = c.table(name=tablename)
        return tb


    def store_image(self, table):
        timer.mark()
        dir = self.base_dir + self.sub_dir + 'Img2/'
        lst = dir + 'Image.lst'
        dict_buffer = {}
        with open(lst, 'rb') as f:
            for line in f:
                path_img = line.strip('\n')
                if path_img:
                    with open(dir + 'Dev/' + path_img, 'rb') as fpic:
                        dict_buffer[path_img.replace('/', '')] = fpic.read()
        timer.report()  # 1.507566s
        timer.mark()
        try:
            with table.batch(batch_size=5000) as b:
                for imgname, imgdata in dict_buffer.items():
                    b.put(imgname, {'cf_pic:data': imgdata})
                raise ValueError("Something went wrong!")
        except ValueError:
            pass
        timer.report()  # 228.003684s


    def store_tag(self, table):
        timer.mark()
        dir = self.base_dir + self.sub_dir + 'Img2/'
        maplst = dir + 'Image.tsv'
        taglist = self.base_dir + self.sub_dir + 'Dev/DevSetLabel.tsv'
        dict_namebuf = {}
        dict_tagbuf = {}

        with open(maplst, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_namebuf[line[0]] = line[2]

        with open(taglist, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_tagbuf[line[-2]] = (line[:-2], line[-1])

        timer.report()  # 0.148540s
        timer.mark()
        try:
            with table.batch(batch_size=5000) as b:
                for key, value in dict_tagbuf.items():
                    b.put(dict_namebuf[key] + '.jpg', {'cf_tag:' + ''.join(value[0]): value[1]})
                raise ValueError("Something went wrong!")
        except ValueError:
            pass
        timer.report()  # 3.280105s

    def get_feat(self, category):
        pass

    def store_feat(self, table, category):
        pass