|
import torch |
|
import torch.nn as nn |
|
import os |
|
import numpy as np |
|
from diffusers.models.attention_processor import Attention |
|
|
|
class VisualTokenSelfAttn(torch.nn.Module): |
|
def __init__(self, in_dim=2792, out_dim=768, num_heads=8): |
|
super().__init__() |
|
|
|
self.meta_token_trans = nn.Sequential( |
|
nn.Linear(in_dim, out_dim * 4), |
|
nn.LayerNorm(out_dim * 4), |
|
nn.GELU(), |
|
nn.Linear(out_dim * 4, out_dim), |
|
nn.LayerNorm(out_dim) |
|
) |
|
|
|
self.norm1 = nn.LayerNorm(out_dim, eps=1e-6) |
|
self.attn = Attention(query_dim=out_dim, heads=num_heads) |
|
self.norm2 = nn.LayerNorm(out_dim, eps=1e-6) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(out_dim, out_dim * 4), |
|
nn.GELU(), |
|
nn.Linear(out_dim * 4, out_dim) |
|
) |
|
|
|
def forward(self, x): |
|
x = self.meta_token_trans(x) |
|
x = x + self.attn(self.norm1(x)) |
|
x = x + self.mlp(self.norm2(x)) |
|
return x |
|
|
|
|
|
class EmotionEmbedding(nn.Module): |
|
def __init__(self, emotions, prompts_dir, feature_names, output_dim, prompt_len=16): |
|
super().__init__() |
|
|
|
input_dim = self.get_input_dim(feature_names=feature_names) |
|
self.self_attn = VisualTokenSelfAttn(in_dim=input_dim, out_dim=output_dim) |
|
|
|
self.emotions = emotions |
|
self.emotion2idx = {emotion: idx for idx, emotion in enumerate(emotions)} |
|
self.emotion_params = nn.ParameterList() |
|
|
|
self.emotion_init_features = self.get_features(emotions, prompts_dir, feature_names, prompt_len) |
|
|
|
for emotion in self.emotions: |
|
init_params = self.emotion_init_features[emotion] |
|
|
|
param = nn.Parameter(init_params) |
|
self.emotion_params.append(param) |
|
|
|
def get_features(self, emotions, prompts_dir, feature_names, prompt_len): |
|
emotion_init_features = {} |
|
for emotion in emotions: |
|
emotion_features = [] |
|
for feature_name in feature_names: |
|
features = np.load(os.path.join(prompts_dir, f'{emotion}_{feature_name}.npy'), allow_pickle=True) |
|
emotion_features.append(features) |
|
emotion_features = np.concatenate(emotion_features, axis=1) |
|
|
|
from sklearn.cluster import KMeans |
|
kmeans = KMeans(n_clusters=prompt_len, random_state=42) |
|
kmeans.fit_predict(emotion_features) |
|
token = torch.tensor(kmeans.cluster_centers_).unsqueeze(0) |
|
|
|
emotion_init_features[emotion] = token |
|
return emotion_init_features |
|
|
|
def get_input_dim(self, feature_names): |
|
if feature_names == ["clip"]: |
|
in_dim = 768 |
|
elif feature_names == ["vgg"]: |
|
in_dim = 1000 |
|
elif feature_names == ["dinov2"]: |
|
in_dim = 1024 |
|
elif feature_names == ["clip", "vgg"]: |
|
in_dim = 1768 |
|
elif feature_names == ["clip", "dinov2"]: |
|
in_dim = 1768 |
|
elif feature_names == ["vgg", "dinov2"]: |
|
in_dim = 2024 |
|
elif feature_names == ["clip", "vgg", "dinov2"]: |
|
in_dim = 2792 |
|
else: |
|
raise ValueError("Invalid feature names") |
|
return in_dim |
|
|
|
def params_to_prompts(self): |
|
self.emotion_prompts = {} |
|
for emotion in self.emotions: |
|
prompt = self.self_attn(self.emotion_params[self.emotion2idx[emotion]]) |
|
prompt = prompt.squeeze(0) |
|
self.emotion_prompts[emotion] = prompt |
|
|
|
def forward(self, emotion): |
|
if isinstance(emotion, str): |
|
emotions = [emotion] |
|
else: |
|
emotions = emotion |
|
|
|
self.params_to_prompts() |
|
selected_prompts = [self.emotion_prompts[emotion] for emotion in emotions] |
|
prompts = torch.stack(selected_prompts, dim=0) |
|
del self.emotion_prompts |
|
|
|
return prompts |
|
|
|
class EmotionEmbedding2(nn.Module): |
|
def __init__(self, emotions, input_dim, output_dim): |
|
super().__init__() |
|
self.self_attn = VisualTokenSelfAttn(in_dim=input_dim, out_dim=output_dim) |
|
self.emotions = emotions |
|
self.emotion2idx = {emotion: idx for idx, emotion in enumerate(emotions)} |
|
self.emotion_params = nn.Embedding(len(emotions), input_dim) |
|
|
|
def forward(self, emotion): |
|
if isinstance(emotion, str): |
|
emotions = [emotion] |
|
else: |
|
emotions = emotion |
|
|
|
emotions = [self.emotion2idx[emotion] for emotion in emotions] |
|
emotions = torch.tensor(emotions, device=self.emotion_params.weight.device) |
|
prompts = self.emotion_params(emotions).unsqueeze(1) |
|
prompts = self.self_attn(prompts) |
|
return prompts |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
emotions = ["amusement", "anger", "awe", "contentment", |
|
"disgust", "excitement", "fear", "sadness"] |
|
prompts_dir = "features/origin" |
|
model = EmotionEmbedding2(emotions, input_dim=2048, output_dim=2048, prompt_len=16).to("cuda") |
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
|
output = model('awe') |
|
target = torch.ones_like(output) |
|
loss = ((output - target) ** 2).mean() |
|
print(output) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
for name, param in model.named_parameters(): |
|
if param.grad is not None: |
|
print(f"{name} has gradient ✅, grad mean: {param.grad.mean().item()}") |
|
if name == "emotion_params.weight": |
|
print(param.grad) |
|
else: |
|
print(f"{name} has NO gradient ❌") |
|
|
|
|
|
optimizer.step() |
|
print(output) |
|
|