3.2B parameter base model trained for ~64B tokens from the FineWeb dataset

uses gpt2 tokenizer from tiktoken

wandb training metrics

  • note: increased batch size from 8 to 512 at step 2,160,000
  • Final checkpoint: step 2,187,000, val_loss: 2.7489
  • Trained on a 8xH100 80GB node using data parallel

Model config:

"d_head": 128,
"d_model": 8192,
"n_heads": 64,
"n_layers": 3,
"n_vocab": 50257

Usage:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model = AutoModelForCausalLM.from_pretrained("michaelbzhu/test-3.2B-base", trust_remote_code=True)
model = model.cuda()
tokenizer = AutoTokenizer.from_pretrained("michaelbzhu/test-3.2B-base", trust_remote_code=True)

prompt = "The future of AI is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
for _ in range(20):
    logits = model(input_ids).logits[0, -1, :]
    next_token = torch.multinomial(torch.softmax(logits, dim=-1), 1).unsqueeze(0)
    input_ids = torch.cat([input_ids, next_token], dim=1)
print(tokenizer.decode(input_ids[0]))

Eval:

$ lm_eval --model hf \
    --model_args pretrained=michaelbzhu/test-3.2B-base,trust_remote_code=True \
    --tasks mmlu_college_medicine,hellaswag,lambada_openai,arc_easy,winogrande,arc_challenge,openbookqa \
    --device cuda:0 \
    --batch_size 16

|     Tasks      |Version|Filter|n-shot|  Metric  |   | Value |   |Stderr|
|----------------|------:|------|-----:|----------|---|------:|---|-----:|
|arc_challenge   |      1|none  |     0|acc       |↑  | 0.2363|±  |0.0124|
|                |       |none  |     0|acc_norm  |↑  | 0.2637|±  |0.0129|
|arc_easy        |      1|none  |     0|acc       |↑  | 0.5758|±  |0.0101|
|                |       |none  |     0|acc_norm  |↑  | 0.4996|±  |0.0103|
|hellaswag       |      1|none  |     0|acc       |↑  | 0.3827|±  |0.0049|
|                |       |none  |     0|acc_norm  |↑  | 0.4846|±  |0.0050|
|lambada_openai  |      1|none  |     0|acc       |↑  | 0.4238|±  |0.0069|
|                |       |none  |     0|perplexity|↓  |14.7850|±  |0.4335|
|college_medicine|      1|none  |     0|acc       |↑  | 0.2370|±  |0.0324|
|openbookqa      |      1|none  |     0|acc       |↑  | 0.2180|±  |0.0185|
|                |       |none  |     0|acc_norm  |↑  | 0.3180|±  |0.0208|
|winogrande      |      1|none  |     0|acc       |↑  | 0.5367|±  |0.0140|
Downloads last month
344
Safetensors
Model size
3.24B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train michaelbzhu/test-3.2B-base