3.2B parameter base model trained for ~64B tokens from the FineWeb dataset
uses gpt2 tokenizer from tiktoken
- note: increased batch size from 8 to 512 at step 2,160,000
- Final checkpoint: step 2,187,000, val_loss: 2.7489
- Trained on a 8xH100 80GB node using data parallel
Model config:
"d_head": 128,
"d_model": 8192,
"n_heads": 64,
"n_layers": 3,
"n_vocab": 50257
Usage:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("michaelbzhu/test-3.2B-base", trust_remote_code=True)
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained("michaelbzhu/test-3.2B-base", trust_remote_code=True)
prompt = "The future of AI is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
for _ in range(20):
logits = model(input_ids).logits[0, -1, :]
next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1).unsqueeze(0)
input_ids = torch.cat([input_ids, next_token], dim=1)
print(tokenizer.decode(input_ids[0]))
Eval:
$ lm_eval --model hf \
--model_args pretrained=michaelbzhu/test-3.2B-base,trust_remote_code=True \
--tasks mmlu_college_medicine,hellaswag,lambada_openai,arc_easy,winogrande,arc_challenge,openbookqa \
--device cuda:0 \
--batch_size 16
| Tasks |Version|Filter|n-shot| Metric | | Value | |Stderr|
|----------------|------:|------|-----:|----------|---|------:|---|-----:|
|arc_challenge | 1|none | 0|acc |↑ | 0.2363|± |0.0124|
| | |none | 0|acc_norm |↑ | 0.2637|± |0.0129|
|arc_easy | 1|none | 0|acc |↑ | 0.5758|± |0.0101|
| | |none | 0|acc_norm |↑ | 0.4996|± |0.0103|
|hellaswag | 1|none | 0|acc |↑ | 0.3827|± |0.0049|
| | |none | 0|acc_norm |↑ | 0.4846|± |0.0050|
|lambada_openai | 1|none | 0|acc |↑ | 0.4238|± |0.0069|
| | |none | 0|perplexity|↓ |14.7850|± |0.4335|
|college_medicine| 1|none | 0|acc |↑ | 0.2370|± |0.0324|
|openbookqa | 1|none | 0|acc |↑ | 0.2180|± |0.0185|
| | |none | 0|acc_norm |↑ | 0.3180|± |0.0208|
|winogrande | 1|none | 0|acc |↑ | 0.5367|± |0.0140|
- Downloads last month
- 344
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support