Commit 22a89e4e authored by 魏博昱's avatar 魏博昱

first commit

parent 2eeefc25
//#include <torch/script.h> // One-stop header.
//#include <torch/csrc/jit/import.h>
//#include <iostream>
//#include <opencv2/opencv.hpp>
//#include <opencv2/core.hpp>
//#include <opencv2/imgcodecs.hpp>
//#include <opencv2/highgui.hpp>
//#include <memory>
//#include <string>
//
//cv::Mat loadImg(std::string path) {
// auto image = cv::imread(path, cv::ImreadModes::IMREAD_COLOR);
// cv::Mat image_transfomed = image;
// return image_transfomed;
//};
//
//
//torch::Tensor preProcess(cv::Mat img_mat) {
// // BGR2RGB
// // cv::cvtColor(img_mat, img_mat, cv::COLOR_BGR2RGB);
//
//
// // convert cv::Mat to at::Tensor (see https://pytorch.org/cppdocs/api/namespace_at.html#namespace-at)
// torch::Tensor tensor_image = torch::from_blob(img_mat.data, { img_mat.rows, img_mat.cols,3 }, torch::kByte);
// tensor_image = tensor_image.permute({ 2,0,1 });
// tensor_image = tensor_image.toType(torch::kFloat);
// tensor_image = tensor_image.unsqueeze(0).to(at::kCUDA);
// return tensor_image;
//}
//
//
//void postProcess(at::Tensor output) {
// output = output.squeeze().detach().permute({ 1, 2, 0 });
// output = output.clamp(0, 255).to(torch::kU8);
// output = output.to(torch::kCPU);
// cv::Mat result_mat(cv::Size(512, 512), CV_8UC3, output.data_ptr());
// cv::imwrite("result.jpg", result_mat);
//}
//
//int main() {
//
// // load model
// torch::NoGradGuard no_grad;
// torch::jit::script::Module model = torch::jit::load("E:\\project\\ConsoleApplication4\\DASR.pt");
// model.eval();
// model.to(at::kCUDA);
//
// // pre-process
// cv::Mat image_transfomed = loadImg("4.jpg");
// torch::Tensor tensor_image = preProcess(image_transfomed, 1);
//
// //forward
// at::Tensor output = model.forward({ tensor_image }).toTensor();
//
// // post-process
// postProcess(output, 1);
// return 0;
//}
\ No newline at end of file
from importlib import import_module
from dataloader import MSDataLoader
class Data:
def __init__(self, args):
self.loader_train = None
if not args.test_only:
module_train = import_module('data.' + args.data_train.lower()) ## load the right dataset loader module
trainset = getattr(module_train, args.data_train)(args) ## load the dataset, args.data_train is the dataset name
self.loader_train = MSDataLoader(
args,
trainset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=not args.cpu
)
if args.data_test in ['Set5', 'Set14', 'B100', 'Manga109', 'Urban100']:
module_test = import_module('data.benchmark')
testset = getattr(module_test, 'Benchmark')(args, name=args.data_test,train=False)
else:
module_test = import_module('data.' + args.data_test.lower())
testset = getattr(module_test, args.data_test)(args, train=False)
self.loader_test = MSDataLoader(
args,
testset,
batch_size=1,
shuffle=False,
pin_memory=not args.cpu
)
import os
from data import common
from data import multiscalesrdata as srdata
class Benchmark(srdata.SRData):
def __init__(self, args, name='', train=True):
super(Benchmark, self).__init__(
args, name=name, train=train, benchmark=True
)
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data,'benchmark', self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.png','.jpg', 'bmp')
print(self.dir_hr)
print(self.dir_lr)
import random
import numpy as np
import skimage.color as sc
import torch
def get_patch(img, patch_size=48, scale=1):
th, tw = img.shape[:2] ## HR image
tp = round(scale * patch_size)
tx = random.randrange(0, (tw-tp))
ty = random.randrange(0, (th-tp))
return img[ty:ty + tp, tx:tx + tp, :]
def set_channel(img, n_channels=3):
if img.ndim == 2:
img = np.expand_dims(img, axis=2)
c = img.shape[2]
if n_channels == 1 and c == 3:
img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2)
elif n_channels == 3 and c == 1:
img = np.concatenate([img] * n_channels, 2)
return img
def np2Tensor(img, rgb_range=255):
np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
tensor = torch.from_numpy(np_transpose).float()
tensor.mul_(rgb_range / 255)
return tensor
def augment(img, hflip=True, rot=True):
hflip = hflip and random.random() < 0.5
vflip = rot and random.random() < 0.5
rot90 = rot and random.random() < 0.5
if hflip: img = img[:, ::-1, :]
if vflip: img = img[::-1, :, :]
if rot90: img = img.transpose(1, 0, 2)
return img
import os
from data import multiscalesrdata
class DF2K(multiscalesrdata.SRData):
def __init__(self, args, name='DF2K', train=True, benchmark=False):
super(DF2K, self).__init__(args, name=name, train=train, benchmark=benchmark)
def _scan(self):
names_hr = super(DF2K, self)._scan()
names_hr = names_hr[self.begin - 1:self.end]
return names_hr
def _set_filesystem(self, dir_data):
super(DF2K, self)._set_filesystem(dir_data)
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
import os
import glob
from data import common
import pickle
import numpy as np
import imageio
import torch
import torch.utils.data as data
class SRData(data.Dataset):
def __init__(self, args, name='', train=True, benchmark=False):
self.args = args
self.name = name
self.train = train
self.split = 'train' if train else 'test'
self.do_eval = True
self.benchmark = benchmark
self.scale = args.scale
self.idx_scale = 0
data_range = [r.split('-') for r in args.data_range.split('/')]
if train:
data_range = data_range[0]
else:
if args.test_only and len(data_range) == 1:
data_range = data_range[0]
else:
data_range = data_range[1]
self.begin, self.end = list(map(lambda x: int(x), data_range))
self._set_filesystem(args.dir_data)
if args.ext.find('img') < 0:
path_bin = os.path.join(self.apath, 'bin')
os.makedirs(path_bin, exist_ok=True)
list_hr = self._scan()
if args.ext.find('bin') >= 0:
# Binary files are stored in 'bin' folder
# If the binary file exists, load it. If not, make it.
list_hr = self._scan()
self.images_hr = self._check_and_load(
args.ext, list_hr, self._name_hrbin()
)
else:
if args.ext.find('img') >= 0 or benchmark:
self.images_hr = list_hr
elif args.ext.find('sep') >= 0:
os.makedirs(
self.dir_hr.replace(self.apath, path_bin),
exist_ok=True
)
self.images_hr = []
for h in list_hr:
b = h.replace(self.apath, path_bin)
b = b.replace(self.ext[0], '.pt')
self.images_hr.append(b)
self._check_and_load(
args.ext, [h], b, verbose=True, load=False
)
if train:
self.repeat = args.test_every // (len(self.images_hr) // args.batch_size)
# Below functions as used to prepare images
def _scan(self):
names_hr = sorted(
glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) + glob.glob(os.path.join(self.dir_hr, '*' + self.ext[1])) +glob.glob(os.path.join(self.dir_hr, '*' + self.ext[2]))
)
names_lr = sorted(
glob.glob(os.path.join(self.dir_lr, '*' + self.ext[0])) + glob.glob(
os.path.join(self.dir_lr, '*' + self.ext[1])) + glob.glob(os.path.join(self.dir_lr, '*' + self.ext[2]))
)
print(len(names_hr))
if self.train:
return names_hr
else:
return names_lr
def _set_filesystem(self, dir_data):
self.apath = os.path.join(dir_data, self.name)
self.dir_hr = os.path.join(self.apath, 'HR')
self.dir_lr = os.path.join(self.apath, 'LR_bicubic')
self.ext = ('.bmp', '.png')
def _name_hrbin(self):
return os.path.join(
self.apath,
'bin',
'{}_bin_HR.pt'.format(self.split)
)
def _name_lrbin(self, scale):
return os.path.join(
self.apath,
'bin',
'{}_bin_LR_X{}.pt'.format(self.split, scale)
)
def _check_and_load(self, ext, l, f, verbose=True, load=True):
if os.path.isfile(f) and ext.find('reset') < 0:
if load:
if verbose: print('Loading {}...'.format(f))
with open(f, 'rb') as _f:
ret = pickle.load(_f)
return ret
else:
return None
else:
if verbose:
if ext.find('reset') >= 0:
print('Making a new binary: {}'.format(f))
else:
print('{} does not exist. Now making binary...'.format(f))
b = [{
'name': os.path.splitext(os.path.basename(_l))[0],
'image': imageio.imread(_l)
} for _l in l]
with open(f, 'wb') as _f:
pickle.dump(b, _f)
return b
def __getitem__(self, idx):
hr, filename = self._load_file(idx)
hr = self.get_patch(hr)
hr = [common.set_channel(img, n_channels=self.args.n_colors) for img in hr]
hr_tensor = [common.np2Tensor(img, rgb_range=self.args.rgb_range)
for img in hr]
return torch.stack(hr_tensor, 0), filename
def __len__(self):
if self.train:
return len(self.images_hr) * self.repeat
else:
return len(self.images_hr)
def _get_index(self, idx):
if self.train:
return idx % len(self.images_hr)
else:
return idx
def _load_file(self, idx):
idx = self._get_index(idx)
f_hr = self.images_hr[idx]
if self.args.ext.find('bin') >= 0:
filename = f_hr['name']
hr = f_hr['image']
else:
filename, _ = os.path.splitext(os.path.basename(f_hr))
if self.args.ext == 'img' or self.benchmark:
hr = imageio.imread(f_hr)
elif self.args.ext.find('sep') >= 0:
with open(f_hr, 'rb') as _f:
hr = np.load(_f)[0]['image']
return hr, filename
def get_patch(self, hr):
scale = self.scale[self.idx_scale]
if self.train:
out = []
hr = common.augment(hr) if not self.args.no_augment else hr
# extract two patches from each image
for _ in range(2):
hr_patch = common.get_patch(
hr,
patch_size=self.args.patch_size,
scale=scale
)
out.append(hr_patch)
else:
out = [hr]
return out
def set_scale(self, idx_scale):
self.idx_scale = idx_scale
import sys
import threading
import queue
import random
import collections
import torch
import torch.multiprocessing as multiprocessing
from torch._C import _set_worker_signal_handlers
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataloader import _DataLoaderIter
from torch.utils.data import _utils
if sys.version_info[0] == 2:
import Queue as queue
else:
import queue
def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id):
global _use_shared_memory
_use_shared_memory = True
_set_worker_signal_handlers()
torch.set_num_threads(1)
torch.manual_seed(seed)
while True:
r = index_queue.get()
if r is None:
break
idx, batch_indices = r
try:
idx_scale = 0
if len(scale) > 1 and dataset.train:
idx_scale = random.randrange(0, len(scale))
dataset.set_scale(idx_scale)
samples = collate_fn([dataset[i] for i in batch_indices])
samples.append(idx_scale)
except Exception:
data_queue.put((idx, _utils.ExceptionWrapper(sys.exc_info())))
else:
data_queue.put((idx, samples))
class _MSDataLoaderIter(_DataLoaderIter):
def __init__(self, loader):
self.dataset = loader.dataset
self.scale = loader.scale
self.collate_fn = loader.collate_fn
self.batch_sampler = loader.batch_sampler
self.num_workers = loader.num_workers
self.pin_memory = loader.pin_memory and torch.cuda.is_available()
self.timeout = loader.timeout
self.done_event = threading.Event()
self.sample_iter = iter(self.batch_sampler)
if self.num_workers > 0:
self.worker_init_fn = loader.worker_init_fn
self.index_queues = [
multiprocessing.Queue() for _ in range(self.num_workers)
]
self.worker_queue_idx = 0
self.worker_result_queue = multiprocessing.Queue()
self.batches_outstanding = 0
self.worker_pids_set = False
self.shutdown = False
self.send_idx = 0
self.rcvd_idx = 0
self.reorder_dict = {}
base_seed = torch.LongTensor(1).random_()[0]
self.workers = [
multiprocessing.Process(
target=_ms_loop,
args=(
self.dataset,
self.index_queues[i],
self.worker_result_queue,
self.collate_fn,
self.scale,
base_seed + i,
self.worker_init_fn,
i
)
)
for i in range(self.num_workers)]
if self.pin_memory or self.timeout > 0:
self.data_queue = queue.Queue()
if self.pin_memory:
maybe_device_id = torch.cuda.current_device()
else:
# do not initialize cuda context if not necessary
maybe_device_id = None
self.pin_memory_thread = threading.Thread(
target=_utils.pin_memory._pin_memory_loop,
args=(self.worker_result_queue, self.data_queue, maybe_device_id, self.done_event))
self.pin_memory_thread.daemon = True
self.pin_memory_thread.start()
else:
self.data_queue = self.worker_result_queue
for w in self.workers:
w.daemon = True # ensure that the worker exits on process exit
w.start()
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self.workers))
_utils.signal_handling._set_SIGCHLD_handler()
self.worker_pids_set = True
# prime the prefetch loop
for _ in range(2 * self.num_workers):
self._put_indices()
class MSDataLoader(DataLoader):
def __init__(
self, args, dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None,
collate_fn=_utils.collate.default_collate, pin_memory=False, drop_last=True,
timeout=0, worker_init_fn=None):
super(MSDataLoader, self).__init__(
dataset, batch_size=batch_size, shuffle=shuffle,
sampler=sampler, batch_sampler=batch_sampler,
num_workers=args.n_threads, collate_fn=collate_fn,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
self.scale = args.scale
def __iter__(self):
return _MSDataLoaderIter(self)
import os
from torchvision import transforms
import model
import utility
from model.blindsr import BlindSR
import torch
from option import args
import cv2
import numpy as np
from utils import util
def to_img(tensor, arg):
normalized = tensor[0].data.mul(255 / arg.rgb_range)
# print(tensor[0])
# print(normalized)
out = normalized.byte().permute(1, 2, 0).cpu().numpy()
# out = np.concatenate((out, out, out), axis=0)
b = out[:, :, 0].copy()
g = out[:, :, 1].copy()
r = out[:, :, 2].copy()
out[:, :, 0] = r
out[:, :, 1] = g
out[:, :, 2] = b
print(out)
return out
def main():
model1 = BlindSR(args)
model1.load_state_dict(torch.load('experiment/blindsr_x2_bicubic_iso/model/model_600.pt'), strict=False)
# checkpoint = utility.checkpoint(args)
#
# if checkpoint.ok:
# model1 = model.Model(args, checkpoint)
model1 = model1.cuda()
model1.eval()
img = cv2.imread('Figs/001 (2).png')
temp = img[:, :, 0].copy()
img[:, :, 0] = img[:, :, 2]
img[:, :, 2] = temp
transf = transforms.ToTensor()
img = transf(img)
img = img.unsqueeze(0).contiguous() * 255
with torch.no_grad():
img = img.cuda()
sr = model1(img)
sr = utility.quantize(sr, args.rgb_range)
print(sr)
result = to_img(sr, args)
cv2.imwrite('result.jpg', result)
if __name__ == '__main__':
main()
import os
from importlib import import_module
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Loss(nn.modules.loss._Loss):
def __init__(self, args, ckp):
super(Loss, self).__init__()
print('Preparing loss function:')
self.n_GPUs = args.n_GPUs
self.loss = []
self.loss_module = nn.ModuleList()
for loss in args.loss.split('+'):
weight, loss_type = loss.split('*')
if loss_type == 'MSE':
loss_function = nn.MSELoss()
elif loss_type == 'L1':
loss_function = nn.L1Loss()
elif loss_type == 'CE':
loss_function = nn.CrossEntropyLoss()
elif loss_type.find('VGG') >= 0:
module = import_module('loss.vgg')
loss_function = getattr(module, 'VGG')(
loss_type[3:],
rgb_range=args.rgb_range
)
elif loss_type.find('GAN') >= 0:
module = import_module('loss.adversarial')
loss_function = getattr(module, 'Adversarial')(
args,
loss_type
)
self.loss.append({
'type': loss_type,
'weight': float(weight),
'function': loss_function}
)
if loss_type.find('GAN') >= 0:
self.loss.append({'type': 'DIS', 'weight': 1, 'function': None})
if len(self.loss) > 1:
self.loss.append({'type': 'Total', 'weight': 0, 'function': None})
for l in self.loss:
if l['function'] is not None:
print('{:.3f} * {}'.format(l['weight'], l['type']))
self.loss_module.append(l['function'])
self.log = torch.Tensor()
device = torch.device('cpu' if args.cpu else 'cuda')
self.loss_module.to(device)
if args.precision == 'half': self.loss_module.half()
if not args.cpu and args.n_GPUs > 1:
self.loss_module = nn.DataParallel(
self.loss_module, range(args.n_GPUs)
)
if args.load != '.': self.load(ckp.dir, cpu=args.cpu)
def forward(self, sr, hr):
losses = []
for i, l in enumerate(self.loss):
if l['function'] is not None:
loss = l['function'](sr, hr)
effective_loss = l['weight'] * loss
losses.append(effective_loss)
self.log[-1, i] += effective_loss.item()
elif l['type'] == 'DIS':
self.log[-1, i] += self.loss[i - 1]['function'].loss
loss_sum = sum(losses)
if len(self.loss) > 1:
self.log[-1, -1] += loss_sum.item()
return loss_sum
def step(self):
for l in self.get_loss_module():
if hasattr(l, 'scheduler'):
l.scheduler.step()
def start_log(self):
self.log = torch.cat((self.log, torch.zeros(1, len(self.loss))))
def end_log(self, n_batches):
self.log[-1].div_(n_batches)
def display_loss(self, batch):
n_samples = batch + 1
log = []
for l, c in zip(self.loss, self.log[-1]):
log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples))
return ''.join(log)
def plot_loss(self, apath, epoch):
axis = np.linspace(1, epoch, epoch)
for i, l in enumerate(self.loss):
label = '{} Loss'.format(l['type'])
fig = plt.figure()
plt.title(label)
plt.plot(axis, self.log[:, i].numpy(), label=label)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig('{}/loss_{}.pdf'.format(apath, l['type']))
plt.close(fig)
def get_loss_module(self):
if self.n_GPUs == 1:
return self.loss_module
else:
return self.loss_module.module
def save(self, apath):
torch.save(self.state_dict(), os.path.join(apath, 'loss.pt'))
torch.save(self.log, os.path.join(apath, 'loss_log.pt'))
def load(self, apath, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
self.load_state_dict(torch.load(
os.path.join(apath, 'loss.pt'),
**kwargs
))
self.log = torch.load(os.path.join(apath, 'loss_log.pt'))
for l in self.loss_module:
if hasattr(l, 'scheduler'):
for _ in range(len(self.log)): l.scheduler.step()
\ No newline at end of file
import utility
from model import common
from loss import discriminator
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
class Adversarial(nn.Module):
def __init__(self, args, gan_type):
super(Adversarial, self).__init__()
self.gan_type = gan_type
self.gan_k = args.gan_k
self.discriminator = discriminator.Discriminator(args, gan_type)
if gan_type != 'WGAN_GP':
self.optimizer = utility.make_optimizer(args, self.discriminator)
else:
self.optimizer = optim.Adam(
self.discriminator.parameters(),
betas=(0, 0.9), eps=1e-8, lr=1e-5
)
self.scheduler = utility.make_scheduler(args, self.optimizer)
def forward(self, fake, real):
fake_detach = fake.detach()
self.loss = 0
for _ in range(self.gan_k):
self.optimizer.zero_grad()
d_fake = self.discriminator(fake_detach)
d_real = self.discriminator(real)
if self.gan_type == 'GAN':
label_fake = torch.zeros_like(d_fake)
label_real = torch.ones_like(d_real)
loss_d \
= F.binary_cross_entropy_with_logits(d_fake, label_fake) \
+ F.binary_cross_entropy_with_logits(d_real, label_real)
elif self.gan_type.find('WGAN') >= 0:
loss_d = (d_fake - d_real).mean()
if self.gan_type.find('GP') >= 0:
epsilon = torch.rand_like(fake).view(-1, 1, 1, 1)
hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon)
hat.requires_grad = True
d_hat = self.discriminator(hat)
gradients = torch.autograd.grad(
outputs=d_hat.sum(), inputs=hat,
retain_graph=True, create_graph=True, only_inputs=True
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean()
loss_d += gradient_penalty
# Discriminator update
self.loss += loss_d.item()
loss_d.backward()
self.optimizer.step()
if self.gan_type == 'WGAN':
for p in self.discriminator.parameters():
p.data.clamp_(-1, 1)
self.loss /= self.gan_k
d_fake_for_g = self.discriminator(fake)
if self.gan_type == 'GAN':
loss_g = F.binary_cross_entropy_with_logits(
d_fake_for_g, label_real
)
elif self.gan_type.find('WGAN') >= 0:
loss_g = -d_fake_for_g.mean()
# Generator loss
return loss_g
def state_dict(self, *args, **kwargs):
state_discriminator = self.discriminator.state_dict(*args, **kwargs)
state_optimizer = self.optimizer.state_dict()
return dict(**state_discriminator, **state_optimizer)
# Some references
# https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py
# OR
# https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
from model import common
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self, args, gan_type='GAN'):
super(Discriminator, self).__init__()
in_channels = 3
out_channels = 64
depth = 7
#bn = not gan_type == 'WGAN_GP'
bn = True
act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
m_features = [
common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act)
]
for i in range(depth):
in_channels = out_channels
if i % 2 == 1:
stride = 1
out_channels *= 2
else:
stride = 2
m_features.append(common.BasicBlock(
in_channels, out_channels, 3, stride=stride, bn=bn, act=act
))
self.features = nn.Sequential(*m_features)
patch_size = args.patch_size // (2**((depth + 1) // 2))
m_classifier = [
nn.Linear(out_channels * patch_size**2, 1024),
act,
nn.Linear(1024, 1)
]
self.classifier = nn.Sequential(*m_classifier)
def forward(self, x):
features = self.features(x)
output = self.classifier(features.view(features.size(0), -1))
return output
from model import common
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.autograd import Variable
class VGG(nn.Module):
def __init__(self, conv_index, rgb_range=1):
super(VGG, self).__init__()
vgg_features = models.vgg19(pretrained=True).features
modules = [m for m in vgg_features]
if conv_index == '22':
self.vgg = nn.Sequential(*modules[:8])
elif conv_index == '54':
self.vgg = nn.Sequential(*modules[:35])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
self.vgg.requires_grad = False
def forward(self, sr, hr):
def _forward(x):
x = self.sub_mean(x)
x = self.vgg(x)
return x
vgg_sr = _forward(sr)
with torch.no_grad():
vgg_hr = _forward(hr.detach())
loss = F.mse_loss(vgg_sr, vgg_hr)
return loss
from option import args
import torch
import utility
import data
import model
import loss
from trainer import Trainer
if __name__ == '__main__':
torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)
if checkpoint.ok:
loader = data.Data(args)
model = model.Model(args, checkpoint)
loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, model, loss, checkpoint)
while not t.terminate():
t.train()
checkpoint.done()
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
import torch.nn as nn
class MoCo(nn.Module):
"""
Build a MoCo model with: a query encoder, a key encoder, and a queue
https://arxiv.org/abs/1911.05722
"""
def __init__(self, base_encoder, dim=256, K=32*256, m=0.999, T=0.07, mlp=False):
"""
dim: feature dimension (default: 128)
K: queue size; number of negative keys (default: 65536)
m: moco momentum of updating key encoder (default: 0.999)
T: softmax temperature (default: 0.07)
"""
super(MoCo, self).__init__()
self.K = K
self.m = m
self.T = T
# create the encoders
# num_classes is the output fc dimension
self.encoder_q = base_encoder()
self.encoder_k = base_encoder()
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
# create the queue
self.register_buffer("queue", torch.randn(dim, K))
self.queue = nn.functional.normalize(self.queue, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, keys):
# gather keys before updating queue
# keys = concat_all_gather(keys)
batch_size = keys.shape[0]
ptr = int(self.queue_ptr)
assert self.K % batch_size == 0 # for simplicity
# replace the keys at ptr (dequeue and enqueue)
self.queue[:, ptr:ptr + batch_size] = keys.transpose(0, 1)
ptr = (ptr + batch_size) % self.K # move pointer
self.queue_ptr[0] = ptr
@torch.no_grad()
def _batch_shuffle_ddp(self, x):
"""
Batch shuffle, for making use of BatchNorm.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all).cuda()
# broadcast to all gpus
torch.distributed.broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], idx_unshuffle
@torch.no_grad()
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
"""
Undo batch shuffle.
*** Only support DistributedDataParallel (DDP) model. ***
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = concat_all_gather(x)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = torch.distributed.get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this]
def forward(self, im_q, im_k):
"""
Input:
im_q: a batch of query images
im_k: a batch of key images
Output:
logits, targets
"""
# if self.training:
# # compute query features
# embedding, q = self.encoder_q(im_q) # queries: NxC
# q = nn.functional.normalize(q, dim=1)
#
# # compute key features
# with torch.no_grad(): # no gradient to keys
# self._momentum_update_key_encoder() # update the key encoder
#
# _, k = self.encoder_k(im_k) # keys: NxC
# k = nn.functional.normalize(k, dim=1)
#
# # compute logits
# # Einstein sum is more intuitive
# # positive logits: Nx1
# l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
# # negative logits: NxK
# l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
#
# # logits: Nx(1+K)
# logits = torch.cat([l_pos, l_neg], dim=1)
#
# # apply temperature
# logits /= self.T
#
# # labels: positive key indicators
# labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
#
# # dequeue and enqueue
# self._dequeue_and_enqueue(k)
#
# return embedding, logits, labels
# else:
embedding, _ = self.encoder_q(im_q)
return embedding
# utils
@torch.no_grad()
def concat_all_gather(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
import os
from importlib import import_module
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, args, ckp):
super(Model, self).__init__()
print('Making model...')
self.args = args
self.scale = args.scale
self.idx_scale = 0
self.self_ensemble = args.self_ensemble
self.chop = args.chop
self.precision = args.precision
self.cpu = args.cpu
self.device = torch.device('cpu' if args.cpu else 'cuda')
self.n_GPUs = args.n_GPUs
self.save_models = args.save_models
self.save = args.save
module = import_module('model.'+args.model)
self.model = module.make_model(args).to(self.device)
if args.precision == 'half': self.model.half()
if not args.cpu and args.n_GPUs > 1:
self.model = nn.DataParallel(self.model, range(args.n_GPUs))
self.load(
ckp.dir,
pre_train=args.pre_train,
resume=args.resume,
cpu=args.cpu
)
def forward(self, x):
if self.self_ensemble and not self.training:
if self.chop:
forward_function = self.forward_chop
else:
forward_function = self.model.forward
return self.forward_x8(x, forward_function)
elif self.chop and not self.training:
return self.forward_chop(x)
else:
return self.model(x)
def get_model(self):
if self.n_GPUs <= 1 or self.cpu:
return self.model
else:
return self.model.module
def state_dict(self, **kwargs):
target = self.get_model()
return target.state_dict(**kwargs)
def save(self, apath, epoch, is_best=False):
target = self.get_model()
torch.save(
target.state_dict(),
os.path.join(apath, 'model', 'model_latest.pt')
)
if is_best:
torch.save(
target.state_dict(),
os.path.join(apath, 'model', 'model_best.pt')
)
if self.save_models:
torch.save(
target.state_dict(),
os.path.join(apath, 'model', 'model_{}.pt'.format(epoch))
)
def load(self, apath, pre_train='.', resume=-1, cpu=False):
if cpu:
kwargs = {'map_location': lambda storage, loc: storage}
else:
kwargs = {}
if resume == -1:
self.get_model().load_state_dict(
torch.load(os.path.join(apath, 'model', 'model_latest.pt'), **kwargs),
strict=True
)
elif resume == 0:
if pre_train != '.':
self.get_model().load_state_dict(
torch.load(pre_train, **kwargs),
strict=True
)
elif resume > 0:
self.get_model().load_state_dict(
torch.load(os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), **kwargs),
strict=False
)
def forward_chop(self, x, shave=10, min_size=160000):
scale = self.scale[self.idx_scale]
n_GPUs = min(self.n_GPUs, 4)
b, c, h, w = x.size()
h_half, w_half = h // 2, w // 2
h_size, w_size = h_half + shave, w_half + shave
lr_list = [
x[:, :, 0:h_size, 0:w_size],
x[:, :, 0:h_size, (w - w_size):w],
x[:, :, (h - h_size):h, 0:w_size],
x[:, :, (h - h_size):h, (w - w_size):w]]
if w_size * h_size < min_size:
sr_list = []
for i in range(0, 4, n_GPUs):
lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)
sr_batch = self.model(lr_batch)
sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))
else:
sr_list = [
self.forward_chop(patch, shave=shave, min_size=min_size) \
for patch in lr_list
]
h, w = scale * h, scale * w
h_half, w_half = scale * h_half, scale * w_half
h_size, w_size = scale * h_size, scale * w_size
shave *= scale
output = x.new(b, c, h, w)
output[:, :, 0:h_half, 0:w_half] \
= sr_list[0][:, :, 0:h_half, 0:w_half]
output[:, :, 0:h_half, w_half:w] \
= sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]
output[:, :, h_half:h, 0:w_half] \
= sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]
output[:, :, h_half:h, w_half:w] \
= sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
return output
def forward_x8(self, x, forward_function):
def _transform(v, op):
if self.precision != 'single': v = v.float()
v2np = v.data.cpu().numpy()
if op == 'v':
tfnp = v2np[:, :, :, ::-1].copy()
elif op == 'h':
tfnp = v2np[:, :, ::-1, :].copy()
elif op == 't':
tfnp = v2np.transpose((0, 1, 3, 2)).copy()
ret = torch.Tensor(tfnp).to(self.device)
if self.precision == 'half': ret = ret.half()
return ret
lr_list = [x]
for tf in 'v', 'h', 't':
lr_list.extend([_transform(t, tf) for t in lr_list])
sr_list = [forward_function(aug) for aug in lr_list]
for i in range(len(sr_list)):
if i > 3:
sr_list[i] = _transform(sr_list[i], 't')
if i % 4 > 1:
sr_list[i] = _transform(sr_list[i], 'h')
if (i % 4) % 2 == 1:
sr_list[i] = _transform(sr_list[i], 'v')
output_cat = torch.cat(sr_list, dim=0)
output = output_cat.mean(dim=0, keepdim=True)
return output
import torch
from torch import nn
import model.common as common
import torch.nn.functional as F
from moco.builder import MoCo
def make_model(args):
return BlindSR(args)
class DA_conv(nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, reduction):
super(DA_conv, self).__init__()
self.channels_out = channels_out
self.channels_in = channels_in
self.kernel_size = kernel_size
self.kernel = nn.Sequential(
nn.Linear(64, 64, bias=False),
nn.LeakyReLU(0.1, True),
nn.Linear(64, 64 * self.kernel_size * self.kernel_size, bias=False)
)
self.conv = common.default_conv(channels_in, channels_out, 1)
self.ca = CA_layer(channels_in, channels_out, reduction)
self.relu = nn.LeakyReLU(0.1, True)
def forward(self, x):
'''
:param x[0]: feature map: B * C * H * W
:param x[1]: degradation representation: B * C
'''
b, c, h, w = x[0].size()
# branch 1
kernel = self.kernel(x[1]).view(-1, 1, self.kernel_size, self.kernel_size)
out = self.relu(F.conv2d(x[0].view(1, -1, h, w), kernel, groups=b*c, padding=(self.kernel_size-1)//2))
out = self.conv(out.view(b, -1, h, w))
# branch 2
out = out + self.ca(x)
return out
class CA_layer(nn.Module):
def __init__(self, channels_in, channels_out, reduction):
super(CA_layer, self).__init__()
self.conv_du = nn.Sequential(
nn.Conv2d(channels_in, channels_in//reduction, 1, 1, 0, bias=False),
nn.LeakyReLU(0.1, True),
nn.Conv2d(channels_in // reduction, channels_out, 1, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
'''
:param x[0]: feature map: B * C * H * W
:param x[1]: degradation representation: B * C
'''
att = self.conv_du(x[1][:, :, None, None])
return x[0] * att
class DAB(nn.Module):
def __init__(self, conv, n_feat, kernel_size, reduction):
super(DAB, self).__init__()
self.da_conv1 = DA_conv(n_feat, n_feat, kernel_size, reduction)
self.da_conv2 = DA_conv(n_feat, n_feat, kernel_size, reduction)
self.conv1 = conv(n_feat, n_feat, kernel_size)
self.conv2 = conv(n_feat, n_feat, kernel_size)
self.relu = nn.LeakyReLU(0.1, True)
def forward(self, x):
'''
:param x[0]: feature map: B * C * H * W
:param x[1]: degradation representation: B * C
'''
out = self.relu(self.da_conv1(x))
out = self.relu(self.conv1(out))
out = self.relu(self.da_conv2([out, x[1]]))
out = self.conv2(out) + x[0]
return out
class DAG(nn.Module):
def __init__(self, conv, n_feat, kernel_size, reduction, n_blocks):
super(DAG, self).__init__()
self.n_blocks = n_blocks
modules_body = [
DAB(conv, n_feat, kernel_size, reduction) \
for _ in range(n_blocks)
]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
'''
:param x[0]: feature map: B * C * H * W
:param x[1]: degradation representation: B * C
'''
res = x[0]
for i in range(self.n_blocks):
res = self.body[i]([res, x[1]])
res = self.body[-1](res)
res = res + x[0]
return res
class DASR(nn.Module):
def __init__(self, args, conv=common.default_conv):
super(DASR, self).__init__()
self.n_groups = 5
n_blocks = 5
n_feats = 64
kernel_size = 3
reduction = 8
scale = int(args.scale[0])
# RGB mean for DIV2K
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(255.0, rgb_mean, rgb_std)
self.add_mean = common.MeanShift(255.0, rgb_mean, rgb_std, 1)
# head module
modules_head = [conv(3, n_feats, kernel_size)]
self.head = nn.Sequential(*modules_head)
# compress
self.compress = nn.Sequential(
nn.Linear(256, 64, bias=False),
nn.LeakyReLU(0.1, True)
)
# body
modules_body = [
DAG(common.default_conv, n_feats, kernel_size, reduction, n_blocks) \
for _ in range(self.n_groups)
]
modules_body.append(conv(n_feats, n_feats, kernel_size))
self.body = nn.Sequential(*modules_body)
# tail
modules_tail = [common.Upsampler(conv, scale, n_feats, act=False),
conv(n_feats, 3, kernel_size)]
self.tail = nn.Sequential(*modules_tail)
def forward(self, x, k_v):
k_v = self.compress(k_v)
# sub mean
x = self.sub_mean(x)
# head
x = self.head(x)
# body
res = x
for i in range(self.n_groups):
res = self.body[i]([res, k_v])
res = self.body[-1](res)
res = res + x
# tail
x = self.tail(res)
# add mean
x = self.add_mean(x)
return x
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.E = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1, True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1, True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1, True),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.1, True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.1, True),
nn.AdaptiveAvgPool2d(1),
)
self.mlp = nn.Sequential(
nn.Linear(256, 256),
nn.LeakyReLU(0.1, True),
nn.Linear(256, 256),
)
def forward(self, x):
fea = self.E(x).squeeze(-1).squeeze(-1)
out = self.mlp(fea)
return fea, out
class BlindSR(nn.Module):
def __init__(self, args):
super(BlindSR, self).__init__()
# Generator
self.G = DASR(args)
# Encoder
self.E = MoCo(base_encoder=Encoder)
def forward(self, x):
# if self.training:
# x_query = x[:, 0, ...] # b, c, h, w
# x_key = x[:, 1, ...] # b, c, h, w
#
# # degradation-aware represenetion learning
# fea, logits, labels = self.E(x_query, x_key)
#
# # degradation-aware SR
# sr = self.G(x_query, fea)
#
# return sr, logits, labels
# else:
# degradation-aware represenetion learning
fea = self.E(x, x)
# degradation-aware SR
sr = self.G(x, fea)
return sr
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
class MeanShift(nn.Conv2d):
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(rgb_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.weight.data.div_(std.view(3, 1, 1, 1))
self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
self.bias.data.div_(std)
self.weight.requires_grad = False
self.bias.requires_grad = False
class Upsampler(nn.Sequential):
def __init__(self, conv, scale, n_feat, act=False, bias=True):
m = []
if (scale & (scale - 1)) == 0: # Is scale = 2^n?
for _ in range(int(math.log(scale, 2))):
m.append(conv(n_feat, 4 * n_feat, 3, bias))
m.append(nn.PixelShuffle(2))
if act: m.append(act())
elif scale == 3:
m.append(conv(n_feat, 9 * n_feat, 3, bias))
m.append(nn.PixelShuffle(3))
if act: m.append(act())
else:
raise NotImplementedError
super(Upsampler, self).__init__(*m)
import argparse
import template
parser = argparse.ArgumentParser(description='EDSR and MDSR')
parser.add_argument('--debug', action='store_true',
help='Enables debug mode')
parser.add_argument('--template', default='.',
help='You can set various templates in option.py')
# Hardware specifications
parser.add_argument('--n_threads', type=int, default=4,
help='number of threads for data loading')
parser.add_argument('--cpu', type=bool, default=False,
help='use cpu only')
parser.add_argument('--n_GPUs', type=int, default=1,
help='number of GPUs')
parser.add_argument('--seed', type=int, default=1,
help='random seed')
# Data specifications
parser.add_argument('--dir_data', type=str, default='/mnt/data4/weiboyu/',
help='dataset directory')
parser.add_argument('--dir_demo', type=str, default='../test',
help='demo image directory')
parser.add_argument('--data_train', type=str, default='DF2K',
help='train dataset name')
parser.add_argument('--data_test', type=str, default='Set14',
help='test dataset name')
parser.add_argument('--data_range', type=str, default='1-3450/801-810',
help='train/test data range')
parser.add_argument('--ext', type=str, default='sep',
help='dataset file extension')
parser.add_argument('--scale', type=str, default='2',
help='super resolution scale')
parser.add_argument('--patch_size', type=int, default=48,
help='output patch size')
parser.add_argument('--rgb_range', type=int, default=255,
help='maximum value of RGB')
parser.add_argument('--n_colors', type=int, default=3,
help='number of color channels to use')
parser.add_argument('--chop', action='store_true',
help='enable memory-efficient forward')
parser.add_argument('--no_augment', action='store_true',
help='do not use data augmentation')
# Degradation specifications
parser.add_argument('--blur_kernel', type=int, default=21,
help='size of blur kernels')
parser.add_argument('--blur_type', type=str, default='iso_gaussian',
help='blur types (iso_gaussian | aniso_gaussian)')
parser.add_argument('--mode', type=str, default='bicubic',
help='downsampler (bicubic | s-fold)')
parser.add_argument('--noise', type=float, default=0.0,
help='noise level')
## isotropic Gaussian blur
parser.add_argument('--sig_min', type=float, default=0.2,
help='minimum sigma of isotropic Gaussian blurs')
parser.add_argument('--sig_max', type=float, default=4.0,
help='maximum sigma of isotropic Gaussian blurs')
parser.add_argument('--sig', type=float, default=1.2,
help='specific sigma of isotropic Gaussian blurs')
## anisotropic Gaussian blur
parser.add_argument('--lambda_min', type=float, default=0.2,
help='minimum value for the eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--lambda_max', type=float, default=4.0,
help='maximum value for the eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--lambda_1', type=float, default=0.2,
help='one eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--lambda_2', type=float, default=4.0,
help='another eigenvalue of anisotropic Gaussian blurs')
parser.add_argument('--theta', type=float, default=0.0,
help='rotation angle of anisotropic Gaussian blurs')
# Model specifications
parser.add_argument('--model', default='blindsr',
help='model name')
parser.add_argument('--pre_train', type=str, default= '.',
help='pre-trained model directory')
parser.add_argument('--extend', type=str, default='.',
help='pre-trained model directory')
parser.add_argument('--shift_mean', default=True,
help='subtract pixel mean from the input')
parser.add_argument('--dilation', action='store_true',
help='use dilated convolution')
parser.add_argument('--precision', type=str, default='single',
choices=('single', 'half'),
help='FP precision for test (single | half)')
# Training specifications
parser.add_argument('--reset', action='store_true',
help='reset the training')
parser.add_argument('--test_every', type=int, default=1000,
help='do test per every N batches')
parser.add_argument('--epochs_encoder', type=int, default=100,
help='number of epochs to train the degradation encoder')
parser.add_argument('--epochs_sr', type=int, default=500,
help='number of epochs to train the whole network')
parser.add_argument('--batch_size', type=int, default=32,
help='input batch size for training')
parser.add_argument('--split_batch', type=int, default=1,
help='split the batch into smaller chunks')
parser.add_argument('--self_ensemble', action='store_true',
help='use self-ensemble method for test')
parser.add_argument('--test_only', action='store_true',default=True,
help='set this option to test the model')
# Optimization specifications
parser.add_argument('--lr_encoder', type=float, default=1e-3,
help='learning rate to train the degradation encoder')
parser.add_argument('--lr_sr', type=float, default=1e-4,
help='learning rate to train the whole network')
parser.add_argument('--lr_decay_encoder', type=int, default=60,
help='learning rate decay per N epochs')
parser.add_argument('--lr_decay_sr', type=int, default=125,
help='learning rate decay per N epochs')
parser.add_argument('--decay_type', type=str, default='step',
help='learning rate decay type')
parser.add_argument('--gamma_encoder', type=float, default=0.1,
help='learning rate decay factor for step decay')
parser.add_argument('--gamma_sr', type=float, default=0.5,
help='learning rate decay factor for step decay')
parser.add_argument('--optimizer', default='ADAM',
choices=('SGD', 'ADAM', 'RMSprop'),
help='optimizer to use (SGD | ADAM | RMSprop)')
parser.add_argument('--momentum', type=float, default=0.9,
help='SGD momentum')
parser.add_argument('--beta1', type=float, default=0.9,
help='ADAM beta1')
parser.add_argument('--beta2', type=float, default=0.999,
help='ADAM beta2')
parser.add_argument('--epsilon', type=float, default=1e-8,
help='ADAM epsilon for numerical stability')
parser.add_argument('--weight_decay', type=float, default=0,
help='weight decay')
parser.add_argument('--start_epoch', type=int, default=0,
help='resume from the snapshot, and the start_epoch')
# Loss specifications
parser.add_argument('--loss', type=str, default='1*L1',
help='loss function configuration')
parser.add_argument('--skip_threshold', type=float, default='1e6',
help='skipping batch that has large error')
# Log specifications
parser.add_argument('--save', type=str, default='blindsr',
help='file name to save')
parser.add_argument('--load', type=str, default='.',
help='file name to load')
parser.add_argument('--resume', type=int, default=600,
help='resume from specific checkpoint')
parser.add_argument('--save_models', action='store_true',
help='save all intermediate models')
parser.add_argument('--print_every', type=int, default=200,
help='how many batches to wait before logging training status')
parser.add_argument('--save_results', default=True,
help='save output results')
args = parser.parse_args()
template.set_template(args)
args.scale = list(map(lambda x: float(x), args.scale.split('+')))
from model.blindsr import BlindSR
import torch
import numpy as np
import imageio
import argparse
import os
import utility
import cv2
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--img_dir', type=str, default='D:/LongguangWang/Data/test.png',
help='image directory')
parser.add_argument('--scale', type=str, default='2',
help='super resolution scale')
parser.add_argument('--resume', type=int, default=600,
help='resume from specific checkpoint')
parser.add_argument('--blur_type', type=str, default='iso_gaussian',
help='blur types (iso_gaussian | aniso_gaussian)')
return parser.parse_args()
def main():
args = parse_args()
if args.blur_type == 'iso_gaussian':
dir = './experiment/blindsr_x' + str(int(args.scale[0])) + '_bicubic_iso'
elif args.blur_type == 'aniso_gaussian':
dir = './experiment/blindsr_x' + str(int(args.scale[0])) + '_bicubic_aniso'
# path to save sr images
save_dir = dir + '/results'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
DASR = BlindSR(args).cuda()
DASR.load_state_dict(torch.load(dir + '/model/model_' + str(args.resume) + '.pt'), strict=False)
DASR.eval()
lr = imageio.imread(args.img_dir)
lr = np.ascontiguousarray(lr.transpose((2, 0, 1)))
lr = torch.from_numpy(lr).float().cuda().unsqueeze(0).unsqueeze(0)
# inference
sr = DASR(lr[:, 0, ...])
sr = utility.quantize(sr, 255.0)
# save sr results
img_name = args.img_dir.split('.png')[0].split('/')[-1]
sr = np.array(sr.squeeze(0).permute(1, 2, 0).data.cpu())
sr = sr[:, :, [2, 1, 0]]
cv2.imwrite(save_dir + '/' + img_name + '_sr.png', sr)
if __name__ == '__main__':
with torch.no_grad():
main()
\ No newline at end of file
import os
import torch
from option import args
import torchvision.transforms as transforms
import cv2 as cv
import numpy as np
# os.environ["CUDA_VISIBLE_DEVICES"] = '8'
#
# img = cv.imread('Figs/512result.bmp')
# print(img.shape) # numpy数组格式为(H,W,C)
# np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
# img_tensor = torch.from_numpy(np_transpose).float()
# img_tensor.mul_(255 / 255)
# img_tensor = img_tensor.unsqueeze(0).cuda()
# print(img_tensor.size())
from model.blindsr import BlindSR
dasr = BlindSR(args)
dasr.load_state_dict(torch.load('experiment/blindsr_x2_bicubic_iso/model/model_600.pt'), strict=False)
example = torch.ones(1, 3, 96, 96)
dasr.eval()
r1 = dasr(example)
print(r1)
# dasr = dasr.cuda()
# torch.save(dasr, 'experiment/blindsr_x2_bicubic_iso/model/DASR.pt')
# model = torch.load('experiment/blindsr_x2_bicubic_iso/model/DASR.pt')
# with torch.no_grad():
# sr = dasr(img_tensor)
# print(sr.data)
# fea = dasr.E(img_tensor, img_tensor)
# sr = dasr.G(img_tensor, fea)
# normalized = sr[0].data.byte()
# ndarr = normalized.permute(1, 2, 0).cpu().numpy()
# # ndarr = np.transpose(ndarr, (1, 2, 0))
# # cv2.cvtColor(ndarr, cv2.COLOR_BGR2)
# b = ndarr[:, :, 0].copy()
# g = ndarr[:, :, 1].copy()
# r = ndarr[:, :, 2].copy()
#
# ndarr[:, :, 0] = r
# ndarr[:, :, 1] = g
# ndarr[:, :, 2] = b
# print(ndarr.shape)
# cv.imwrite('{}{}.png'.format('result', 'SR'), ndarr)
traced_script_module = torch.jit.trace(dasr, example)
r2 = traced_script_module(example)
print(r2)
traced_script_module.save("model.pt")
def set_template(args):
# Set the templates here
if args.template.find('jpeg') >= 0:
args.data_train = 'DIV2K_jpeg'
args.data_test = 'DIV2K_jpeg'
args.epochs = 200
args.lr_decay = 100
if args.template.find('EDSR_paper') >= 0:
args.model = 'EDSR'
args.n_resblocks = 32
args.n_feats = 256
args.res_scale = 0.1
if args.template.find('MDSR') >= 0:
args.model = 'MDSR'
args.patch_size = 48
args.epochs = 650
if args.template.find('DDBPN') >= 0:
args.model = 'DDBPN'
args.patch_size = 128
args.scale = '4'
args.data_test = 'Set5'
args.batch_size = 20
args.epochs = 1000
args.lr_decay = 500
args.gamma = 0.1
args.weight_decay = 1e-4
args.loss = '1*MSE'
if args.template.find('GAN') >= 0:
args.epochs = 200
args.lr = 5e-5
args.lr_decay = 150
if args.template.find('RCAN') >= 0:
args.model = 'RCAN'
args.n_resgroups = 10
args.n_resblocks = 20
args.n_feats = 64
args.chop = True
import os
from option import args
import torch
import utility
import data
import model
import loss
from trainer import Trainer
os.environ["CUDA_VISIBLE_DEVICES"] = '7, 8'
if __name__ == '__main__':
torch.manual_seed(args.seed)
checkpoint = utility.checkpoint(args)
if checkpoint.ok:
loader = data.Data(args)
model = model.Model(args, checkpoint)
loss = loss.Loss(args, checkpoint) if not args.test_only else None
t = Trainer(args, loader, model, loss, checkpoint)
# while not t.terminate():
t.test()
checkpoint.done()
# import cv2
#
# img = cv2.imread('resultSR.png')
# print(img)
import os
import utility
import torch
from decimal import Decimal
import torch.nn.functional as F
from utils import util
class Trainer():
def __init__(self, args, loader, my_model, my_loss, ckp):
self.args = args
self.scale = args.scale
self.ckp = ckp
self.loader_train = loader.loader_train
self.loader_test = loader.loader_test
self.model = my_model
self.model_E = torch.nn.DataParallel(self.model.get_model().E, range(self.args.n_GPUs))
self.loss = my_loss
self.contrast_loss = torch.nn.CrossEntropyLoss().cuda()
self.optimizer = utility.make_optimizer(args, self.model)
self.scheduler = utility.make_scheduler(args, self.optimizer)
if self.args.load != '.':
self.optimizer.load_state_dict(
torch.load(os.path.join(ckp.dir, 'optimizer.pt'))
)
for _ in range(len(ckp.log)): self.scheduler.step()
def train(self):
self.scheduler.step()
self.loss.step()
epoch = self.scheduler.last_epoch + 1
# lr stepwise
if epoch <= self.args.epochs_encoder:
lr = self.args.lr_encoder * (self.args.gamma_encoder ** (epoch // self.args.lr_decay_encoder))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
else:
lr = self.args.lr_sr * (self.args.gamma_sr ** ((epoch - self.args.epochs_encoder) // self.args.lr_decay_sr))
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)))
self.loss.start_log()
self.model.train()
degrade = util.SRMDPreprocessing(
self.scale[0],
kernel_size=self.args.blur_kernel,
blur_type=self.args.blur_type,
sig_min=self.args.sig_min,
sig_max=self.args.sig_max,
noise=self.args.noise
)
timer = utility.timer()
losses_contrast, losses_sr = utility.AverageMeter(), utility.AverageMeter()
for batch, (hr, _, idx_scale) in enumerate(self.loader_train):
hr = hr.cuda() # b, n, c, h, w
lr, b_kernels = degrade(hr) # bn, c, h, w
self.optimizer.zero_grad()
timer.tic()
# forward
## train degradation encoder
if epoch <= self.args.epochs_encoder:
_, output, target = self.model_E(im_q=lr[:,0,...], im_k=lr[:,1,...])
loss_constrast = self.contrast_loss(output, target)
loss = loss_constrast
losses_contrast.update(loss_constrast.item())
## train the whole network
else:
sr, output, target = self.model(lr)
loss_SR = self.loss(sr, hr[:,0,...])
loss_constrast = self.contrast_loss(output, target)
loss = loss_constrast + loss_SR
losses_sr.update(loss_SR.item())
losses_contrast.update(loss_constrast.item())
# backward
loss.backward()
self.optimizer.step()
timer.hold()
if epoch <= self.args.epochs_encoder:
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log(
'Epoch: [{:03d}][{:04d}/{:04d}]\t'
'Loss [contrastive loss: {:.3f}]\t'
'Time [{:.1f}s]'.format(
epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
losses_contrast.avg,
timer.release()
))
else:
if (batch + 1) % self.args.print_every == 0:
self.ckp.write_log(
'Epoch: [{:04d}][{:04d}/{:04d}]\t'
'Loss [SR loss:{:.3f} | contrastive loss: {:.3f}]\t'
'Time [{:.1f}s]'.format(
epoch, (batch + 1) * self.args.batch_size, len(self.loader_train.dataset),
losses_sr.avg, losses_contrast.avg,
timer.release(),
))
self.loss.end_log(len(self.loader_train))
# save model
target = self.model.get_model()
model_dict = target.state_dict()
keys = list(model_dict.keys())
for key in keys:
if 'E.encoder_k' in key or 'queue' in key:
del model_dict[key]
torch.save(
model_dict,
os.path.join(self.ckp.dir, 'model', 'model_{}.pt'.format(epoch))
)
def test(self):
self.ckp.write_log('\nEvaluation:')
self.ckp.add_log(torch.zeros(1, len(self.scale)))
self.model.eval()
timer_test = utility.timer()
with torch.no_grad():
for idx_scale, scale in enumerate(self.scale):
self.loader_test.dataset.set_scale(idx_scale)
eval_psnr = 0
eval_ssim = 0
degrade = util.SRMDPreprocessing(
self.scale[0],
kernel_size=self.args.blur_kernel,
blur_type=self.args.blur_type,
sig=self.args.sig,
noise=self.args.noise
)
for idx_img, (hr, filename, _) in enumerate(self.loader_test):
hr = hr.cuda() # b, 1, c, h, w
hr = self.crop_border(hr, scale)
# lr, _ = degrade(hr, random=False) # b, 1, c, h, w
lr = hr
hr = hr[:, 0, ...] # b, c, h, w
# inference
timer_test.tic()
# img = torch.ones(1, 3, 96, 96).cuda()
sr = self.model(lr[:, 0, ...])
timer_test.hold()
sr = utility.quantize(sr, self.args.rgb_range)
hr = utility.quantize(hr, self.args.rgb_range)
# metrics
# eval_psnr += utility.calc_psnr(
# sr, hr, scale, self.args.rgb_range,
# benchmark=self.loader_test.dataset.benchmark
# )
# eval_ssim += utility.calc_ssim(
# sr, hr, scale,
# benchmark=self.loader_test.dataset.benchmark
# )
# save results
if self.args.save_results:
save_list = [sr]
filename = filename[0]
self.ckp.save_results(filename, save_list, scale)
# self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test)
# self.ckp.write_log(
# '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f}'.format(
# self.args.resume,
# self.args.data_test,
# scale,
# eval_psnr / len(self.loader_test),
# eval_ssim / len(self.loader_test),
# ))
def crop_border(self, img_hr, scale):
b, n, c, h, w = img_hr.size()
img_hr = img_hr[:, :, :, :int(h//scale*scale), :int(w//scale*scale)]
return img_hr
def terminate(self):
if self.args.test_only:
self.test()
return True
else:
epoch = self.scheduler.last_epoch + 1
return epoch >= self.args.epochs_encoder + self.args.epochs_sr
import os
import math
import time
import datetime
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc as misc
import cv2
import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lrs
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class timer():
def __init__(self):
self.acc = 0
self.tic()
def tic(self):
self.t0 = time.time()
def toc(self):
return time.time() - self.t0
def hold(self):
self.acc += self.toc()
def release(self):
ret = self.acc
self.acc = 0
return ret
def reset(self):
self.acc = 0
class checkpoint():
def __init__(self, args):
self.args = args
self.ok = True
self.log = torch.Tensor()
now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
if args.blur_type == 'iso_gaussian':
self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_iso'
elif args.blur_type == 'aniso_gaussian':
self.dir = './experiment/' + args.save + '_x' + str(int(args.scale[0])) + '_' + args.mode + '_aniso'
def _make_dir(path):
if not os.path.exists(path): os.makedirs(path)
_make_dir(self.dir)
_make_dir(self.dir + '/model')
_make_dir(self.dir + '/results')
open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w'
self.log_file = open(self.dir + '/log.txt', open_type)
with open(self.dir + '/config.txt', open_type) as f:
f.write(now + '\n\n')
for arg in vars(args):
f.write('{}: {}\n'.format(arg, getattr(args, arg)))
f.write('\n')
def save(self, trainer, epoch, is_best=False):
trainer.model.save(self.dir, epoch, is_best=is_best)
trainer.loss.save(self.dir)
trainer.loss.plot_loss(self.dir, epoch)
self.plot_psnr(epoch)
torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt'))
torch.save(
trainer.optimizer.state_dict(),
os.path.join(self.dir, 'optimizer.pt')
)
def add_log(self, log):
self.log = torch.cat([self.log, log])
def write_log(self, log, refresh=False):
print(log)
self.log_file.write(log + '\n')
if refresh:
self.log_file.close()
self.log_file = open(self.dir + '/log.txt', 'a')
def done(self):
self.log_file.close()
def plot_psnr(self, epoch):
axis = np.linspace(1, epoch, epoch)
label = 'SR on {}'.format(self.args.data_test)
fig = plt.figure()
plt.title(label)
for idx_scale, scale in enumerate(self.args.scale):
plt.plot(
axis,
self.log[:, idx_scale].numpy(),
label='Scale {}'.format(scale)
)
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('PSNR')
plt.grid(True)
plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test))
plt.close(fig)
def save_results(self, filename, save_list, scale):
filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale)
normalized = save_list[0][0].data.mul(255 / self.args.rgb_range)
print(255 / self.args.rgb_range)
print(normalized)
ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy()
# ndarr = np.transpose(ndarr, (1, 2, 0))
# cv2.cvtColor(ndarr, cv2.COLOR_BGR2)
b = ndarr[:, :, 0].copy()
g = ndarr[:, :, 1].copy()
r = ndarr[:, :, 2].copy()
ndarr[:, :, 0] = r
ndarr[:, :, 1] = g
ndarr[:, :, 2] = b
cv2.imwrite('{}{}.png'.format(filename, 'SR'), ndarr)
# misc.imsave('{}{}.png'.format(filename, 'SR'), ndarr)
def quantize(img, rgb_range):
pixel_range = 255 / rgb_range
return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range)
def calc_psnr(sr, hr, scale, rgb_range, benchmark=False):
diff = (sr - hr).data.div(rgb_range)
if benchmark:
shave = scale
if diff.size(1) > 1:
convert = diff.new(1, 3, 1, 1)
convert[0, 0, 0, 0] = 65.738
convert[0, 1, 0, 0] = 129.057
convert[0, 2, 0, 0] = 25.064
diff.mul_(convert).div_(256)
diff = diff.sum(dim=1, keepdim=True)
else:
shave = scale + 6
import math
shave = math.ceil(shave)
valid = diff[:, :, shave:-shave, shave:-shave]
mse = valid.pow(2).mean()
return -10 * math.log10(mse)
def calc_ssim(img1, img2, scale=2, benchmark=False):
'''calculate SSIM
the same outputs as MATLAB's
img1, img2: [0, 255]
'''
if benchmark:
border = math.ceil(scale)
else:
border = math.ceil(scale) + 6
img1 = img1.data.squeeze().float().clamp(0, 255).round().cpu().numpy()
img1 = np.transpose(img1, (1, 2, 0))
img2 = img2.data.squeeze().cpu().numpy()
img2 = np.transpose(img2, (1, 2, 0))
img1_y = np.dot(img1, [65.738, 129.057, 25.064]) / 255.0 + 16.0
img2_y = np.dot(img2, [65.738, 129.057, 25.064]) / 255.0 + 16.0
if not img1.shape == img2.shape:
raise ValueError('Input images must have the same dimensions.')
h, w = img1.shape[:2]
img1_y = img1_y[border:h - border, border:w - border]
img2_y = img2_y[border:h - border, border:w - border]
if img1_y.ndim == 2:
return ssim(img1_y, img2_y)
elif img1.ndim == 3:
if img1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(img1, img2))
return np.array(ssims).mean()
elif img1.shape[2] == 1:
return ssim(np.squeeze(img1), np.squeeze(img2))
else:
raise ValueError('Wrong input image dimensions.')
def ssim(img1, img2):
C1 = (0.01 * 255) ** 2
C2 = (0.03 * 255) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def make_optimizer(args, my_model):
trainable = filter(lambda x: x.requires_grad, my_model.parameters())
if args.optimizer == 'SGD':
optimizer_function = optim.SGD
kwargs = {'momentum': args.momentum}
elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam
kwargs = {
'betas': (args.beta1, args.beta2),
'eps': args.epsilon
}
elif args.optimizer == 'RMSprop':
optimizer_function = optim.RMSprop
kwargs = {'eps': args.epsilon}
kwargs['weight_decay'] = args.weight_decay
return optimizer_function(trainable, **kwargs)
def make_scheduler(args, my_optimizer):
if args.decay_type == 'step':
scheduler = lrs.StepLR(
my_optimizer,
step_size=args.lr_decay_sr,
gamma=args.gamma_sr,
)
elif args.decay_type.find('step') >= 0:
milestones = args.decay_type.split('_')
milestones.pop(0)
milestones = list(map(lambda x: int(x), milestones))
scheduler = lrs.MultiStepLR(
my_optimizer,
milestones=milestones,
gamma=args.gamma
)
scheduler.step(args.start_epoch - 1)
return scheduler
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def cal_sigma(sig_x, sig_y, radians):
sig_x = sig_x.view(-1, 1, 1)
sig_y = sig_y.view(-1, 1, 1)
radians = radians.view(-1, 1, 1)
D = torch.cat([F.pad(sig_x ** 2, [0, 1, 0, 0]), F.pad(sig_y ** 2, [1, 0, 0, 0])], 1)
U = torch.cat([torch.cat([radians.cos(), -radians.sin()], 2),
torch.cat([radians.sin(), radians.cos()], 2)], 1)
sigma = torch.bmm(U, torch.bmm(D, U.transpose(1, 2)))
return sigma
def anisotropic_gaussian_kernel(batch, kernel_size, covar):
ax = torch.arange(kernel_size).float().cuda() - kernel_size // 2
xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
xy = torch.stack([xx, yy], -1).view(batch, -1, 2)
inverse_sigma = torch.inverse(covar)
kernel = torch.exp(- 0.5 * (torch.bmm(xy, inverse_sigma) * xy).sum(2)).view(batch, kernel_size, kernel_size)
return kernel / kernel.sum([1, 2], keepdim=True)
def isotropic_gaussian_kernel(batch, kernel_size, sigma):
ax = torch.arange(kernel_size).float().cuda() - kernel_size//2
xx = ax.repeat(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
yy = ax.repeat_interleave(kernel_size).view(1, kernel_size, kernel_size).expand(batch, -1, -1)
kernel = torch.exp(-(xx ** 2 + yy ** 2) / (2. * sigma.view(-1, 1, 1) ** 2))
return kernel / kernel.sum([1,2], keepdim=True)
def random_anisotropic_gaussian_kernel(batch=1, kernel_size=21, lambda_min=0.2, lambda_max=4.0):
theta = torch.rand(batch).cuda() * math.pi
lambda_1 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
lambda_2 = torch.rand(batch).cuda() * (lambda_max - lambda_min) + lambda_min
covar = cal_sigma(lambda_1, lambda_2, theta)
kernel = anisotropic_gaussian_kernel(batch, kernel_size, covar)
return kernel
def stable_anisotropic_gaussian_kernel(kernel_size=21, theta=0, lambda_1=0.2, lambda_2=4.0):
theta = torch.ones(1).cuda() * theta * math.pi
lambda_1 = torch.ones(1).cuda() * lambda_1
lambda_2 = torch.ones(1).cuda() * lambda_2
covar = cal_sigma(lambda_1, lambda_2, theta)
kernel = anisotropic_gaussian_kernel(1, kernel_size, covar)
return kernel
def random_isotropic_gaussian_kernel(batch=1, kernel_size=21, sig_min=0.2, sig_max=4.0):
x = torch.rand(batch).cuda() * (sig_max - sig_min) + sig_min
k = isotropic_gaussian_kernel(batch, kernel_size, x)
return k
def stable_isotropic_gaussian_kernel(kernel_size=21, sig=4.0):
x = torch.ones(1).cuda() * sig
k = isotropic_gaussian_kernel(1, kernel_size, x)
return k
def random_gaussian_kernel(batch, kernel_size=21, blur_type='iso_gaussian', sig_min=0.2, sig_max=4.0, lambda_min=0.2, lambda_max=4.0):
if blur_type == 'iso_gaussian':
return random_isotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, sig_min=sig_min, sig_max=sig_max)
elif blur_type == 'aniso_gaussian':
return random_anisotropic_gaussian_kernel(batch=batch, kernel_size=kernel_size, lambda_min=lambda_min, lambda_max=lambda_max)
def stable_gaussian_kernel(kernel_size=21, blur_type='iso_gaussian', sig=2.6, lambda_1=0.2, lambda_2=4.0, theta=0):
if blur_type == 'iso_gaussian':
return stable_isotropic_gaussian_kernel(kernel_size=kernel_size, sig=sig)
elif blur_type == 'aniso_gaussian':
return stable_anisotropic_gaussian_kernel(kernel_size=kernel_size, lambda_1=lambda_1, lambda_2=lambda_2, theta=theta)
# implementation of matlab bicubic interpolation in pytorch
class bicubic(nn.Module):
def __init__(self):
super(bicubic, self).__init__()
def cubic(self, x):
absx = torch.abs(x)
absx2 = torch.abs(x) * torch.abs(x)
absx3 = torch.abs(x) * torch.abs(x) * torch.abs(x)
condition1 = (absx <= 1).to(torch.float32)
condition2 = ((1 < absx) & (absx <= 2)).to(torch.float32)
f = (1.5 * absx3 - 2.5 * absx2 + 1) * condition1 + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * condition2
return f
def contribute(self, in_size, out_size, scale):
kernel_width = 4
if scale < 1:
kernel_width = 4 / scale
x0 = torch.arange(start=1, end=out_size[0] + 1).to(torch.float32).cuda()
x1 = torch.arange(start=1, end=out_size[1] + 1).to(torch.float32).cuda()
u0 = x0 / scale + 0.5 * (1 - 1 / scale)
u1 = x1 / scale + 0.5 * (1 - 1 / scale)
left0 = torch.floor(u0 - kernel_width / 2)
left1 = torch.floor(u1 - kernel_width / 2)
P = np.ceil(kernel_width) + 2
indice0 = left0.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
indice1 = left1.unsqueeze(1) + torch.arange(start=0, end=P).to(torch.float32).unsqueeze(0).cuda()
mid0 = u0.unsqueeze(1) - indice0.unsqueeze(0)
mid1 = u1.unsqueeze(1) - indice1.unsqueeze(0)
if scale < 1:
weight0 = scale * self.cubic(mid0 * scale)
weight1 = scale * self.cubic(mid1 * scale)
else:
weight0 = self.cubic(mid0)
weight1 = self.cubic(mid1)
weight0 = weight0 / (torch.sum(weight0, 2).unsqueeze(2))
weight1 = weight1 / (torch.sum(weight1, 2).unsqueeze(2))
indice0 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice0), torch.FloatTensor([in_size[0]]).cuda()).unsqueeze(0)
indice1 = torch.min(torch.max(torch.FloatTensor([1]).cuda(), indice1), torch.FloatTensor([in_size[1]]).cuda()).unsqueeze(0)
kill0 = torch.eq(weight0, 0)[0][0]
kill1 = torch.eq(weight1, 0)[0][0]
weight0 = weight0[:, :, kill0 == 0]
weight1 = weight1[:, :, kill1 == 0]
indice0 = indice0[:, :, kill0 == 0]
indice1 = indice1[:, :, kill1 == 0]
return weight0, weight1, indice0, indice1
def forward(self, input, scale=1/4):
b, c, h, w = input.shape
weight0, weight1, indice0, indice1 = self.contribute([h, w], [int(h * scale), int(w * scale)], scale)
weight0 = weight0[0]
weight1 = weight1[0]
indice0 = indice0[0].long()
indice1 = indice1[0].long()
out = input[:, :, (indice0 - 1), :] * (weight0.unsqueeze(0).unsqueeze(1).unsqueeze(4))
out = (torch.sum(out, dim=3))
A = out.permute(0, 1, 3, 2)
out = A[:, :, (indice1 - 1), :] * (weight1.unsqueeze(0).unsqueeze(1).unsqueeze(4))
out = out.sum(3).permute(0, 1, 3, 2)
return out
class Gaussin_Kernel(object):
def __init__(self, kernel_size=21, blur_type='iso_gaussian',
sig=2.6, sig_min=0.2, sig_max=4.0,
lambda_1=0.2, lambda_2=4.0, theta=0, lambda_min=0.2, lambda_max=4.0):
self.kernel_size = kernel_size
self.blur_type = blur_type
self.sig = sig
self.sig_min = sig_min
self.sig_max = sig_max
self.lambda_1 = lambda_1
self.lambda_2 = lambda_2
self.theta = theta
self.lambda_min = lambda_min
self.lambda_max = lambda_max
def __call__(self, batch, random):
# random kernel
if random == True:
return random_gaussian_kernel(batch, kernel_size=self.kernel_size, blur_type=self.blur_type,
sig_min=self.sig_min, sig_max=self.sig_max,
lambda_min=self.lambda_min, lambda_max=self.lambda_max)
# stable kernel
else:
return stable_gaussian_kernel(kernel_size=self.kernel_size, blur_type=self.blur_type,
sig=self.sig,
lambda_1=self.lambda_1, lambda_2=self.lambda_2, theta=self.theta)
class BatchBlur(nn.Module):
def __init__(self, kernel_size=21):
super(BatchBlur, self).__init__()
self.kernel_size = kernel_size
if kernel_size % 2 == 1:
self.pad = nn.ReflectionPad2d(kernel_size//2)
else:
self.pad = nn.ReflectionPad2d((kernel_size//2, kernel_size//2-1, kernel_size//2, kernel_size//2-1))
def forward(self, input, kernel):
B, C, H, W = input.size()
input_pad = self.pad(input)
H_p, W_p = input_pad.size()[-2:]
if len(kernel.size()) == 2:
input_CBHW = input_pad.view((C * B, 1, H_p, W_p))
kernel = kernel.contiguous().view((1, 1, self.kernel_size, self.kernel_size))
return F.conv2d(input_CBHW, kernel, padding=0).view((B, C, H, W))
else:
input_CBHW = input_pad.view((1, C * B, H_p, W_p))
kernel = kernel.contiguous().view((B, 1, self.kernel_size, self.kernel_size))
kernel = kernel.repeat(1, C, 1, 1).view((B * C, 1, self.kernel_size, self.kernel_size))
return F.conv2d(input_CBHW, kernel, groups=B*C).view((B, C, H, W))
class SRMDPreprocessing(object):
def __init__(self,
scale,
mode='bicubic',
kernel_size=21,
blur_type='iso_gaussian',
sig=2.6,
sig_min=0.2,
sig_max=4.0,
lambda_1=0.2,
lambda_2=4.0,
theta=0,
lambda_min=0.2,
lambda_max=4.0,
noise=0.0
):
'''
# sig, sig_min and sig_max are used for isotropic Gaussian blurs
During training phase (random=True):
the width of the blur kernel is randomly selected from [sig_min, sig_max]
During test phase (random=False):
the width of the blur kernel is set to sig
# lambda_1, lambda_2, theta, lambda_min and lambda_max are used for anisotropic Gaussian blurs
During training phase (random=True):
the eigenvalues of the covariance is randomly selected from [lambda_min, lambda_max]
the angle value is randomly selected from [0, pi]
During test phase (random=False):
the eigenvalues of the covariance are set to lambda_1 and lambda_2
the angle value is set to theta
'''
self.kernel_size = kernel_size
self.scale = scale
self.mode = mode
self.noise = noise
self.gen_kernel = Gaussin_Kernel(
kernel_size=kernel_size, blur_type=blur_type,
sig=sig, sig_min=sig_min, sig_max=sig_max,
lambda_1=lambda_1, lambda_2=lambda_2, theta=theta, lambda_min=lambda_min, lambda_max=lambda_max
)
self.blur = BatchBlur(kernel_size=kernel_size)
self.bicubic = bicubic()
def __call__(self, hr_tensor, random=True):
with torch.no_grad():
# only downsampling
if self.gen_kernel.blur_type == 'iso_gaussian' and self.gen_kernel.sig == 0:
B, N, C, H, W = hr_tensor.size()
hr_blured = hr_tensor.view(-1, C, H, W)
b_kernels = None
# gaussian blur + downsampling
else:
B, N, C, H, W = hr_tensor.size()
b_kernels = self.gen_kernel(B, random) # B degradations
# blur
hr_blured = self.blur(hr_tensor.view(B, -1, H, W), b_kernels)
hr_blured = hr_blured.view(-1, C, H, W) # BN, C, H, W
# downsampling
if self.mode == 'bicubic':
lr_blured = self.bicubic(hr_blured, scale=1/self.scale)
elif self.mode == 's-fold':
lr_blured = hr_blured.view(-1, C, H//self.scale, self.scale, W//self.scale, self.scale)[:, :, :, 0, :, 0]
# add noise
if self.noise > 0:
_, C, H_lr, W_lr = lr_blured.size()
noise_level = torch.rand(B, 1, 1, 1, 1).to(lr_blured.device) * self.noise if random else self.noise
noise = torch.randn_like(lr_blured).view(-1, N, C, H_lr, W_lr).mul_(noise_level).view(-1, C, H_lr, W_lr)
lr_blured.add_(noise)
lr_blured = torch.clamp(lr_blured.round(), 0, 255)
return lr_blured.view(B, N, C, H//int(self.scale), W//int(self.scale)), b_kernels
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment