File size: 3,322 Bytes
			
			| 1fae98d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | from typing import List, Tuple
from scipy import interpolate
import numpy as np
import torch
import matplotlib.pyplot as plt
from IPython.display import clear_output
import abc
class GuideModel(torch.nn.Module, abc.ABC):
    def __init__(self) -> None:
        super().__init__()
    @abc.abstractmethod
    def preprocess(self, x_img):
        pass
    @abc.abstractmethod
    def compute_loss(self, inp):
        pass
class Guider(torch.nn.Module):
    def __init__(self, sampler, guide_model, scale=1.0, verbose=False):
        """Apply classifier guidance
        Specify a guidance scale as either a scalar
        Or a schedule as a list of tuples t = 0->1 and scale, e.g.
        [(0, 10), (0.5, 20), (1, 50)]
        """
        super().__init__()
        self.sampler = sampler
        self.index = 0
        self.show = verbose
        self.guide_model = guide_model
        self.history = []
        if isinstance(scale, (Tuple, List)):
            times = np.array([x[0] for x in scale])
            values = np.array([x[1] for x in scale])
            self.scale_schedule = {"times": times, "values": values}
        else:
            self.scale_schedule = float(scale)
        self.ddim_timesteps = sampler.ddim_timesteps
        self.ddpm_num_timesteps = sampler.ddpm_num_timesteps
    def get_scales(self):
        if isinstance(self.scale_schedule, float):
            return len(self.ddim_timesteps)*[self.scale_schedule]
        interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"])
        fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps
        return interpolater(fractional_steps)
    def modify_score(self, model, e_t, x, t, c):
        # TODO look up index by t
        scale = self.get_scales()[self.index]
        if (scale == 0):
            return e_t
        sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device)
        with torch.enable_grad():
            x_in = x.detach().requires_grad_(True)
            pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t)
            x_img = model.first_stage_model.decode((1/0.18215)*pred_x0)
            inp = self.guide_model.preprocess(x_img)
            loss = self.guide_model.compute_loss(inp)
            grads = torch.autograd.grad(loss.sum(), x_in)[0]
            correction = grads * scale
            if self.show:
                clear_output(wait=True)
                print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item())
                self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()])
                plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2)
                plt.axis('off')
                plt.show()
                plt.imshow(correction[0][0].detach().cpu())
                plt.axis('off')
                plt.show()
        e_t_mod = e_t - sqrt_1ma*correction
        if self.show:
            fig, axs = plt.subplots(1, 3)
            axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2)
            axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2)
            axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2)
            plt.show()
        self.index += 1
        return e_t_mod | 
