MSR.py 4.46 KB
__author__ = 'chunk'

from mdata import *
from mfeat import *

import os, sys
import base64
from PIL import Image
from io import BytesIO
from hashlib import md5
import csv
import shutil
from common import *
import json
import collections
import itertools

import happybase


class DataMSR(DataDumperBase):
    def __init__(self, base_dir='/home/hadoop/data/MSR-IRC2014/', category='Dev',
                 data_file='DevSetImage.tsv', tag_file='DevSetLabel.tsv'):
        DataDumperBase.__init__(self, base_dir, category)

        self.data_file = self.base_dir + self.category + '/' + data_file
        self.tag_file = self.base_dir + self.category + '/' + tag_file
        self.map_file = self.base_dir + self.category + '/' + 'images_map.tsv'

        self.table_name = self.base_dir.split('/')[-2] + '-' + self.category

    def format(self):
        self.extract()

    def _load_base64(self):
        assert self.data_file != None and os.path.exists(self.data_file)

        with open(self.data_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                yield (line[0], line[1])


    def _hash_dump(self, data):
        img = Image.open(BytesIO(base64.b64decode(data)))
        img.save('res/tmp.jpg', format='JPEG')

        with open('res/tmp.jpg', 'rb') as f:
            index = md5(f.read()).hexdigest()

        dir = self.img_dir + index[:3] + '/'
        if not os.path.exists(dir):
            os.makedirs(dir)
        image_path = dir + index[3:] + '.jpg'
        print image_path

        if not os.path.exists(image_path):
            shutil.copy('res/tmp.jpg', image_path)
            # or :
            # img.save(image, format='JPEG')


    def extract(self):
        for name, data in self._load_base64():
            self._hash_dump(data)


    def build_list(self):
        assert self.list_file != None
        with open(self.list_file, 'wb') as f:
            for path, subdirs, files in os.walk(self.img_dir):
                for name in files:
                    entry = path.split('/')[-1] + '/' + name
                    print entry
                    f.write(entry + '\n')


    def get_table(self):
        if self.table != None:
            return self.table

        if self.connection is None:
            c = happybase.Connection('HPC-server')
            self.connection = c

        tables = self.connection.tables()
        if self.table_name not in tables:
            families = {'cf_pic': dict(),
                        'cf_info': dict(max_versions=10),
                        'cf_tag': dict(),
                        'cf_feat': dict(),
            }
            self.connection.create_table(name=self.table_name, families=families)

        table = self.connection.table(name=self.table_name)

        self.table = table

        return table


    def store_image(self):
        if self.table == None:
            self.table = self.get_table()

        dict_buffer = {}
        with open(self.list_file, 'rb') as f:
            for line in f:
                path_img = line.strip('\n')
                if path_img:
                    with open(self.img_dir + path_img, 'rb') as fpic:
                        dict_buffer[path_img.replace('/', '')] = fpic.read()

        try:
            with self.table.batch(batch_size=5000) as b:
                for imgname, imgdata in dict_buffer.items():
                    b.put(imgname, {'cf_pic:data': imgdata})
        except ValueError:
            raise
            pass


    def store_tag(self, feattype='retrieve'):
        if self.table == None:
            self.table = self.get_table()

        dict_namebuf = {}
        dict_tagbuf = {}

        with open(self.map_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_namebuf[line[2]] = line[0]

        with open(self.tag_file, 'rb') as tsvfile:
            tsvfile = csv.reader(tsvfile, delimiter='\t')
            for line in tsvfile:
                dict_tagbuf[line[-2]] = (line[0].split(), line[-1])

        try:
            with self.table.batch(batch_size=5000) as b:
                for key, data in self.table.scan():
                    value = dict_tagbuf[dict_namebuf[key[:-4]]]
                    b.put(key, {'cf_tag:' + feattype: json.dumps(value[0]), 'cf_tag:eval': value[1]})
        except ValueError:
            raise
            pass

    def get_feat(self, feattype):
        pass

    def store_feat(self, feattype):
        pass