diffae / predict.py
qninhdt's picture
Upload 68 files
1ab03a3 verified
raw
history blame
6.51 kB
# pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls
# wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
# bunzip2 shape_predictor_68_face_landmarks.dat.bz2
import os
import torch
from torchvision.utils import save_image
import tempfile
from templates import *
from templates_cls import *
from experiment_classifier import ClsModel
from align import LandmarksDetector, image_align
from cog import BasePredictor, Path, Input, BaseModel
class ModelOutput(BaseModel):
image: Path
class Predictor(BasePredictor):
def setup(self):
self.aligned_dir = "aligned"
os.makedirs(self.aligned_dir, exist_ok=True)
self.device = "cuda:0"
# Model Initialization
model_config = ffhq256_autoenc()
self.model = LitModel(model_config)
state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu")
self.model.load_state_dict(state["state_dict"], strict=False)
self.model.ema_model.eval()
self.model.ema_model.to(self.device)
# Classifier Initialization
classifier_config = ffhq256_autoenc_cls()
classifier_config.pretrain = None # a bit faster
self.classifier = ClsModel(classifier_config)
state_class = torch.load(
"checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu"
)
print("latent step:", state_class["global_step"])
self.classifier.load_state_dict(state_class["state_dict"], strict=False)
self.classifier.to(self.device)
self.landmarks_detector = LandmarksDetector(
"shape_predictor_68_face_landmarks.dat"
)
def predict(
self,
image: Path = Input(
description="Input image for face manipulation. Image will be aligned and cropped, "
"output aligned and manipulated images.",
),
target_class: str = Input(
default="Bangs",
choices=[
"5_o_Clock_Shadow",
"Arched_Eyebrows",
"Attractive",
"Bags_Under_Eyes",
"Bald",
"Bangs",
"Big_Lips",
"Big_Nose",
"Black_Hair",
"Blond_Hair",
"Blurry",
"Brown_Hair",
"Bushy_Eyebrows",
"Chubby",
"Double_Chin",
"Eyeglasses",
"Goatee",
"Gray_Hair",
"Heavy_Makeup",
"High_Cheekbones",
"Male",
"Mouth_Slightly_Open",
"Mustache",
"Narrow_Eyes",
"Beard",
"Oval_Face",
"Pale_Skin",
"Pointy_Nose",
"Receding_Hairline",
"Rosy_Cheeks",
"Sideburns",
"Smiling",
"Straight_Hair",
"Wavy_Hair",
"Wearing_Earrings",
"Wearing_Hat",
"Wearing_Lipstick",
"Wearing_Necklace",
"Wearing_Necktie",
"Young",
],
description="Choose manipulation direction.",
),
manipulation_amplitude: float = Input(
default=0.3,
ge=-0.5,
le=0.5,
description="When set too strong it would result in artifact as it could dominate the original image information.",
),
T_step: int = Input(
default=100,
choices=[50, 100, 125, 200, 250, 500],
description="Number of step for generation.",
),
T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]),
) -> List[ModelOutput]:
img_size = 256
print("Aligning image...")
for i, face_landmarks in enumerate(
self.landmarks_detector.get_landmarks(str(image)), start=1
):
image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks)
data = ImageDataset(
self.aligned_dir,
image_size=img_size,
exts=["jpg", "jpeg", "JPG", "png"],
do_augment=False,
)
print("Encoding and Manipulating the aligned image...")
cls_manipulation_amplitude = manipulation_amplitude
interpreted_target_class = target_class
if (
target_class not in CelebAttrDataset.id_to_cls
and f"No_{target_class}" in CelebAttrDataset.id_to_cls
):
cls_manipulation_amplitude = -manipulation_amplitude
interpreted_target_class = f"No_{target_class}"
batch = data[0]["img"][None]
semantic_latent = self.model.encode(batch.to(self.device))
stochastic_latent = self.model.encode_stochastic(
batch.to(self.device), semantic_latent, T=T_inv
)
cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class]
class_direction = self.classifier.classifier.weight[cls_id]
normalized_class_direction = F.normalize(class_direction[None, :], dim=1)
normalized_semantic_latent = self.classifier.normalize(semantic_latent)
normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512)
normalized_manipulated_semantic_latent = (
normalized_semantic_latent
+ normalized_manipulation_amp * normalized_class_direction
)
manipulated_semantic_latent = self.classifier.denormalize(
normalized_manipulated_semantic_latent
)
# Render Manipulated image
manipulated_img = self.model.render(
stochastic_latent, manipulated_semantic_latent, T=T_step
)[0]
original_img = data[0]["img"]
model_output = []
out_path = Path(tempfile.mkdtemp()) / "original_aligned.png"
save_image(convert2rgb(original_img), str(out_path))
model_output.append(ModelOutput(image=out_path))
out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png"
save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path))
model_output.append(ModelOutput(image=out_path))
return model_output
def convert2rgb(img, adjust_scale=True):
convert_img = torch.tensor(img)
if adjust_scale:
convert_img = (convert_img + 1) / 2
return convert_img.cpu()