| import torchvision.utils as vutils | |
| import argparse | |
| import yaml | |
| import torch | |
| import torchvision | |
| from metrics import calculate_psnr, calculate_ssim | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| import math | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def calculate_metrics(imgs_1, imgs_2): | |
| psnrs = [] | |
| ssims = [] | |
| assert imgs_1.shape[0] == imgs_2.shape[0] | |
| batch_size = imgs_1.shape[0] | |
| for i in range(batch_size): | |
| img1 = imgs_1[i] | |
| img2 = imgs_2[i] | |
| img1 = np.asarray(transforms.ToPILImage()(img1)) | |
| img2 = np.asarray(transforms.ToPILImage()(img2)) | |
| psnr = calculate_psnr(img1, img2, 0) | |
| ssim = calculate_ssim(img1, img2, 0) | |
| psnrs.append(psnr) | |
| ssims.append(ssim) | |
| return np.asarray(psnrs).mean(), np.asarray(ssims).mean() | |
| def read_args(config_file): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", default=config_file) | |
| file = open(config_file) | |
| config = yaml.safe_load(file) | |
| for k, v in config.items(): | |
| parser.add_argument(f"--{k}", default=v) | |
| return parser | |
| def save_checkpoint(state, filename): | |
| torch.save(state, filename) | |
| class CosineAnnealingWarmRestarts(_LRScheduler): | |
| r"""Set the learning rate of each parameter group using a cosine annealing | |
| schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` | |
| is the number of epochs since the last restart and :math:`T_{i}` is the number | |
| of epochs between two warm restarts in SGDR: | |
| .. math:: | |
| \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + | |
| \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) | |
| When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. | |
| When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. | |
| It has been proposed in | |
| `SGDR: Stochastic Gradient Descent with Warm Restarts`_. | |
| Args: | |
| optimizer (Optimizer): Wrapped optimizer. | |
| T_0 (int): Number of iterations for the first restart. | |
| T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. | |
| eta_min (float, optional): Minimum learning rate. Default: 0. | |
| last_epoch (int, optional): The index of last epoch. Default: -1. | |
| verbose (bool): If ``True``, prints a message to stdout for | |
| each update. Default: ``False``. | |
| .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: | |
| https://arxiv.org/abs/1608.03983 | |
| """ | |
| def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False): | |
| if T_0 <= 0 or not isinstance(T_0, int): | |
| raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) | |
| if T_mult < 1 or not isinstance(T_mult, int): | |
| raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) | |
| self.T_0 = T_0 | |
| self.T_i = T_0 | |
| self.T_mult = T_mult | |
| self.eta_min = eta_min | |
| self.T_cur = 0 if last_epoch < 0 else last_epoch | |
| super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose) | |
| def get_lr(self): | |
| if not self._get_lr_called_within_step: | |
| warnings.warn("To get the last learning rate computed by the scheduler, " | |
| "please use `get_last_lr()`.", UserWarning) | |
| return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 | |
| for base_lr in self.base_lrs] | |
| def step(self, epoch=None): | |
| """Step could be called after every batch update | |
| Example: | |
| >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) | |
| >>> iters = len(dataloader) | |
| >>> for epoch in range(20): | |
| >>> for i, sample in enumerate(dataloader): | |
| >>> inputs, labels = sample['inputs'], sample['labels'] | |
| >>> optimizer.zero_grad() | |
| >>> outputs = net(inputs) | |
| >>> loss = criterion(outputs, labels) | |
| >>> loss.backward() | |
| >>> optimizer.step() | |
| >>> scheduler.step(epoch + i / iters) | |
| This function can be called in an interleaved way. | |
| Example: | |
| >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) | |
| >>> for epoch in range(20): | |
| >>> scheduler.step() | |
| >>> scheduler.step(26) | |
| >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) | |
| """ | |
| if epoch is None and self.last_epoch < 0: | |
| epoch = 0 | |
| if epoch is None: | |
| epoch = self.last_epoch + 1 | |
| self.T_cur = self.T_cur + 1 | |
| if self.T_cur >= self.T_i: | |
| self.T_cur = self.T_cur - self.T_i | |
| self.T_i = self.T_i * self.T_mult | |
| else: | |
| if epoch < 0: | |
| raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) | |
| if epoch >= self.T_0: | |
| if self.T_mult == 1: | |
| self.T_cur = epoch % self.T_0 | |
| else: | |
| n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) | |
| self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) | |
| self.T_i = self.T_0 * self.T_mult ** (n) | |
| else: | |
| self.T_i = self.T_0 | |
| self.T_cur = epoch | |
| self.last_epoch = math.floor(epoch) | |
| class _enable_get_lr_call: | |
| def __init__(self, o): | |
| self.o = o | |
| def __enter__(self): | |
| self.o._get_lr_called_within_step = True | |
| return self | |
| def __exit__(self, type, value, traceback): | |
| self.o._get_lr_called_within_step = False | |
| return self | |
| with _enable_get_lr_call(self): | |
| for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): | |
| param_group, lr = data | |
| param_group['lr'] = lr | |
| self.print_lr(self.verbose, i, lr, epoch) | |
| self._last_lr = [group['lr'] for group in self.optimizer.param_groups] | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |