#!/usr/bin/env python
__author__ = 'paul.duckworth'
import os, sys
import cv2
import numpy as np
from os import listdir, path
from os.path import isfile, join, isdir

class ImageCreator():

    def __init__(self, directory):

        self.video_iter = 0
        self.all_videos = [v for v in sorted(listdir(directory))  if isdir(join(directory, v))]

        self.video = self.all_videos[self.video_iter]

        self.video_len = 0
        self.directory = directory
        self.dir = join(self.directory, self.video)
        self.skeleton_data = {}
        self.joints = [
            'head',
            'neck',
            'torso',
            'left_shoulder',
            'left_elbow',
            'left_hand',
            'left_hip',
            'left_knee',
            'left_foot',
            'right_shoulder',
            'right_elbow',
            'right_hand',
            'right_hip',
            'right_knee',
            'right_foot']

        self.connected_joints = [
            ['head','neck'],
            ['neck','torso'],
            ['neck','right_shoulder'],
            ['neck','left_shoulder'],
            ['torso','right_hip'],
            ['torso','left_hip'],
            ['left_shoulder','left_elbow'],
            ['left_elbow','left_hand'],
            ['left_hip','left_knee'],
            ['left_knee','left_foot'],
            ['right_shoulder','right_elbow'],
            ['right_elbow','right_hand'],
            ['right_hip','right_knee'],
            ['right_knee','right_foot'] ]

        self

    def create_sk_images(self):
        """Creates the rgb image with overlaid skeleton tracks.
        Outputs each frame to /rgb_sk folder.
        It uses the rgb images released in the DOI, and so images may be blurred for privicy reasons.
        """
        for self.video_iter, self.video in enumerate(self.all_videos):
            self.dir = join(self.directory, self.video)
            print self.dir

            self.video_len = len([f for f in listdir(join(self.dir, "skeleton")) if isfile(join(self.dir, "skeleton", f))])
            for val in range(1, self.video_len+1):
                if int(val)<10:         val_str = '0000'+str(val)
                elif int(val)<100:      val_str = '000'+str(val)
                elif int(val)<1000:     val_str = '00'+str(val)
                elif int(val)<10000:    val_str = '0'+str(val)
                elif int(val)<100000:   val_str = str(val)
                rgb_img = self.add_image(val_str)
                self.get_2d_sk(val_str)
                self.plot_sk(rgb_img, val_str)

    def add_image(self,val_str):
        img_loc = join(self.dir, "rgb", "rgb_"+val_str+'.jpg')
        img = cv2.imread(img_loc)
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
        return img

    def plot_sk(self,img,val_str):
        self.get_2d_sk(val_str)
        img = self.plot_2d_sk(img)
        cv2.imwrite(join(self.dir, "rgb_sk", "sk_"+val_str+'.jpg'), img)

    def get_2d_sk(self,val_str):
        fx = 525.0
        fy = 525.0
        cx = 319.5
        cy = 239.5
        f1 = open(join(self.dir,'skeleton','skl_'+val_str+'.txt'),'r')
        for count,line in enumerate(f1):
            # read the joint name
            if (count-1)%10 == 0:
                j = line.split('\n')[0]
                self.skeleton_data[j] = [0,0,0,0,0]
            # read the x value
            elif (count-1)%10 == 2:
                a = float(line.split('\n')[0].split(':')[1])
                self.skeleton_data[j][0] = a
            # read the y value
            elif (count-1)%10 == 3:
                a = float(line.split('\n')[0].split(':')[1])
                self.skeleton_data[j][1] = a
            # read the z value
            elif (count-1)%10 == 4:
                a = float(line.split('\n')[0].split(':')[1])
                self.skeleton_data[j][2] = a
                #2D data
                x = self.skeleton_data[j][0]
                y = self.skeleton_data[j][1]
                z = self.skeleton_data[j][2]
                x2d = int(x*fx/z*1 +cx);
                y2d = int(y*fy/z*-1+cy);
                self.skeleton_data[j][3] = x2d
                self.skeleton_data[j][4] = y2d

    def plot_2d_sk(self, img):
        for j in self.connected_joints:
            x0 = self.skeleton_data[j[0]][3]
            y0 = self.skeleton_data[j[0]][4]
            x1 = self.skeleton_data[j[1]][3]
            y1 = self.skeleton_data[j[1]][4]
            # print j[0],x0,y0,j[1],x1,y1
            cv2.line(img,(x0,y0),(x1,y1),(240,0,180), 3)

        for j in self.joints:
            x = self.skeleton_data[j][3]
            y = self.skeleton_data[j][4]
            cv2.circle(img,(x,y),5, (240,0,50), -1)
            if j == 'head':
                cv2.circle(img,(x,y),9, (255,0,50), -1)
            if j == 'right_hand':
                cv2.circle(img,(x,y),5, (0,255,50), -1)
        return img




directory = path.join(os.path.dirname(os.path.abspath(__file__)), "Data")
print "directory = %s" % directory

# date_files = [f for f in listdir(directory)]
all_dates = [f for f in ['2016-04-05', '2016-04-06', '2016-04-07', '2016-04-08', '2016-04-11'] if isdir(join(directory, f))]

for each_date in sorted(all_dates):
    d = join(directory, each_date)
    ic = ImageCreator(d)
    ic.create_sk_images()
