|
|
|
|
|
|
|
|
|
import os |
|
import torch |
|
from torchvision.utils import save_image |
|
import tempfile |
|
from templates import * |
|
from templates_cls import * |
|
from experiment_classifier import ClsModel |
|
from align import LandmarksDetector, image_align |
|
from cog import BasePredictor, Path, Input, BaseModel |
|
|
|
|
|
class ModelOutput(BaseModel): |
|
image: Path |
|
|
|
|
|
class Predictor(BasePredictor): |
|
def setup(self): |
|
self.aligned_dir = "aligned" |
|
os.makedirs(self.aligned_dir, exist_ok=True) |
|
self.device = "cuda:0" |
|
|
|
|
|
model_config = ffhq256_autoenc() |
|
self.model = LitModel(model_config) |
|
state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu") |
|
self.model.load_state_dict(state["state_dict"], strict=False) |
|
self.model.ema_model.eval() |
|
self.model.ema_model.to(self.device) |
|
|
|
|
|
classifier_config = ffhq256_autoenc_cls() |
|
classifier_config.pretrain = None |
|
self.classifier = ClsModel(classifier_config) |
|
state_class = torch.load( |
|
"checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu" |
|
) |
|
print("latent step:", state_class["global_step"]) |
|
self.classifier.load_state_dict(state_class["state_dict"], strict=False) |
|
self.classifier.to(self.device) |
|
|
|
self.landmarks_detector = LandmarksDetector( |
|
"shape_predictor_68_face_landmarks.dat" |
|
) |
|
|
|
def predict( |
|
self, |
|
image: Path = Input( |
|
description="Input image for face manipulation. Image will be aligned and cropped, " |
|
"output aligned and manipulated images.", |
|
), |
|
target_class: str = Input( |
|
default="Bangs", |
|
choices=[ |
|
"5_o_Clock_Shadow", |
|
"Arched_Eyebrows", |
|
"Attractive", |
|
"Bags_Under_Eyes", |
|
"Bald", |
|
"Bangs", |
|
"Big_Lips", |
|
"Big_Nose", |
|
"Black_Hair", |
|
"Blond_Hair", |
|
"Blurry", |
|
"Brown_Hair", |
|
"Bushy_Eyebrows", |
|
"Chubby", |
|
"Double_Chin", |
|
"Eyeglasses", |
|
"Goatee", |
|
"Gray_Hair", |
|
"Heavy_Makeup", |
|
"High_Cheekbones", |
|
"Male", |
|
"Mouth_Slightly_Open", |
|
"Mustache", |
|
"Narrow_Eyes", |
|
"Beard", |
|
"Oval_Face", |
|
"Pale_Skin", |
|
"Pointy_Nose", |
|
"Receding_Hairline", |
|
"Rosy_Cheeks", |
|
"Sideburns", |
|
"Smiling", |
|
"Straight_Hair", |
|
"Wavy_Hair", |
|
"Wearing_Earrings", |
|
"Wearing_Hat", |
|
"Wearing_Lipstick", |
|
"Wearing_Necklace", |
|
"Wearing_Necktie", |
|
"Young", |
|
], |
|
description="Choose manipulation direction.", |
|
), |
|
manipulation_amplitude: float = Input( |
|
default=0.3, |
|
ge=-0.5, |
|
le=0.5, |
|
description="When set too strong it would result in artifact as it could dominate the original image information.", |
|
), |
|
T_step: int = Input( |
|
default=100, |
|
choices=[50, 100, 125, 200, 250, 500], |
|
description="Number of step for generation.", |
|
), |
|
T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]), |
|
) -> List[ModelOutput]: |
|
|
|
img_size = 256 |
|
print("Aligning image...") |
|
for i, face_landmarks in enumerate( |
|
self.landmarks_detector.get_landmarks(str(image)), start=1 |
|
): |
|
image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks) |
|
|
|
data = ImageDataset( |
|
self.aligned_dir, |
|
image_size=img_size, |
|
exts=["jpg", "jpeg", "JPG", "png"], |
|
do_augment=False, |
|
) |
|
|
|
print("Encoding and Manipulating the aligned image...") |
|
cls_manipulation_amplitude = manipulation_amplitude |
|
interpreted_target_class = target_class |
|
if ( |
|
target_class not in CelebAttrDataset.id_to_cls |
|
and f"No_{target_class}" in CelebAttrDataset.id_to_cls |
|
): |
|
cls_manipulation_amplitude = -manipulation_amplitude |
|
interpreted_target_class = f"No_{target_class}" |
|
|
|
batch = data[0]["img"][None] |
|
|
|
semantic_latent = self.model.encode(batch.to(self.device)) |
|
stochastic_latent = self.model.encode_stochastic( |
|
batch.to(self.device), semantic_latent, T=T_inv |
|
) |
|
|
|
cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class] |
|
class_direction = self.classifier.classifier.weight[cls_id] |
|
normalized_class_direction = F.normalize(class_direction[None, :], dim=1) |
|
|
|
normalized_semantic_latent = self.classifier.normalize(semantic_latent) |
|
normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512) |
|
normalized_manipulated_semantic_latent = ( |
|
normalized_semantic_latent |
|
+ normalized_manipulation_amp * normalized_class_direction |
|
) |
|
|
|
manipulated_semantic_latent = self.classifier.denormalize( |
|
normalized_manipulated_semantic_latent |
|
) |
|
|
|
|
|
manipulated_img = self.model.render( |
|
stochastic_latent, manipulated_semantic_latent, T=T_step |
|
)[0] |
|
original_img = data[0]["img"] |
|
|
|
model_output = [] |
|
out_path = Path(tempfile.mkdtemp()) / "original_aligned.png" |
|
save_image(convert2rgb(original_img), str(out_path)) |
|
model_output.append(ModelOutput(image=out_path)) |
|
|
|
out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png" |
|
save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path)) |
|
model_output.append(ModelOutput(image=out_path)) |
|
return model_output |
|
|
|
|
|
def convert2rgb(img, adjust_scale=True): |
|
convert_img = torch.tensor(img) |
|
if adjust_scale: |
|
convert_img = (convert_img + 1) / 2 |
|
return convert_img.cpu() |
|
|