|
|
|
import torch |
|
from transformers import LlamaForCausalLM |
|
from .configuration_pruned_llama import LlamaPrunedConfig |
|
import torch.nn as nn |
|
|
|
|
|
class LlamaPrunedForCausalLM(LlamaForCausalLM): |
|
config_class = LlamaPrunedConfig |
|
|
|
def __init__(self, config: LlamaPrunedConfig): |
|
super().__init__(config) |
|
|
|
for layer in self.model.layers[config.begin_pruned_layer: config.end_pruned_layer]: |
|
layer.self_attn.hidden_size = 3072 |
|
layer.self_attn.q_proj = nn.Linear(4096, 3072, bias=False) |
|
layer.self_attn.k_proj = nn.Linear(4096, 768, bias=False) |
|
layer.self_attn.v_proj = nn.Linear(4096, 768, bias=False) |
|
layer.self_attn.o_proj = nn.Linear(3072, 4096, bias=False) |
|
layer.mlp.gate_proj = nn.Linear(4096, 10752, bias=False) |
|
layer.mlp.up_proj = nn.Linear(4096, 10752, bias=False) |
|
layer.mlp.down_proj = nn.Linear(10752, 4096, bias=False) |
|
|
|
for layer in self.model.layers: |
|
layer.self_attn.num_heads = layer.self_attn.q_proj.weight.data.shape[0] // layer.self_attn.head_dim |
|
layer.self_attn.num_key_value_heads = layer.self_attn.k_proj.weight.data.shape[ |
|
0] // layer.self_attn.head_dim |
|
|