import os
import torch
from collections import OrderedDict
from torchsummary import summary



class BaseModel():
    def name(self):
        return 'BaseModel'

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
        if opt.resize_or_crop != 'scale_width':
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.image_paths = []

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, wrapping `forward` in no_grad() so we don't save
    # intermediate steps for backprop
    def test(self):
        with torch.no_grad():
            self.forward()

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def optimize_parameters(self):
        pass

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    # return visualization images. train.py will display these images, and save the images to a html
    def get_current_visuals(self):
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    # return traning losses/errors. train.py will print out these errors as debugging information
    def get_current_losses(self):
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                # float(...) works for both scalar tensor and float number
                errors_ret[name] = float(getattr(self, 'loss_' + name))
        return errors_ret

    # save models to the disk
    def save_networks(self, which_epoch):
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (which_epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                    ############################################
                    if name == 'G':
                        save_filename_pt_st = '%s_net_%s_st.pt' % (which_epoch, name)
                        ###unet256 成功#############
                        # example = torch.rand(1, 4, 256, 256)
                        # example = example.cuda()
                        # net.eval()
                        # traced_script_module = torch.jit.trace(net.module.cuda().eval(), example)
                        # traced_script_module.save(save_filename_pt_st)

                        example1 = torch.ones(1, 4, 64, 64).cuda()
                        example2 = torch.ones(2, 4, 64, 64).cuda()
                        # result1 = net.module(example)
                        net.eval()
                        # traced_script_module = torch.jit.trace(net.module.cuda().eval(), example)
                        traced_script_module = torch.jit.script(net.module)
                        result1 = traced_script_module(example1)
                        result2 = traced_script_module(example2)
                        print(result1)
                        print(result2)
                        # torch.jit.export_opnames(traced_script_module)
                        traced_script_module.save(save_filename_pt_st)


                else:
                    torch.save(net.cpu().state_dict(), save_path)
                    ############################################
                    if name == 'G':
                        save_filename_pt_st = '%s_net_%s_st.pt' % (which_epoch, name)
                        # torch.save(net.cpu(), save_filename_pt_st)
                        # example = torch.rand(1, 4, 256, 256)
                        # example = example.cuda()
                        # traced_script_module = torch.jit.trace(net.module, (example, example))
                        # traced_script_module.save(save_filename_pt_st)
                        example = torch.zeros(1, 4, 128, 128)
                        example = example
                        net.eval()
                        traced_script_module = torch.jit.trace(net.cpu(), example)
                        traced_script_module.save(save_filename_pt_st)

                        '''
                        model = torch.load(save_path)
                        net.load_state_dict(model)
                        net.eval()
                        example = torch.rand(1, 4, 256, 256).cuda()  # 生成一个随机输入维度的输入
                        traced_script_module = torch.jit.trace(net, example)
                        traced_script_module.save(save_filename_pt_st)


                        model = torch.load(save_path)
                        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                        summary(model, input_size=(1, 4, 256, 256))
                        model = model.to(device)
                        traced_script_module = torch.jit.trace(model, torch.ones(1, 4, 256, 256).to(device))
                        traced_script_module.save(save_filename_pt_st)
                        '''

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    # load models from the disk
    def load_networks(self, which_epoch):
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (which_epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=str(self.device))
                #state_dict = torch.load(load_path)
                # patch InstanceNorm checkpoints prior to 0.4
                for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

    # print network information
    def print_networks(self, verbose):
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad