MMP_Diffusion / models /visual_prompts.py
Emotion-Director's picture
Upload folder using huggingface_hub
017bf8e verified
import torch
import torch.nn as nn
import os
import numpy as np
from diffusers.models.attention_processor import Attention
class VisualTokenSelfAttn(torch.nn.Module):
def __init__(self, in_dim=2792, out_dim=768, num_heads=8):
super().__init__()
self.meta_token_trans = nn.Sequential(
nn.Linear(in_dim, out_dim * 4),
nn.LayerNorm(out_dim * 4),
nn.GELU(),
nn.Linear(out_dim * 4, out_dim),
nn.LayerNorm(out_dim)
)
self.norm1 = nn.LayerNorm(out_dim, eps=1e-6) # important to avoid attention collapsing
self.attn = Attention(query_dim=out_dim, heads=num_heads)
self.norm2 = nn.LayerNorm(out_dim, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(out_dim, out_dim * 4),
nn.GELU(),
nn.Linear(out_dim * 4, out_dim)
)
def forward(self, x):
x = self.meta_token_trans(x)
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class EmotionEmbedding(nn.Module):
def __init__(self, emotions, prompts_dir, feature_names, output_dim, prompt_len=16):
super().__init__()
input_dim = self.get_input_dim(feature_names=feature_names)
self.self_attn = VisualTokenSelfAttn(in_dim=input_dim, out_dim=output_dim)
self.emotions = emotions
self.emotion2idx = {emotion: idx for idx, emotion in enumerate(emotions)}
self.emotion_params = nn.ParameterList()
self.emotion_init_features = self.get_features(emotions, prompts_dir, feature_names, prompt_len)
for emotion in self.emotions:
init_params = self.emotion_init_features[emotion]
# init_params = torch.from_numpy(init_params).float()
param = nn.Parameter(init_params)
self.emotion_params.append(param)
def get_features(self, emotions, prompts_dir, feature_names, prompt_len):
emotion_init_features = {}
for emotion in emotions:
emotion_features = []
for feature_name in feature_names:
features = np.load(os.path.join(prompts_dir, f'{emotion}_{feature_name}.npy'), allow_pickle=True)
emotion_features.append(features)
emotion_features = np.concatenate(emotion_features, axis=1)
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=prompt_len, random_state=42)
kmeans.fit_predict(emotion_features)
token = torch.tensor(kmeans.cluster_centers_).unsqueeze(0)
# print(token.shape)
emotion_init_features[emotion] = token
return emotion_init_features
def get_input_dim(self, feature_names):
if feature_names == ["clip"]:
in_dim = 768
elif feature_names == ["vgg"]:
in_dim = 1000
elif feature_names == ["dinov2"]:
in_dim = 1024
elif feature_names == ["clip", "vgg"]:
in_dim = 1768
elif feature_names == ["clip", "dinov2"]:
in_dim = 1768
elif feature_names == ["vgg", "dinov2"]:
in_dim = 2024
elif feature_names == ["clip", "vgg", "dinov2"]:
in_dim = 2792
else:
raise ValueError("Invalid feature names")
return in_dim
def params_to_prompts(self):
self.emotion_prompts = {}
for emotion in self.emotions:
prompt = self.self_attn(self.emotion_params[self.emotion2idx[emotion]])
prompt = prompt.squeeze(0)
self.emotion_prompts[emotion] = prompt
def forward(self, emotion):
if isinstance(emotion, str):
emotions = [emotion]
else:
emotions = emotion
self.params_to_prompts()
selected_prompts = [self.emotion_prompts[emotion] for emotion in emotions]
prompts = torch.stack(selected_prompts, dim=0)
del self.emotion_prompts
return prompts
class EmotionEmbedding2(nn.Module):
def __init__(self, emotions, input_dim, output_dim):
super().__init__()
self.self_attn = VisualTokenSelfAttn(in_dim=input_dim, out_dim=output_dim)
self.emotions = emotions
self.emotion2idx = {emotion: idx for idx, emotion in enumerate(emotions)}
self.emotion_params = nn.Embedding(len(emotions), input_dim)
def forward(self, emotion):
if isinstance(emotion, str):
emotions = [emotion]
else:
emotions = emotion
emotions = [self.emotion2idx[emotion] for emotion in emotions]
emotions = torch.tensor(emotions, device=self.emotion_params.weight.device)
prompts = self.emotion_params(emotions).unsqueeze(1)
prompts = self.self_attn(prompts)
return prompts
if __name__ == "__main__":
# emotions = ["amusement", "anger", "awe", "contentment",
# "disgust", "excitement", "fear", "sadness"]
# feature_names = ["clip", "vgg", "dinov2"]
# prompts_dir = "features/origin"
# model = EmotionEmbedding(emotions, prompts_dir, feature_names, output_dim=2048, prompt_len=16).to("cuda")
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# output = model('awe')
# target = torch.ones_like(output)
# loss = ((output - target) ** 2).mean()
# print(output)
emotions = ["amusement", "anger", "awe", "contentment",
"disgust", "excitement", "fear", "sadness"]
prompts_dir = "features/origin"
model = EmotionEmbedding2(emotions, input_dim=2048, output_dim=2048, prompt_len=16).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
output = model('awe')
target = torch.ones_like(output)
loss = ((output - target) ** 2).mean()
print(output)
# 反向传播
loss.backward()
# 打印看看梯度
for name, param in model.named_parameters():
if param.grad is not None:
print(f"{name} has gradient ✅, grad mean: {param.grad.mean().item()}")
if name == "emotion_params.weight":
print(param.grad)
else:
print(f"{name} has NO gradient ❌")
# 更新一下参数
optimizer.step()
print(output)