import numpy as np
import torch
import torch.nn as nn
from time import time


def unfold(img, patch_size, stride, with_indexes=torch.tensor(False, dtype=torch.bool)):
    n_dim = 4
    assert img.dim() == n_dim, 'image must be of dimension 4.'

    kH, kW = patch_size, patch_size
    dH, dW = stride, stride
    input_windows = img.unfold(2, kH, dH).unfold(3, kW, dW)

    i_0, i_1, i_2, i_3, i_4, i_5 = input_windows.size()

    if with_indexes:
        input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1)
        return input_windows, i_2, i_3, i_1
    else:
        input_windows = input_windows.permute(0, 2, 3, 1, 4, 5).contiguous().view(i_0, i_2 * i_3, i_1, i_4, i_5)
        return input_windows, 0, 0, 0


def filter(input_windows, flag, value):
    ## EXTRACT MASK OR NOT DEPENDING ON VALUE
    assert flag.dim() == 2, "flag should be batch version"
    input_window = input_windows[flag == value]
    bz = flag.size(0)
    return input_window.view(bz, input_window.size(0) // bz, -1)


def cosine_similarity(former, latter, patch_size, stride, flag, with_former=torch.tensor(False, dtype=torch.bool)):
    former_windows, _, _, _ = unfold(former, patch_size, stride)
    former = filter(former_windows, flag, torch.tensor(1))

    latter_windows, i_2, i_3, i_1 = unfold(latter, patch_size, stride,
                                           with_indexes=torch.tensor(True, dtype=torch.bool))
    latter = filter(latter_windows, flag, torch.tensor(0))

    num = torch.einsum('bik,bjk->bij', [former, latter])
    norm_latter = torch.einsum("bij,bij->bi", [latter, latter])
    norm_former = torch.einsum("bij,bij->bi", [former, former])
    den = torch.sqrt(torch.einsum('bi,bj->bij', [norm_former, norm_latter]))
    if not with_former:
        return num / den, latter_windows, torch.tensor(0), i_2, i_3, i_1
    else:
        return num / den, latter_windows, former_windows, i_2, i_3, i_1


# delete i_4, as i_4 is 1
def paste(input_windows, transition_matrix, i_2, i_3, i_1):
    ## TRANSPOSE FEATURES NEW FEATURES
    bz = input_windows.size(0)
    input_windows = torch.bmm(transition_matrix, input_windows)

    ## RESIZE TO CORRET CONV FEATURES FORMAT
    input_windows = input_windows.view(bz, i_2, i_3, i_1)
    input_windows = input_windows.permute(0, 3, 1, 2)
    return input_windows


def InnerShiftTripleFunction(input, shift_sz, stride, triple_w, flag, show_flow):
    # InnerShiftTripleFunction.ctx = ctx
    assert input.dim() == 4, "Input Dim has to be 4"
    # ctx.triple_w = triple_w
    # ctx.flag = flag
    # ctx.show_flow = show_flow

    bz = input.size(0)
    c_real = input.size(1)
    h = input.size(2)
    w = input.size(3)
    c = c_real

    # ind_lst = torch.Tensor(bz, h * w, h * w).zero_().to(input)
    ind_lst = torch.zeros(bz, h * w, h * w).to(input)

    # former and latter are all tensors
    former_all = input.narrow(1, 0, c // 2)  ### decoder feature
    latter_all = input.narrow(1, c // 2, c // 2)  ### encoder feature
    # shift_masked_all = torch.tensor(former_all.size()).type_as(former_all).zero_()  # addition feature

    flag = flag.to(input).long()

    # None batch version
    # bNonparm = Batch_NonShift()
    shift_offsets = []

    # batch version
    cosine, latter_windows, form, i_2, i_3, i_1 = cosine_similarity(former_all.clone(), latter_all.clone(),
                                                                    torch.tensor(1),
                                                                    stride, flag)

    _, indexes = torch.max(cosine, dim=2)

    mask_indexes = (flag == 1).nonzero()[:, 1].view(bz, -1)

    non_mask_indexes = (flag == 0).nonzero()[:, 1].view(bz, -1).gather(1, indexes)

    idx_b = torch.arange(bz).long().unsqueeze(1).expand(bz, mask_indexes.size(1))
    # set the elemnets of indexed by [mask_indexes, non_mask_indexes] to 1.
    # It is a batch version
    ind_lst[(idx_b, mask_indexes, non_mask_indexes)] = torch.ones(
        ind_lst[(idx_b, mask_indexes, non_mask_indexes)].shape).to(input)

    shift_masked_all = paste(latter_windows, ind_lst, torch.tensor(i_2), torch.tensor(i_3), torch.tensor(i_1))

    return torch.cat((former_all, latter_all, shift_masked_all), 1)
