Emotion-Director's picture
Upload folder using huggingface_hub
017bf8e verified
import math
import torch
import torch.nn as nn
from typing import Set, List, Optional, Type
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
class LoraInjectedLinear(nn.Module):
def __init__(
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
):
super().__init__()
if r > min(in_features, out_features):
#raise ValueError(
# f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
#)
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}")
r = min(in_features, out_features)
self.r = r
self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.dropout = nn.Dropout(dropout_p)
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = scale
self.selector = nn.Identity()
nn.init.normal_(self.lora_down.weight, std=1 / r)
nn.init.zeros_(self.lora_up.weight)
def update_step(self, cur_step):
self.cur_step = cur_step
def forward(self, input, return_format='linear'):
assert return_format in ['linear', 'lora', 'added', 'full']
if return_format == 'linear': return self.linear(input)
elif return_format == 'lora': return self.dropout(self.lora_up(self.selector(self.lora_down(input))))
elif return_format == 'added':
return (
self.linear(input)
+ self.dropout(self.lora_up(self.selector(self.lora_down(input))))
* self.scale
)
linear_res = self.linear(input)
lora_res = self.dropout(self.lora_up(self.selector(self.lora_down(input))))
return linear_res, lora_res, self.scale
def realize_as_lora(self):
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
def set_selector_from_diag(self, diag: torch.Tensor):
# diag is a 1D tensor of size (r,)
assert diag.shape == (self.r,)
self.selector = nn.Linear(self.r, self.r, bias=False)
self.selector.weight.data = torch.diag(diag)
self.selector.weight.data = self.selector.weight.data.to(
self.lora_up.weight.device
).to(self.lora_up.weight.dtype)
def _find_modules_v2(
model,
ancestor_class: Optional[Set[str]] = None,
search_name = 'attn2',
include_names = ['to_q', 'to_k', 'to_v'],
search_class: List[Type[nn.Module]] = [nn.Linear],
exclude_children_of: Optional[List[Type[nn.Module]]] = [
LoraInjectedLinear,
],
):
"""
Find all modules of a certain class (or union of classes) that are direct or
indirect descendants of other modules of a certain class (or union of classes).
Returns all matching modules, along with the parent of those moduless and the
names they are referenced by.
"""
# Get the targets we should replace all linears under
if ancestor_class is not None:
ancestors = (
module
for module in model.modules()
if module.__class__.__name__ in ancestor_class
)
else:
# this, incase you want to naively iterate over all modules.
ancestors = [module for module in model.modules()]
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for ancestor in ancestors:
for fullname, module in ancestor.named_modules():
if search_name in fullname:
*path, base_name = fullname.split('.')
parent = ancestor
while path:
parent = parent.get_submodule(path.pop(0))
if base_name in include_names:
assert isinstance(module, search_class[0])
yield parent, base_name, module
# if any([isinstance(module, _class) for _class in search_class]):
# # Find the direct parent if this is a descendant, not a child, of target
# *path, name = fullname.split(".")
# parent = ancestor
# while path:
# parent = parent.get_submodule(path.pop(0))
# # Skip this linear if it's a child of a LoraInjectedLinear
# if exclude_children_of and any(
# [isinstance(parent, _class) for _class in exclude_children_of]
# ):
# continue
# # Otherwise, yield it
# yield parent, name, module
def inject_trainable_lora(
model: nn.Module,
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
lora_rank: int = 4,
loras=None, # path to lora .pt
verbose: bool = False,
dropout: float = 0.0,
scale: float = 1.0,
):
"""
inject lora into model, and returns lora parameter groups.
"""
require_grad_params = []
names = []
if loras != None:
loras = torch.load(loras)
for _module, name, _child_module in _find_modules_v2(
model, target_replace_module, search_class=[nn.Linear]
):
weight = _child_module.weight
bias = _child_module.bias
if verbose:
print("LoRA Injection : injecting lora into ", name)
print("LoRA Injection : weight shape", weight.shape)
_tmp = LoraInjectedLinear(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r=lora_rank,
dropout_p=dropout,
scale=scale,
)
_tmp.linear.weight = weight
if bias is not None:
_tmp.linear.bias = bias
# switch the module
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
_module._modules[name] = _tmp
require_grad_params.append(_module._modules[name].lora_up.parameters())
require_grad_params.append(_module._modules[name].lora_down.parameters())
if loras != None:
_module._modules[name].lora_up.weight = loras.pop(0)
_module._modules[name].lora_down.weight = loras.pop(0)
_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)
return require_grad_params, names
lora_args = dict(
model = None,
loras = None,
target_replace_module = [],
lora_rank = 4,
dropout = 0,
scale = 0
)
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
loras = []
for _m, _n, _child_module in _find_modules_v2(
model,
target_replace_module,
search_class=[LoraInjectedLinear],
):
loras.append((_child_module.lora_up, _child_module.lora_down))
if len(loras) == 0:
raise ValueError("No lora injected.")
return loras
def do_lora_injection(model, replace_modules, lora_loader_args=None):
REPLACE_MODULES = replace_modules
params = None
negation = None
injector_args = lora_loader_args
params, negation = inject_trainable_lora(**injector_args)
success_inject = True
for _up, _down in extract_lora_ups_down(model, target_replace_module=REPLACE_MODULES):
if not all(x is not None for x in [_up, _down]): success_inject = False
if success_inject:
print(f"Lora successfully injected into {model.__class__.__name__}.")
else:
print(f'Fail to inject Lora into {model.__class__.__name__}')
exit(-1)
return params, negation
def add_lora_to_model(model, dropout=0.0, lora_rank=16,
scale=0, replace_modules: str = ["Transformer2DModel"]):
'''
replace_modules needs to be fixed to the proper block
'''
params = None
negation = None
lora_loader_args = lora_args.copy()
lora_loader_args.update({
"model": model,
"loras": None,
"target_replace_module": replace_modules,
"lora_rank": lora_rank,
"dropout": dropout,
"scale": scale
})
params, negation = do_lora_injection(model, replace_modules, lora_loader_args=lora_loader_args)
params = model if params is None else params
return params, negation