modify: update modeling code
Browse files- configuration_baichuan.py +1 -0
- modeling_baichuan.py +4 -2
- quantizer.py +4 -1
- requirements.txt +6 -0
- tokenization_baichuan.py +2 -1
configuration_baichuan.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
|
| 2 |
from transformers.configuration_utils import PretrainedConfig
|
| 3 |
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
|
| 2 |
|
| 3 |
from transformers.configuration_utils import PretrainedConfig
|
| 4 |
|
modeling_baichuan.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
from typing import List, Optional, Tuple, Union
|
| 3 |
|
|
@@ -238,7 +240,7 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
| 238 |
if self.first_run:
|
| 239 |
self.first_run = False
|
| 240 |
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
| 241 |
-
if
|
| 242 |
self.max_cache_pos = seq_length_with_past
|
| 243 |
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
| 244 |
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
|
|
@@ -266,7 +268,6 @@ class BaichuanModel(BaichuanPreTrainedModel):
|
|
| 266 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
| 267 |
|
| 268 |
seq_length_with_past = seq_length
|
| 269 |
-
past_key_values_length = 0
|
| 270 |
|
| 271 |
if past_key_values is not None:
|
| 272 |
past_key_values_length = past_key_values[0][0].shape[2]
|
|
@@ -366,6 +367,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
| 366 |
output_attentions: Optional[bool] = False,
|
| 367 |
output_hidden_states: Optional[bool] = False,
|
| 368 |
return_dict: Optional[bool] = True,
|
|
|
|
| 369 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 370 |
|
| 371 |
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
|
| 2 |
+
|
| 3 |
import math
|
| 4 |
from typing import List, Optional, Tuple, Union
|
| 5 |
|
|
|
|
| 240 |
if self.first_run:
|
| 241 |
self.first_run = False
|
| 242 |
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
| 243 |
+
if seq_length_with_past > self.max_cache_pos:
|
| 244 |
self.max_cache_pos = seq_length_with_past
|
| 245 |
self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
|
| 246 |
mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
|
|
|
|
| 268 |
raise ValueError("You need to provide input_ids or inputs_embeds")
|
| 269 |
|
| 270 |
seq_length_with_past = seq_length
|
|
|
|
| 271 |
|
| 272 |
if past_key_values is not None:
|
| 273 |
past_key_values_length = past_key_values[0][0].shape[2]
|
|
|
|
| 367 |
output_attentions: Optional[bool] = False,
|
| 368 |
output_hidden_states: Optional[bool] = False,
|
| 369 |
return_dict: Optional[bool] = True,
|
| 370 |
+
**kwargs
|
| 371 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 372 |
|
| 373 |
|
quantizer.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from typing import List
|
| 3 |
import bz2
|
|
@@ -92,10 +94,11 @@ class QLinear(torch.nn.Module):
|
|
| 92 |
super().__init__()
|
| 93 |
self.quant_bits = bits
|
| 94 |
self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1)
|
|
|
|
| 95 |
if self.quant_bits == 4:
|
| 96 |
self.weight = quant4(weight, self.scale)
|
| 97 |
elif self.quant_bits == 8:
|
| 98 |
-
self.weight = torch.round(weight / self.scale[:, None]).to(torch.int8)
|
| 99 |
if self.quant_bits == 8:
|
| 100 |
self.weight = self.weight.T
|
| 101 |
self.bias = None
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
from typing import List
|
| 5 |
import bz2
|
|
|
|
| 94 |
super().__init__()
|
| 95 |
self.quant_bits = bits
|
| 96 |
self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1)
|
| 97 |
+
self.scale = self.scale.to(torch.float32)
|
| 98 |
if self.quant_bits == 4:
|
| 99 |
self.weight = quant4(weight, self.scale)
|
| 100 |
elif self.quant_bits == 8:
|
| 101 |
+
self.weight = torch.round(weight.to(self.scale.dtype) / self.scale[:, None]).to(torch.int8)
|
| 102 |
if self.quant_bits == 8:
|
| 103 |
self.weight = self.weight.T
|
| 104 |
self.bias = None
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate
|
| 2 |
+
colorama
|
| 3 |
+
cpm_kernels
|
| 4 |
+
sentencepiece
|
| 5 |
+
streamlit
|
| 6 |
+
transformers_stream_generator
|
tokenization_baichuan.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
from shutil import copyfile
|
| 3 |
from typing import Any, Dict, List, Optional, Tuple
|
| 4 |
|
| 5 |
import sentencepiece as spm
|
| 6 |
-
|
| 7 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 8 |
from transformers.utils import logging
|
| 9 |
|
|
|
|
| 1 |
+
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
|
| 2 |
+
|
| 3 |
import os
|
| 4 |
from shutil import copyfile
|
| 5 |
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
|
| 7 |
import sentencepiece as spm
|
|
|
|
| 8 |
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
|
| 9 |
from transformers.utils import logging
|
| 10 |
|