File size: 2,592 Bytes
1777329
 
 
 
dae91fa
 
be4cbb2
dae91fa
 
1777329
 
c5c82ce
 
 
 
 
 
 
 
 
0892652
c5c82ce
 
 
 
 
 
b07a847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9dd587a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5c82ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
---
tags:
- model_hub_mixin
- pytorch_model_hub_mixin
license: mit
datasets:
- kjj0/fineweb100B-gpt2
language:
- en
---

3.2B parameter base model trained for ~64B tokens from the FineWeb dataset

uses gpt2 tokenizer from tiktoken

[wandb training metrics](https://api.wandb.ai/links/teammapo-mapo-labs/zooq3iig)
- note: increased batch size from 8 to 512 at step 2,160,000
- Final checkpoint: step 2,187,000, val_loss: 2.7489
- Trained on a 8xH100 80GB node using data parallel

Model config:
```
"d_head": 128,
"d_model": 8192,
"n_heads": 64,
"n_layers": 3,
"n_vocab": 50257
```

Usage:
```
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("michaelbzhu/test-3.2B-base", trust_remote_code=True)
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained("michaelbzhu/test-3.2B-base", trust_remote_code=True)

prompt = "The future of AI is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
for _ in range(20):
    logits = model(input_ids).logits[0, -1, :]
    next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1).unsqueeze(0)
    input_ids = torch.cat([input_ids, next_token], dim=1)
print(tokenizer.decode(input_ids[0]))
```

Eval:
```
$ lm_eval --model hf \
    --model_args pretrained=michaelbzhu/test-3.2B-base,trust_remote_code=True \
    --tasks mmlu_college_medicine,hellaswag,lambada_openai,arc_easy,winogrande,arc_challenge,openbookqa \
    --device cuda:0 \
    --batch_size 16

|     Tasks      |Version|Filter|n-shot|  Metric  |   | Value |   |Stderr|
|----------------|------:|------|-----:|----------|---|------:|---|-----:|
|arc_challenge   |      1|none  |     0|acc       |↑  | 0.2363|±  |0.0124|
|                |       |none  |     0|acc_norm  |↑  | 0.2637|±  |0.0129|
|arc_easy        |      1|none  |     0|acc       |↑  | 0.5758|±  |0.0101|
|                |       |none  |     0|acc_norm  |↑  | 0.4996|±  |0.0103|
|hellaswag       |      1|none  |     0|acc       |↑  | 0.3827|±  |0.0049|
|                |       |none  |     0|acc_norm  |↑  | 0.4846|±  |0.0050|
|lambada_openai  |      1|none  |     0|acc       |↑  | 0.4238|±  |0.0069|
|                |       |none  |     0|perplexity|↓  |14.7850|±  |0.4335|
|college_medicine|      1|none  |     0|acc       |↑  | 0.2370|±  |0.0324|
|openbookqa      |      1|none  |     0|acc       |↑  | 0.2180|±  |0.0185|
|                |       |none  |     0|acc_norm  |↑  | 0.3180|±  |0.0208|
|winogrande      |      1|none  |     0|acc       |↑  | 0.5367|±  |0.0140|
```