|
from typing import List, Optional |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoProcessor, AutoModel |
|
from safetensors.torch import load_file |
|
import os |
|
from typing import Union, List |
|
from .config import MODEL_PATHS |
|
|
|
class MLP(torch.nn.Module): |
|
def __init__(self, input_size: int, xcol: str = "emb", ycol: str = "avg_rating"): |
|
super().__init__() |
|
self.input_size = input_size |
|
self.xcol = xcol |
|
self.ycol = ycol |
|
self.layers = torch.nn.Sequential( |
|
torch.nn.Linear(self.input_size, 1024), |
|
|
|
torch.nn.Dropout(0.2), |
|
torch.nn.Linear(1024, 128), |
|
|
|
torch.nn.Dropout(0.2), |
|
torch.nn.Linear(128, 64), |
|
|
|
torch.nn.Dropout(0.1), |
|
torch.nn.Linear(64, 16), |
|
|
|
torch.nn.Linear(16, 1), |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.layers(x) |
|
|
|
def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: |
|
x = batch[self.xcol] |
|
y = batch[self.ycol].reshape(-1, 1) |
|
x_hat = self.layers(x) |
|
loss = torch.nn.functional.mse_loss(x_hat, y) |
|
return loss |
|
|
|
def validation_step(self, batch: dict, batch_idx: int) -> torch.Tensor: |
|
x = batch[self.xcol] |
|
y = batch[self.ycol].reshape(-1, 1) |
|
x_hat = self.layers(x) |
|
loss = torch.nn.functional.mse_loss(x_hat, y) |
|
return loss |
|
|
|
def configure_optimizers(self) -> torch.optim.Optimizer: |
|
return torch.optim.Adam(self.parameters(), lr=1e-3) |
|
|
|
|
|
class AestheticScore(torch.nn.Module): |
|
def __init__(self, device: torch.device, path: str = MODEL_PATHS): |
|
super().__init__() |
|
self.device = device |
|
self.aes_model_path = path.get("aesthetic_predictor") |
|
|
|
self.model = MLP(768) |
|
try: |
|
if self.aes_model_path.endswith(".safetensors"): |
|
state_dict = load_file(self.aes_model_path) |
|
else: |
|
state_dict = torch.load(self.aes_model_path) |
|
self.model.load_state_dict(state_dict) |
|
except Exception as e: |
|
raise ValueError(f"Error loading model weights from {self.aes_model_path}: {e}") |
|
|
|
self.model.to(device) |
|
self.model.eval() |
|
|
|
|
|
clip_model_name = path.get('clip-large') |
|
self.model2 = AutoModel.from_pretrained(clip_model_name).eval().to(device) |
|
self.processor = AutoProcessor.from_pretrained(clip_model_name) |
|
|
|
def _calculate_score(self, image: torch.Tensor) -> float: |
|
"""Calculate the aesthetic score for a single image. |
|
|
|
Args: |
|
image (torch.Tensor): The processed image tensor. |
|
|
|
Returns: |
|
float: The aesthetic score. |
|
""" |
|
with torch.no_grad(): |
|
|
|
image_embs = self.model2.get_image_features(image) |
|
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True) |
|
|
|
|
|
score = self.model(image_embs).cpu().flatten().item() |
|
|
|
return score |
|
|
|
@torch.no_grad() |
|
def score(self, images: Union[str, List[str], Image.Image, List[Image.Image]], prompt: str = "") -> List[float]: |
|
"""Score the images based on their aesthetic quality. |
|
|
|
Args: |
|
images (Union[str, List[str], Image.Image, List[Image.Image]]): Path(s) to the image(s) or PIL image(s). |
|
|
|
Returns: |
|
List[float]: List of scores for the images. |
|
""" |
|
try: |
|
if isinstance(images, (str, Image.Image)): |
|
|
|
if isinstance(images, str): |
|
pil_image = Image.open(images) |
|
else: |
|
pil_image = images |
|
|
|
|
|
image_inputs = self.processor( |
|
images=pil_image, |
|
padding=True, |
|
truncation=True, |
|
max_length=77, |
|
return_tensors="pt", |
|
).to(self.device) |
|
|
|
return [self._calculate_score(image_inputs["pixel_values"])] |
|
elif isinstance(images, list): |
|
|
|
scores = [] |
|
for one_image in images: |
|
if isinstance(one_image, str): |
|
pil_image = Image.open(one_image) |
|
elif isinstance(one_image, Image.Image): |
|
pil_image = one_image |
|
else: |
|
raise TypeError("The type of parameter images is illegal.") |
|
|
|
|
|
image_inputs = self.processor( |
|
images=pil_image, |
|
padding=True, |
|
truncation=True, |
|
max_length=77, |
|
return_tensors="pt", |
|
).to(self.device) |
|
|
|
scores.append(self._calculate_score(image_inputs["pixel_values"])) |
|
return scores |
|
else: |
|
raise TypeError("The type of parameter images is illegal.") |
|
except Exception as e: |
|
raise RuntimeError(f"Error in scoring images: {e}") |
|
|