import time
import os
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models import create_model
from util.visualizer import save_images
from util import html
import torch
from options.train_options import TrainOptions
from collections import OrderedDict

def concat(visuals_list, w, h):
    opt = TrainOptions().parse()
    w = w.numpy()[0]
    h = h.numpy()[0]
    print("1")
    real_A = torch.FloatTensor(1, 4, w, h).zero_()
    fake_B = torch.FloatTensor(1, 3, w, h).zero_()
    real_B = torch.FloatTensor(1, 3, w, h).zero_()

    nw = int(w / opt.fineSize * 2)
    nh = int(h / opt.fineSize * 2)
    nw0 = int(w % opt.fineSize)
    nh0 = int(h % opt.fineSize)
    step = int(opt.fineSize / 2)
    step_mid = int(step/2)
    #for i, visual in enumerate(visuals_list):

    for y in range(nh):
        for x in range(nw):
            visual = visuals_list[y * nw + x]
            a = visual['real_A']
            b = visual['fake_B']
            real_b = visual['real_B']
            if x == 0 and y == 0:
                real_A[:, :,  0:opt.fineSize,0:opt.fineSize] = a[:, :, :, :]
                fake_B[:, :, 0:opt.fineSize, 0:opt.fineSize] = b[:, :, :, :]
                continue

            if x == 0 and y != nh-1:
                real_A[:, :, y * step + step_mid:y * step + opt.fineSize,  0:opt.fineSize] = a[:, :, step_mid:opt.fineSize, :]
                fake_B[:, :, y * step + step_mid:y * step + opt.fineSize,  0:opt.fineSize] = b[:, :, step_mid:opt.fineSize, :]
                continue

            if x == 0 and y == nh - 1:
                real_A[:, :, h - opt.fineSize + step_mid:h, 0:opt.fineSize] = a[:, :, step_mid:opt.fineSize, :]
                fake_B[:, :, h - opt.fineSize + step_mid:h, 0:opt.fineSize] = b[:, :, step_mid:opt.fineSize, :]
                continue
            if y ==0 and x != nw-1:
                real_A[:, :, 0:opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = a[:, :, :, step_mid:opt.fineSize]
                fake_B[:, :, 0:opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = b[:, :, :, step_mid:opt.fineSize]
                continue
            if y ==0 and x == nw-1:
                real_A[:, :, 0:opt.fineSize, w-opt.fineSize + step_mid:w] = a[:, :, :, step_mid:opt.fineSize]
                fake_B[:, :, 0:opt.fineSize, w-opt.fineSize + step_mid:w] = b[:, :, :, step_mid:opt.fineSize]
                continue

            if y == nh-1 and x == nw-1:
                real_A[:, :, h-opt.fineSize + step_mid:h, w-opt.fineSize + step_mid:w] = a[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
                fake_B[:, :, h-opt.fineSize + step_mid:h, w-opt.fineSize + step_mid:w] = b[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
                continue
            if y == nh-1:
                real_A[:, :, h-opt.fineSize + step_mid:h, x * step + step_mid:x * step + opt.fineSize] = a[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
                fake_B[:, :, h-opt.fineSize + step_mid:h, x * step + step_mid:x * step + opt.fineSize] = b[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
                continue

            if x == nw-1:
                real_A[:, :, y * step + step_mid:y * step + opt.fineSize, w-opt.fineSize + step_mid:w] = a[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
                fake_B[:, :, y * step + step_mid:y * step + opt.fineSize, w-opt.fineSize + step_mid:w] = b[:, :, step_mid:opt.fineSize, step_mid:opt.fineSize]
                continue

            real_A[:, :, y * step + step_mid:y * step + opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = a[:, :,step_mid:opt.fineSize, step_mid:opt.fineSize]
            fake_B[:, :, y * step + step_mid:y * step + opt.fineSize, x * step + step_mid:x * step + opt.fineSize] = b[:, :,step_mid:opt.fineSize, step_mid:opt.fineSize]

    visual_ret = OrderedDict()
    visual_ret['real_A'] = real_A
    visual_ret['real_B'] = real_B
    visual_ret['fake_B'] = fake_B
    return visual_ret
    '''
    nw = int(w / self.opt.fineSize * 2)
    nh = int(h / self.opt.fineSize * 2)
    nw0 = int(w % self.opt.fineSize)
    nw0 = int(h % self.opt.fineSize)
    step = int(self.opt.fineSize / 2)
    A_temp = torch.FloatTensor(nw * nh, 3, self.opt.fineSize, self.opt.fineSize).zero_()
    for iw in range(nw):
        for ih in range(nh):
            if iw == nw - 1 and ih == nh - 1:
                A_temp[iw * nh + ih, :, :, :] = A[:, w - self.opt.fineSize:w, h - self.opt.fineSize:h]
                continue
            if iw == nw - 1 and ih != nh - 1:
                A_temp[iw * nh + ih, :, :, :] = A[:, w - self.opt.fineSize:w, ih * step:ih * step + self.opt.fineSize]
                continue
            if iw != nw - 1 and ih == nh - 1:
                A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step + self.opt.fineSize, h - self.opt.fineSize:h]
                continue
            A_temp[iw * nh + ih, :, :, :] = A[:, iw * step:iw * step + self.opt.fineSize,
                                            ih * step:ih * step + self.opt.fineSize]
    A = A_temp

    '''

import torchvision.transforms as transforms

if __name__ == "__main__":
    opt = TestOptions().parse()
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip
    opt.display_id = -1 # no visdom display
    #opt.loadSize = opt.fineSize  # Do not scale!

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    model = create_model(opt)

    # create website
    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    # test

    '''

    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        t1 = time.time()
        model.set_input(data)
        model.test()
        t2 = time.time()
        print(t2-t1)
        visuals = model.get_current_visuals()
        img_path = model.get_image_paths()
        print('process image... %s' % img_path)
        save_images(webpage, visuals, img_path, 0, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
    webpage.save()

    '''


    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break

        A_path = data['A_paths']
        A = data['A'][0]
        B = data['B'][0]
        mask = data['M'][0]
        A_flip = data['A_F'][0]
        B_flip = data['B_F'][0]

        w = data['im_size'][0]
        h = data['im_size'][1]

        visuals_list = []
        for j in range(A.shape[0]):
            data_temp = {'A': A[j].unsqueeze(0), 'B': B[j].unsqueeze(0), 'A_F': A_flip[j].unsqueeze(0), 'B_F': B_flip[j].unsqueeze(0), 'M': mask[j].unsqueeze(0),
             'A_paths': A_path}
            t1 = time.time()
            model.set_input(data_temp)
            model.test()
            t2 = time.time()
            print(t2-t1)
            visuals = model.get_current_visuals()
            visuals_list.append(visuals)
            #img_path = model.get_image_paths()
            #print('process image... %s' % img_path)
            #save_images(webpage, visuals, img_path, j,  aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
        img_path = model.get_image_paths()
        print('process image... %s' % img_path)
        visual_ret = concat(visuals_list, w, h)
        save_images(webpage, visual_ret, img_path, 0,  aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
    webpage.save()
