import math import torch import torch.nn as nn from typing import Set, List, Optional, Type UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE class LoraInjectedLinear(nn.Module): def __init__( self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0 ): super().__init__() if r > min(in_features, out_features): #raise ValueError( # f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}" #) print(f"LoRA rank {r} is too large. setting to: {min(in_features, out_features)}") r = min(in_features, out_features) self.r = r self.linear = nn.Linear(in_features, out_features, bias) self.lora_down = nn.Linear(in_features, r, bias=False) self.dropout = nn.Dropout(dropout_p) self.lora_up = nn.Linear(r, out_features, bias=False) self.scale = scale self.selector = nn.Identity() nn.init.normal_(self.lora_down.weight, std=1 / r) nn.init.zeros_(self.lora_up.weight) def update_step(self, cur_step): self.cur_step = cur_step def forward(self, input, return_format='linear'): assert return_format in ['linear', 'lora', 'added', 'full'] if return_format == 'linear': return self.linear(input) elif return_format == 'lora': return self.dropout(self.lora_up(self.selector(self.lora_down(input)))) elif return_format == 'added': return ( self.linear(input) + self.dropout(self.lora_up(self.selector(self.lora_down(input)))) * self.scale ) linear_res = self.linear(input) lora_res = self.dropout(self.lora_up(self.selector(self.lora_down(input)))) return linear_res, lora_res, self.scale def realize_as_lora(self): return self.lora_up.weight.data * self.scale, self.lora_down.weight.data def set_selector_from_diag(self, diag: torch.Tensor): # diag is a 1D tensor of size (r,) assert diag.shape == (self.r,) self.selector = nn.Linear(self.r, self.r, bias=False) self.selector.weight.data = torch.diag(diag) self.selector.weight.data = self.selector.weight.data.to( self.lora_up.weight.device ).to(self.lora_up.weight.dtype) def _find_modules_v2( model, ancestor_class: Optional[Set[str]] = None, search_name = 'attn2', include_names = ['to_q', 'to_k', 'to_v'], search_class: List[Type[nn.Module]] = [nn.Linear], exclude_children_of: Optional[List[Type[nn.Module]]] = [ LoraInjectedLinear, ], ): """ Find all modules of a certain class (or union of classes) that are direct or indirect descendants of other modules of a certain class (or union of classes). Returns all matching modules, along with the parent of those moduless and the names they are referenced by. """ # Get the targets we should replace all linears under if ancestor_class is not None: ancestors = ( module for module in model.modules() if module.__class__.__name__ in ancestor_class ) else: # this, incase you want to naively iterate over all modules. ancestors = [module for module in model.modules()] # For each target find every linear_class module that isn't a child of a LoraInjectedLinear for ancestor in ancestors: for fullname, module in ancestor.named_modules(): if search_name in fullname: *path, base_name = fullname.split('.') parent = ancestor while path: parent = parent.get_submodule(path.pop(0)) if base_name in include_names: assert isinstance(module, search_class[0]) yield parent, base_name, module # if any([isinstance(module, _class) for _class in search_class]): # # Find the direct parent if this is a descendant, not a child, of target # *path, name = fullname.split(".") # parent = ancestor # while path: # parent = parent.get_submodule(path.pop(0)) # # Skip this linear if it's a child of a LoraInjectedLinear # if exclude_children_of and any( # [isinstance(parent, _class) for _class in exclude_children_of] # ): # continue # # Otherwise, yield it # yield parent, name, module def inject_trainable_lora( model: nn.Module, target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, lora_rank: int = 4, loras=None, # path to lora .pt verbose: bool = False, dropout: float = 0.0, scale: float = 1.0, ): """ inject lora into model, and returns lora parameter groups. """ require_grad_params = [] names = [] if loras != None: loras = torch.load(loras) for _module, name, _child_module in _find_modules_v2( model, target_replace_module, search_class=[nn.Linear] ): weight = _child_module.weight bias = _child_module.bias if verbose: print("LoRA Injection : injecting lora into ", name) print("LoRA Injection : weight shape", weight.shape) _tmp = LoraInjectedLinear( _child_module.in_features, _child_module.out_features, _child_module.bias is not None, r=lora_rank, dropout_p=dropout, scale=scale, ) _tmp.linear.weight = weight if bias is not None: _tmp.linear.bias = bias # switch the module _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype) _module._modules[name] = _tmp require_grad_params.append(_module._modules[name].lora_up.parameters()) require_grad_params.append(_module._modules[name].lora_down.parameters()) if loras != None: _module._modules[name].lora_up.weight = loras.pop(0) _module._modules[name].lora_down.weight = loras.pop(0) _module._modules[name].lora_up.weight.requires_grad = True _module._modules[name].lora_down.weight.requires_grad = True names.append(name) return require_grad_params, names lora_args = dict( model = None, loras = None, target_replace_module = [], lora_rank = 4, dropout = 0, scale = 0 ) def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): loras = [] for _m, _n, _child_module in _find_modules_v2( model, target_replace_module, search_class=[LoraInjectedLinear], ): loras.append((_child_module.lora_up, _child_module.lora_down)) if len(loras) == 0: raise ValueError("No lora injected.") return loras def do_lora_injection(model, replace_modules, lora_loader_args=None): REPLACE_MODULES = replace_modules params = None negation = None injector_args = lora_loader_args params, negation = inject_trainable_lora(**injector_args) success_inject = True for _up, _down in extract_lora_ups_down(model, target_replace_module=REPLACE_MODULES): if not all(x is not None for x in [_up, _down]): success_inject = False if success_inject: print(f"Lora successfully injected into {model.__class__.__name__}.") else: print(f'Fail to inject Lora into {model.__class__.__name__}') exit(-1) return params, negation def add_lora_to_model(model, dropout=0.0, lora_rank=16, scale=0, replace_modules: str = ["Transformer2DModel"]): ''' replace_modules needs to be fixed to the proper block ''' params = None negation = None lora_loader_args = lora_args.copy() lora_loader_args.update({ "model": model, "loras": None, "target_replace_module": replace_modules, "lora_rank": lora_rank, "dropout": dropout, "scale": scale }) params, negation = do_lora_injection(model, replace_modules, lora_loader_args=lora_loader_args) params = model if params is None else params return params, negation