Upload OLMoForCausalLM
Browse files- README.md +199 -0
- config.json +56 -0
- configuration_olmo.py +43 -0
- generation_config.json +6 -0
- model-00001-of-00006.safetensors +3 -0
- model-00002-of-00006.safetensors +3 -0
- model-00003-of-00006.safetensors +3 -0
- model-00004-of-00006.safetensors +3 -0
- model-00005-of-00006.safetensors +3 -0
- model-00006-of-00006.safetensors +3 -0
- model.safetensors.index.json +137 -0
- modeling_olmo.py +228 -0
README.md
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags: []
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
config.json
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "/home/itay.itzhak/projects/proj2/finetuning/open-instruct/output/allenai/tulu-v2-sft-mixture_allenai/OLMo-7B_lora_r128_alpha256_LR2e-5_seed_2/merged",
|
3 |
+
"activation_type": "swiglu",
|
4 |
+
"alibi": false,
|
5 |
+
"alibi_bias_max": 8.0,
|
6 |
+
"architectures": [
|
7 |
+
"OLMoForCausalLM"
|
8 |
+
],
|
9 |
+
"attention_dropout": 0.0,
|
10 |
+
"attention_layer_norm": false,
|
11 |
+
"attention_layer_norm_with_affine": false,
|
12 |
+
"auto_map": {
|
13 |
+
"AutoConfig": "configuration_olmo.OLMoConfig",
|
14 |
+
"AutoModelForCausalLM": "modeling_olmo.OLMoForCausalLM",
|
15 |
+
"AutoTokenizer": [
|
16 |
+
"allenai/OLMo-7B--tokenization_olmo_fast.OLMoTokenizerFast",
|
17 |
+
"allenai/OLMo-7B--tokenization_olmo_fast.OLMoTokenizerFast"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
"bias_for_layer_norm": false,
|
21 |
+
"block_group_size": 1,
|
22 |
+
"block_type": "sequential",
|
23 |
+
"clip_qkv": null,
|
24 |
+
"d_model": 4096,
|
25 |
+
"embedding_dropout": 0.0,
|
26 |
+
"embedding_size": 50304,
|
27 |
+
"eos_token_id": 50279,
|
28 |
+
"flash_attention": true,
|
29 |
+
"include_bias": false,
|
30 |
+
"init_cutoff_factor": null,
|
31 |
+
"init_device": "meta",
|
32 |
+
"init_fn": "mitchell",
|
33 |
+
"init_std": 0.02,
|
34 |
+
"layer_norm_eps": 1e-05,
|
35 |
+
"layer_norm_type": "default",
|
36 |
+
"layer_norm_with_affine": false,
|
37 |
+
"max_sequence_length": 2048,
|
38 |
+
"mlp_hidden_size": 22016,
|
39 |
+
"mlp_ratio": 4,
|
40 |
+
"model_type": "hf_olmo",
|
41 |
+
"multi_query_attention": false,
|
42 |
+
"n_heads": 32,
|
43 |
+
"n_kv_heads": null,
|
44 |
+
"n_layers": 32,
|
45 |
+
"pad_token_id": 1,
|
46 |
+
"precision": "amp_bf16",
|
47 |
+
"residual_dropout": 0.0,
|
48 |
+
"rope": true,
|
49 |
+
"rope_full_precision": true,
|
50 |
+
"scale_logits": false,
|
51 |
+
"torch_dtype": "float32",
|
52 |
+
"transformers_version": "4.42.4",
|
53 |
+
"use_cache": true,
|
54 |
+
"vocab_size": 50280,
|
55 |
+
"weight_tying": false
|
56 |
+
}
|
configuration_olmo.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
OLMo configuration
|
3 |
+
"""
|
4 |
+
|
5 |
+
from transformers import AutoConfig, PretrainedConfig
|
6 |
+
from transformers.utils import logging
|
7 |
+
|
8 |
+
from olmo.config import ModelConfig
|
9 |
+
|
10 |
+
logger = logging.get_logger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class OLMoConfig(PretrainedConfig):
|
14 |
+
model_type = "hf_olmo"
|
15 |
+
keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
|
16 |
+
|
17 |
+
def __init__(self, use_cache: bool = False, **kwargs):
|
18 |
+
model_config = ModelConfig()
|
19 |
+
all_kwargs = model_config.asdict()
|
20 |
+
all_kwargs.update(kwargs)
|
21 |
+
all_kwargs.update({"use_cache": use_cache})
|
22 |
+
all_kwargs.update(
|
23 |
+
{"architectures": all_kwargs.get("architectures", ["OLMoForCausalLM"]) or ["OLMoForCausalLM"]}
|
24 |
+
)
|
25 |
+
super().__init__(**all_kwargs)
|
26 |
+
|
27 |
+
@property
|
28 |
+
def num_attention_heads(self):
|
29 |
+
return self.n_heads
|
30 |
+
|
31 |
+
@property
|
32 |
+
def num_hidden_layers(self):
|
33 |
+
return self.n_layers
|
34 |
+
|
35 |
+
@property
|
36 |
+
def hidden_size(self):
|
37 |
+
return self.d_model
|
38 |
+
|
39 |
+
|
40 |
+
# Register the config class so that it is available for transformer pipelines, auto-loading etc.
|
41 |
+
# OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers
|
42 |
+
# may not support the newest architectures we create.
|
43 |
+
AutoConfig.register("hf_olmo", OLMoConfig)
|
generation_config.json
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"eos_token_id": 50279,
|
4 |
+
"pad_token_id": 1,
|
5 |
+
"transformers_version": "4.42.4"
|
6 |
+
}
|
model-00001-of-00006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:abe292042dd739164fae4b63a63cddb1d3d490964824f419348dc31c70a93e96
|
3 |
+
size 4938795616
|
model-00002-of-00006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f660dbba3daacf89b7fe9f299a6c375bc6eafcd2e5ef04e4a0760953c199101f
|
3 |
+
size 4857006944
|
model-00003-of-00006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b76a9f4dc3462d8a0d0b445cd171dbea8bb5c7fff53266483ae7f468bbeb7ced
|
3 |
+
size 4857006960
|
model-00004-of-00006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2c61a7e01d5b8e317b3efb70122ae26339f85022cf7ec5a83a50650b522fc144
|
3 |
+
size 4857006960
|
model-00005-of-00006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3ff7a3da588aa195dab4caa1d608c7d7a2c07661b373db2e4e38b2233496ed4
|
3 |
+
size 4857006960
|
model-00006-of-00006.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48f88262c6801df4f22fc82d34548c0d1b660de5576ae70ea3b18833eb1f39e4
|
3 |
+
size 3185575352
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {
|
3 |
+
"total_size": 27552382976
|
4 |
+
},
|
5 |
+
"weight_map": {
|
6 |
+
"model.transformer.blocks.0.att_proj.weight": "model-00001-of-00006.safetensors",
|
7 |
+
"model.transformer.blocks.0.attn_out.weight": "model-00001-of-00006.safetensors",
|
8 |
+
"model.transformer.blocks.0.ff_out.weight": "model-00001-of-00006.safetensors",
|
9 |
+
"model.transformer.blocks.0.ff_proj.weight": "model-00001-of-00006.safetensors",
|
10 |
+
"model.transformer.blocks.1.att_proj.weight": "model-00001-of-00006.safetensors",
|
11 |
+
"model.transformer.blocks.1.attn_out.weight": "model-00001-of-00006.safetensors",
|
12 |
+
"model.transformer.blocks.1.ff_out.weight": "model-00001-of-00006.safetensors",
|
13 |
+
"model.transformer.blocks.1.ff_proj.weight": "model-00001-of-00006.safetensors",
|
14 |
+
"model.transformer.blocks.10.att_proj.weight": "model-00002-of-00006.safetensors",
|
15 |
+
"model.transformer.blocks.10.attn_out.weight": "model-00002-of-00006.safetensors",
|
16 |
+
"model.transformer.blocks.10.ff_out.weight": "model-00002-of-00006.safetensors",
|
17 |
+
"model.transformer.blocks.10.ff_proj.weight": "model-00002-of-00006.safetensors",
|
18 |
+
"model.transformer.blocks.11.att_proj.weight": "model-00003-of-00006.safetensors",
|
19 |
+
"model.transformer.blocks.11.attn_out.weight": "model-00002-of-00006.safetensors",
|
20 |
+
"model.transformer.blocks.11.ff_out.weight": "model-00003-of-00006.safetensors",
|
21 |
+
"model.transformer.blocks.11.ff_proj.weight": "model-00003-of-00006.safetensors",
|
22 |
+
"model.transformer.blocks.12.att_proj.weight": "model-00003-of-00006.safetensors",
|
23 |
+
"model.transformer.blocks.12.attn_out.weight": "model-00003-of-00006.safetensors",
|
24 |
+
"model.transformer.blocks.12.ff_out.weight": "model-00003-of-00006.safetensors",
|
25 |
+
"model.transformer.blocks.12.ff_proj.weight": "model-00003-of-00006.safetensors",
|
26 |
+
"model.transformer.blocks.13.att_proj.weight": "model-00003-of-00006.safetensors",
|
27 |
+
"model.transformer.blocks.13.attn_out.weight": "model-00003-of-00006.safetensors",
|
28 |
+
"model.transformer.blocks.13.ff_out.weight": "model-00003-of-00006.safetensors",
|
29 |
+
"model.transformer.blocks.13.ff_proj.weight": "model-00003-of-00006.safetensors",
|
30 |
+
"model.transformer.blocks.14.att_proj.weight": "model-00003-of-00006.safetensors",
|
31 |
+
"model.transformer.blocks.14.attn_out.weight": "model-00003-of-00006.safetensors",
|
32 |
+
"model.transformer.blocks.14.ff_out.weight": "model-00003-of-00006.safetensors",
|
33 |
+
"model.transformer.blocks.14.ff_proj.weight": "model-00003-of-00006.safetensors",
|
34 |
+
"model.transformer.blocks.15.att_proj.weight": "model-00003-of-00006.safetensors",
|
35 |
+
"model.transformer.blocks.15.attn_out.weight": "model-00003-of-00006.safetensors",
|
36 |
+
"model.transformer.blocks.15.ff_out.weight": "model-00003-of-00006.safetensors",
|
37 |
+
"model.transformer.blocks.15.ff_proj.weight": "model-00003-of-00006.safetensors",
|
38 |
+
"model.transformer.blocks.16.att_proj.weight": "model-00003-of-00006.safetensors",
|
39 |
+
"model.transformer.blocks.16.attn_out.weight": "model-00003-of-00006.safetensors",
|
40 |
+
"model.transformer.blocks.16.ff_out.weight": "model-00003-of-00006.safetensors",
|
41 |
+
"model.transformer.blocks.16.ff_proj.weight": "model-00003-of-00006.safetensors",
|
42 |
+
"model.transformer.blocks.17.att_proj.weight": "model-00004-of-00006.safetensors",
|
43 |
+
"model.transformer.blocks.17.attn_out.weight": "model-00003-of-00006.safetensors",
|
44 |
+
"model.transformer.blocks.17.ff_out.weight": "model-00004-of-00006.safetensors",
|
45 |
+
"model.transformer.blocks.17.ff_proj.weight": "model-00004-of-00006.safetensors",
|
46 |
+
"model.transformer.blocks.18.att_proj.weight": "model-00004-of-00006.safetensors",
|
47 |
+
"model.transformer.blocks.18.attn_out.weight": "model-00004-of-00006.safetensors",
|
48 |
+
"model.transformer.blocks.18.ff_out.weight": "model-00004-of-00006.safetensors",
|
49 |
+
"model.transformer.blocks.18.ff_proj.weight": "model-00004-of-00006.safetensors",
|
50 |
+
"model.transformer.blocks.19.att_proj.weight": "model-00004-of-00006.safetensors",
|
51 |
+
"model.transformer.blocks.19.attn_out.weight": "model-00004-of-00006.safetensors",
|
52 |
+
"model.transformer.blocks.19.ff_out.weight": "model-00004-of-00006.safetensors",
|
53 |
+
"model.transformer.blocks.19.ff_proj.weight": "model-00004-of-00006.safetensors",
|
54 |
+
"model.transformer.blocks.2.att_proj.weight": "model-00001-of-00006.safetensors",
|
55 |
+
"model.transformer.blocks.2.attn_out.weight": "model-00001-of-00006.safetensors",
|
56 |
+
"model.transformer.blocks.2.ff_out.weight": "model-00001-of-00006.safetensors",
|
57 |
+
"model.transformer.blocks.2.ff_proj.weight": "model-00001-of-00006.safetensors",
|
58 |
+
"model.transformer.blocks.20.att_proj.weight": "model-00004-of-00006.safetensors",
|
59 |
+
"model.transformer.blocks.20.attn_out.weight": "model-00004-of-00006.safetensors",
|
60 |
+
"model.transformer.blocks.20.ff_out.weight": "model-00004-of-00006.safetensors",
|
61 |
+
"model.transformer.blocks.20.ff_proj.weight": "model-00004-of-00006.safetensors",
|
62 |
+
"model.transformer.blocks.21.att_proj.weight": "model-00004-of-00006.safetensors",
|
63 |
+
"model.transformer.blocks.21.attn_out.weight": "model-00004-of-00006.safetensors",
|
64 |
+
"model.transformer.blocks.21.ff_out.weight": "model-00004-of-00006.safetensors",
|
65 |
+
"model.transformer.blocks.21.ff_proj.weight": "model-00004-of-00006.safetensors",
|
66 |
+
"model.transformer.blocks.22.att_proj.weight": "model-00004-of-00006.safetensors",
|
67 |
+
"model.transformer.blocks.22.attn_out.weight": "model-00004-of-00006.safetensors",
|
68 |
+
"model.transformer.blocks.22.ff_out.weight": "model-00004-of-00006.safetensors",
|
69 |
+
"model.transformer.blocks.22.ff_proj.weight": "model-00004-of-00006.safetensors",
|
70 |
+
"model.transformer.blocks.23.att_proj.weight": "model-00005-of-00006.safetensors",
|
71 |
+
"model.transformer.blocks.23.attn_out.weight": "model-00004-of-00006.safetensors",
|
72 |
+
"model.transformer.blocks.23.ff_out.weight": "model-00005-of-00006.safetensors",
|
73 |
+
"model.transformer.blocks.23.ff_proj.weight": "model-00005-of-00006.safetensors",
|
74 |
+
"model.transformer.blocks.24.att_proj.weight": "model-00005-of-00006.safetensors",
|
75 |
+
"model.transformer.blocks.24.attn_out.weight": "model-00005-of-00006.safetensors",
|
76 |
+
"model.transformer.blocks.24.ff_out.weight": "model-00005-of-00006.safetensors",
|
77 |
+
"model.transformer.blocks.24.ff_proj.weight": "model-00005-of-00006.safetensors",
|
78 |
+
"model.transformer.blocks.25.att_proj.weight": "model-00005-of-00006.safetensors",
|
79 |
+
"model.transformer.blocks.25.attn_out.weight": "model-00005-of-00006.safetensors",
|
80 |
+
"model.transformer.blocks.25.ff_out.weight": "model-00005-of-00006.safetensors",
|
81 |
+
"model.transformer.blocks.25.ff_proj.weight": "model-00005-of-00006.safetensors",
|
82 |
+
"model.transformer.blocks.26.att_proj.weight": "model-00005-of-00006.safetensors",
|
83 |
+
"model.transformer.blocks.26.attn_out.weight": "model-00005-of-00006.safetensors",
|
84 |
+
"model.transformer.blocks.26.ff_out.weight": "model-00005-of-00006.safetensors",
|
85 |
+
"model.transformer.blocks.26.ff_proj.weight": "model-00005-of-00006.safetensors",
|
86 |
+
"model.transformer.blocks.27.att_proj.weight": "model-00005-of-00006.safetensors",
|
87 |
+
"model.transformer.blocks.27.attn_out.weight": "model-00005-of-00006.safetensors",
|
88 |
+
"model.transformer.blocks.27.ff_out.weight": "model-00005-of-00006.safetensors",
|
89 |
+
"model.transformer.blocks.27.ff_proj.weight": "model-00005-of-00006.safetensors",
|
90 |
+
"model.transformer.blocks.28.att_proj.weight": "model-00005-of-00006.safetensors",
|
91 |
+
"model.transformer.blocks.28.attn_out.weight": "model-00005-of-00006.safetensors",
|
92 |
+
"model.transformer.blocks.28.ff_out.weight": "model-00005-of-00006.safetensors",
|
93 |
+
"model.transformer.blocks.28.ff_proj.weight": "model-00005-of-00006.safetensors",
|
94 |
+
"model.transformer.blocks.29.att_proj.weight": "model-00006-of-00006.safetensors",
|
95 |
+
"model.transformer.blocks.29.attn_out.weight": "model-00005-of-00006.safetensors",
|
96 |
+
"model.transformer.blocks.29.ff_out.weight": "model-00006-of-00006.safetensors",
|
97 |
+
"model.transformer.blocks.29.ff_proj.weight": "model-00006-of-00006.safetensors",
|
98 |
+
"model.transformer.blocks.3.att_proj.weight": "model-00001-of-00006.safetensors",
|
99 |
+
"model.transformer.blocks.3.attn_out.weight": "model-00001-of-00006.safetensors",
|
100 |
+
"model.transformer.blocks.3.ff_out.weight": "model-00001-of-00006.safetensors",
|
101 |
+
"model.transformer.blocks.3.ff_proj.weight": "model-00001-of-00006.safetensors",
|
102 |
+
"model.transformer.blocks.30.att_proj.weight": "model-00006-of-00006.safetensors",
|
103 |
+
"model.transformer.blocks.30.attn_out.weight": "model-00006-of-00006.safetensors",
|
104 |
+
"model.transformer.blocks.30.ff_out.weight": "model-00006-of-00006.safetensors",
|
105 |
+
"model.transformer.blocks.30.ff_proj.weight": "model-00006-of-00006.safetensors",
|
106 |
+
"model.transformer.blocks.31.att_proj.weight": "model-00006-of-00006.safetensors",
|
107 |
+
"model.transformer.blocks.31.attn_out.weight": "model-00006-of-00006.safetensors",
|
108 |
+
"model.transformer.blocks.31.ff_out.weight": "model-00006-of-00006.safetensors",
|
109 |
+
"model.transformer.blocks.31.ff_proj.weight": "model-00006-of-00006.safetensors",
|
110 |
+
"model.transformer.blocks.4.att_proj.weight": "model-00001-of-00006.safetensors",
|
111 |
+
"model.transformer.blocks.4.attn_out.weight": "model-00001-of-00006.safetensors",
|
112 |
+
"model.transformer.blocks.4.ff_out.weight": "model-00001-of-00006.safetensors",
|
113 |
+
"model.transformer.blocks.4.ff_proj.weight": "model-00001-of-00006.safetensors",
|
114 |
+
"model.transformer.blocks.5.att_proj.weight": "model-00002-of-00006.safetensors",
|
115 |
+
"model.transformer.blocks.5.attn_out.weight": "model-00001-of-00006.safetensors",
|
116 |
+
"model.transformer.blocks.5.ff_out.weight": "model-00002-of-00006.safetensors",
|
117 |
+
"model.transformer.blocks.5.ff_proj.weight": "model-00002-of-00006.safetensors",
|
118 |
+
"model.transformer.blocks.6.att_proj.weight": "model-00002-of-00006.safetensors",
|
119 |
+
"model.transformer.blocks.6.attn_out.weight": "model-00002-of-00006.safetensors",
|
120 |
+
"model.transformer.blocks.6.ff_out.weight": "model-00002-of-00006.safetensors",
|
121 |
+
"model.transformer.blocks.6.ff_proj.weight": "model-00002-of-00006.safetensors",
|
122 |
+
"model.transformer.blocks.7.att_proj.weight": "model-00002-of-00006.safetensors",
|
123 |
+
"model.transformer.blocks.7.attn_out.weight": "model-00002-of-00006.safetensors",
|
124 |
+
"model.transformer.blocks.7.ff_out.weight": "model-00002-of-00006.safetensors",
|
125 |
+
"model.transformer.blocks.7.ff_proj.weight": "model-00002-of-00006.safetensors",
|
126 |
+
"model.transformer.blocks.8.att_proj.weight": "model-00002-of-00006.safetensors",
|
127 |
+
"model.transformer.blocks.8.attn_out.weight": "model-00002-of-00006.safetensors",
|
128 |
+
"model.transformer.blocks.8.ff_out.weight": "model-00002-of-00006.safetensors",
|
129 |
+
"model.transformer.blocks.8.ff_proj.weight": "model-00002-of-00006.safetensors",
|
130 |
+
"model.transformer.blocks.9.att_proj.weight": "model-00002-of-00006.safetensors",
|
131 |
+
"model.transformer.blocks.9.attn_out.weight": "model-00002-of-00006.safetensors",
|
132 |
+
"model.transformer.blocks.9.ff_out.weight": "model-00002-of-00006.safetensors",
|
133 |
+
"model.transformer.blocks.9.ff_proj.weight": "model-00002-of-00006.safetensors",
|
134 |
+
"model.transformer.ff_out.weight": "model-00006-of-00006.safetensors",
|
135 |
+
"model.transformer.wte.weight": "model-00001-of-00006.safetensors"
|
136 |
+
}
|
137 |
+
}
|
modeling_olmo.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import fields
|
3 |
+
from typing import List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from transformers import PreTrainedModel
|
7 |
+
from transformers.cache_utils import Cache
|
8 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
9 |
+
from transformers.models.auto import AutoModelForCausalLM
|
10 |
+
|
11 |
+
from olmo.config import ModelConfig
|
12 |
+
from olmo.model import OLMo
|
13 |
+
|
14 |
+
from .configuration_olmo import OLMoConfig
|
15 |
+
|
16 |
+
log = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def create_model_config_from_pretrained_config(config: OLMoConfig):
|
20 |
+
"""
|
21 |
+
Utility function
|
22 |
+
"""
|
23 |
+
|
24 |
+
kwargs = {}
|
25 |
+
for field in fields(ModelConfig):
|
26 |
+
kwargs[field.name] = getattr(config, field.name)
|
27 |
+
|
28 |
+
model_config = ModelConfig(**kwargs)
|
29 |
+
return model_config
|
30 |
+
|
31 |
+
|
32 |
+
class OLMoForCausalLM(PreTrainedModel):
|
33 |
+
"""
|
34 |
+
Extremely barebones HF model wrapper.
|
35 |
+
"""
|
36 |
+
|
37 |
+
config_class = OLMoConfig
|
38 |
+
base_model_prefix = "model"
|
39 |
+
_no_split_modules = ["OLMoBlock"]
|
40 |
+
|
41 |
+
def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False):
|
42 |
+
super().__init__(config)
|
43 |
+
|
44 |
+
if not model:
|
45 |
+
model_config = create_model_config_from_pretrained_config(config)
|
46 |
+
# Initialize model (always on CPU to start with so we don't run out of GPU memory).
|
47 |
+
model_config.init_device = "cpu"
|
48 |
+
self.model = OLMo(model_config, init_params=init_params)
|
49 |
+
else:
|
50 |
+
self.model = model
|
51 |
+
|
52 |
+
def forward(
|
53 |
+
self,
|
54 |
+
input_ids: torch.LongTensor = None,
|
55 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
56 |
+
attention_mask: Optional[torch.Tensor] = None,
|
57 |
+
attention_bias: Optional[torch.Tensor] = None,
|
58 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
59 |
+
labels: Optional[torch.LongTensor] = None,
|
60 |
+
use_cache: Optional[bool] = None,
|
61 |
+
output_attentions: Optional[bool] = None,
|
62 |
+
output_hidden_states: Optional[bool] = None,
|
63 |
+
return_dict: Optional[bool] = None,
|
64 |
+
cache_position: Optional[
|
65 |
+
Cache
|
66 |
+
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
|
67 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
68 |
+
if use_cache is None:
|
69 |
+
use_cache = self.config.use_cache
|
70 |
+
|
71 |
+
if output_attentions:
|
72 |
+
raise ValueError("output_attentions is not yet supported in OLMo")
|
73 |
+
|
74 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
75 |
+
|
76 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
77 |
+
outputs = self.model.forward(
|
78 |
+
input_ids=input_ids,
|
79 |
+
input_embeddings=inputs_embeds,
|
80 |
+
attention_mask=attention_mask,
|
81 |
+
attention_bias=attention_bias,
|
82 |
+
past_key_values=past_key_values,
|
83 |
+
use_cache=use_cache,
|
84 |
+
output_hidden_states=output_hidden_states,
|
85 |
+
)
|
86 |
+
|
87 |
+
logits = outputs.logits
|
88 |
+
hidden_states = outputs.hidden_states
|
89 |
+
|
90 |
+
loss = None
|
91 |
+
if labels is not None:
|
92 |
+
# Shift so that tokens < n predict n
|
93 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
94 |
+
shift_labels = labels[..., 1:].contiguous()
|
95 |
+
# Flatten the tokens
|
96 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
97 |
+
shift_logits = shift_logits.view(-1, self.config.embedding_size)
|
98 |
+
shift_labels = shift_labels.view(-1)
|
99 |
+
# Enable model parallelism
|
100 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
101 |
+
loss = loss_fct(shift_logits, shift_labels)
|
102 |
+
|
103 |
+
if not return_dict:
|
104 |
+
output = (logits,) + outputs[1:]
|
105 |
+
return (loss,) + output if loss is not None else output
|
106 |
+
|
107 |
+
return CausalLMOutputWithPast(
|
108 |
+
loss=loss,
|
109 |
+
logits=logits,
|
110 |
+
past_key_values=outputs.attn_key_values,
|
111 |
+
hidden_states=hidden_states,
|
112 |
+
)
|
113 |
+
|
114 |
+
def can_generate(self) -> bool:
|
115 |
+
return True
|
116 |
+
|
117 |
+
def prepare_inputs_for_generation(
|
118 |
+
self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
|
119 |
+
):
|
120 |
+
if past_key_values:
|
121 |
+
# This is because we want the model to only process the last generated token.
|
122 |
+
input_ids = input_ids[:, -1:]
|
123 |
+
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
124 |
+
|
125 |
+
model_inputs.update(kwargs)
|
126 |
+
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
|
127 |
+
return model_inputs
|
128 |
+
|
129 |
+
# TODO: these are required to make the implementation complete.
|
130 |
+
# def resize_position_embeddings(self, new_num_position_embeddings: int):
|
131 |
+
# pass
|
132 |
+
#
|
133 |
+
# def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
134 |
+
# pass
|
135 |
+
#
|
136 |
+
# def _reorder_cache(self, past_key_values, beam_idx):
|
137 |
+
# pass
|
138 |
+
|
139 |
+
def get_input_embeddings(self) -> torch.nn.Module:
|
140 |
+
return self.model.transformer.wte
|
141 |
+
|
142 |
+
def set_input_embeddings(self, value: torch.nn.Module):
|
143 |
+
self.model.transformer.wte = value
|
144 |
+
|
145 |
+
def get_output_embeddings(self):
|
146 |
+
if self.config.weight_tying:
|
147 |
+
return self.model.transformer.wte
|
148 |
+
else:
|
149 |
+
return self.model.transformer.ff_out
|
150 |
+
|
151 |
+
def set_output_embeddings(self, value: torch.nn.Module):
|
152 |
+
if self.config.weight_tying:
|
153 |
+
self.model.transformer.wte = value
|
154 |
+
else:
|
155 |
+
self.model.transformer.ff_out = value
|
156 |
+
|
157 |
+
def tie_weights(self):
|
158 |
+
"""
|
159 |
+
This function is intentionally left as a no-op.
|
160 |
+
|
161 |
+
Weight tying is handled as follows:
|
162 |
+
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
|
163 |
+
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
|
164 |
+
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
|
165 |
+
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
|
166 |
+
|
167 |
+
Therefore, there is no need to explicitly tie the weights in this function.
|
168 |
+
"""
|
169 |
+
pass
|
170 |
+
|
171 |
+
def resize_token_embeddings(
|
172 |
+
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
|
173 |
+
) -> torch.nn.Embedding:
|
174 |
+
"""
|
175 |
+
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
|
176 |
+
|
177 |
+
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
178 |
+
|
179 |
+
Arguments:
|
180 |
+
new_num_tokens (`int`, *optional*):
|
181 |
+
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
|
182 |
+
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
|
183 |
+
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
|
184 |
+
pad_to_multiple_of (`int`, *optional*):
|
185 |
+
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
|
186 |
+
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
|
187 |
+
|
188 |
+
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
|
189 |
+
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
|
190 |
+
details about this, or help on choosing the correct value for resizing, refer to this guide:
|
191 |
+
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
|
192 |
+
|
193 |
+
Return:
|
194 |
+
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
|
195 |
+
|
196 |
+
Note:
|
197 |
+
This method differs from the base class implementation by resizing the `embedding_size` attribute of the
|
198 |
+
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
|
199 |
+
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
|
200 |
+
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
|
201 |
+
"""
|
202 |
+
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
203 |
+
if new_num_tokens is None and pad_to_multiple_of is None:
|
204 |
+
return model_embeds
|
205 |
+
|
206 |
+
# Update base model and current model config
|
207 |
+
self.config.embedding_size = model_embeds.weight.shape[0]
|
208 |
+
self.model.config.embedding_size = model_embeds.weight.shape[0]
|
209 |
+
|
210 |
+
# Check if the embedding size is less than the vocab size
|
211 |
+
if self.config.embedding_size < self.config.vocab_size:
|
212 |
+
warning_message = (
|
213 |
+
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
|
214 |
+
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
|
215 |
+
"size is less than or equal to the new token embedding size."
|
216 |
+
)
|
217 |
+
log.warning(warning_message)
|
218 |
+
|
219 |
+
# Tie weights again if needed
|
220 |
+
self.tie_weights()
|
221 |
+
|
222 |
+
return model_embeds
|
223 |
+
|
224 |
+
|
225 |
+
# Register the model so that it is available for transformer pipelines, auto-loading, etc.
|
226 |
+
# OLMo is integrated directly in transformers from v4.40.0 onwards, but the version in transformers
|
227 |
+
# may not support the newest architectures we create.
|
228 |
+
AutoModelForCausalLM.register(OLMoConfig, OLMoForCausalLM)
|