|
import math |
|
import torch |
|
import torch.nn as nn |
|
from typing import Set, List, Optional, Type |
|
|
|
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} |
|
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE |
|
|
|
|
|
class LoraInjectedLinear(nn.Module): |
|
def __init__( |
|
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 |
|
): |
|
super().__init__() |
|
|
|
if r > min(in_features, out_features): |
|
|
|
|
|
|
|
print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}") |
|
r = min(in_features, out_features) |
|
|
|
self.r = r |
|
self.linear = nn.Linear(in_features, out_features, bias) |
|
self.lora_down = nn.Linear(in_features, r, bias=False) |
|
self.dropout = nn.Dropout(dropout_p) |
|
self.lora_up = nn.Linear(r, out_features, bias=False) |
|
self.scale = scale |
|
self.selector = nn.Identity() |
|
|
|
nn.init.normal_(self.lora_down.weight, std=1 / r) |
|
nn.init.zeros_(self.lora_up.weight) |
|
|
|
def update_step(self, cur_step): |
|
self.cur_step = cur_step |
|
|
|
def forward(self, input, return_format='linear'): |
|
assert return_format in ['linear', 'lora', 'added', 'full'] |
|
if return_format == 'linear': return self.linear(input) |
|
elif return_format == 'lora': return self.dropout(self.lora_up(self.selector(self.lora_down(input)))) |
|
elif return_format == 'added': |
|
return ( |
|
self.linear(input) |
|
+ self.dropout(self.lora_up(self.selector(self.lora_down(input)))) |
|
* self.scale |
|
) |
|
linear_res = self.linear(input) |
|
lora_res = self.dropout(self.lora_up(self.selector(self.lora_down(input)))) |
|
return linear_res, lora_res, self.scale |
|
|
|
def realize_as_lora(self): |
|
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data |
|
|
|
def set_selector_from_diag(self, diag: torch.Tensor): |
|
|
|
assert diag.shape == (self.r,) |
|
self.selector = nn.Linear(self.r, self.r, bias=False) |
|
self.selector.weight.data = torch.diag(diag) |
|
self.selector.weight.data = self.selector.weight.data.to( |
|
self.lora_up.weight.device |
|
).to(self.lora_up.weight.dtype) |
|
|
|
|
|
def _find_modules_v2( |
|
model, |
|
ancestor_class: Optional[Set[str]] = None, |
|
search_name = 'attn2', |
|
include_names = ['to_q', 'to_k', 'to_v'], |
|
search_class: List[Type[nn.Module]] = [nn.Linear], |
|
exclude_children_of: Optional[List[Type[nn.Module]]] = [ |
|
LoraInjectedLinear, |
|
], |
|
): |
|
""" |
|
Find all modules of a certain class (or union of classes) that are direct or |
|
indirect descendants of other modules of a certain class (or union of classes). |
|
|
|
Returns all matching modules, along with the parent of those moduless and the |
|
names they are referenced by. |
|
""" |
|
|
|
|
|
if ancestor_class is not None: |
|
ancestors = ( |
|
module |
|
for module in model.modules() |
|
if module.__class__.__name__ in ancestor_class |
|
) |
|
else: |
|
|
|
ancestors = [module for module in model.modules()] |
|
|
|
|
|
for ancestor in ancestors: |
|
for fullname, module in ancestor.named_modules(): |
|
if search_name in fullname: |
|
*path, base_name = fullname.split('.') |
|
parent = ancestor |
|
while path: |
|
parent = parent.get_submodule(path.pop(0)) |
|
if base_name in include_names: |
|
assert isinstance(module, search_class[0]) |
|
yield parent, base_name, module |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inject_trainable_lora( |
|
model: nn.Module, |
|
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, |
|
lora_rank: int = 4, |
|
loras=None, |
|
verbose: bool = False, |
|
dropout: float = 0.0, |
|
scale: float = 1.0, |
|
): |
|
""" |
|
inject lora into model, and returns lora parameter groups. |
|
""" |
|
|
|
require_grad_params = [] |
|
names = [] |
|
|
|
if loras != None: |
|
loras = torch.load(loras) |
|
|
|
for _module, name, _child_module in _find_modules_v2( |
|
model, target_replace_module, search_class=[nn.Linear] |
|
): |
|
weight = _child_module.weight |
|
bias = _child_module.bias |
|
if verbose: |
|
print("LoRA Injection : injecting lora into ", name) |
|
print("LoRA Injection : weight shape", weight.shape) |
|
_tmp = LoraInjectedLinear( |
|
_child_module.in_features, |
|
_child_module.out_features, |
|
_child_module.bias is not None, |
|
r=lora_rank, |
|
dropout_p=dropout, |
|
scale=scale, |
|
) |
|
_tmp.linear.weight = weight |
|
if bias is not None: |
|
_tmp.linear.bias = bias |
|
|
|
|
|
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) |
|
_module._modules[name] = _tmp |
|
|
|
require_grad_params.append(_module._modules[name].lora_up.parameters()) |
|
require_grad_params.append(_module._modules[name].lora_down.parameters()) |
|
|
|
if loras != None: |
|
_module._modules[name].lora_up.weight = loras.pop(0) |
|
_module._modules[name].lora_down.weight = loras.pop(0) |
|
|
|
_module._modules[name].lora_up.weight.requires_grad = True |
|
_module._modules[name].lora_down.weight.requires_grad = True |
|
names.append(name) |
|
|
|
return require_grad_params, names |
|
|
|
|
|
lora_args = dict( |
|
model = None, |
|
loras = None, |
|
target_replace_module = [], |
|
lora_rank = 4, |
|
dropout = 0, |
|
scale = 0 |
|
) |
|
|
|
|
|
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): |
|
loras = [] |
|
for _m, _n, _child_module in _find_modules_v2( |
|
model, |
|
target_replace_module, |
|
search_class=[LoraInjectedLinear], |
|
): |
|
loras.append((_child_module.lora_up, _child_module.lora_down)) |
|
|
|
if len(loras) == 0: |
|
raise ValueError("No lora injected.") |
|
|
|
return loras |
|
|
|
|
|
def do_lora_injection(model, replace_modules, lora_loader_args=None): |
|
REPLACE_MODULES = replace_modules |
|
|
|
params = None |
|
negation = None |
|
injector_args = lora_loader_args |
|
|
|
params, negation = inject_trainable_lora(**injector_args) |
|
|
|
success_inject = True |
|
for _up, _down in extract_lora_ups_down(model, target_replace_module=REPLACE_MODULES): |
|
|
|
if not all(x is not None for x in [_up, _down]): success_inject = False |
|
|
|
if success_inject: |
|
print(f"Lora successfully injected into {model.__class__.__name__}.") |
|
else: |
|
print(f'Fail to inject Lora into {model.__class__.__name__}') |
|
exit(-1) |
|
|
|
return params, negation |
|
|
|
|
|
def add_lora_to_model(model, dropout=0.0, lora_rank=16, |
|
scale=0, replace_modules: str = ["Transformer2DModel"]): |
|
''' |
|
replace_modules needs to be fixed to the proper block |
|
''' |
|
params = None |
|
negation = None |
|
|
|
lora_loader_args = lora_args.copy() |
|
lora_loader_args.update({ |
|
"model": model, |
|
"loras": None, |
|
"target_replace_module": replace_modules, |
|
"lora_rank": lora_rank, |
|
"dropout": dropout, |
|
"scale": scale |
|
}) |
|
|
|
params, negation = do_lora_injection(model, replace_modules, lora_loader_args=lora_loader_args) |
|
|
|
params = model if params is None else params |
|
return params, negation |
|
|
|
|