mtasic85 commited on
Commit
1386dd6
·
1 Parent(s): afa8f6e

grokadamw.GrokAdamW

Browse files
scripts/pretrain-core-model.yaml CHANGED
@@ -63,8 +63,8 @@ train:
63
  log_interval: 1
64
 
65
  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)
66
- # global_batch_size: 512
67
- global_batch_size: 256
68
 
69
  # Number of samples per data-parallel rank (type: int, default: 4)
70
  # micro_batch_size: 4
@@ -114,12 +114,24 @@ eval:
114
 
115
  # Optimizer-related arguments
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  optimizer:
118
- # class_path: torch.optim.AdamW
119
- # class_path: torchao.prototype.low_bit_optim.AdamW8bit
120
- # class_path: torchao.prototype.low_bit_optim.AdamW4bit
121
- # class_path: bitsandbytes.optim.AdamW8bit
122
- class_path: bitsandbytes.optim.PagedAdamW8bit
123
  init_args:
124
  # (type: float, default: 0.001)
125
  lr: 1e-4
@@ -128,7 +140,7 @@ optimizer:
128
  # (type: tuple, default: (0.9,0.999))
129
  betas:
130
  - 0.9
131
- - 0.99
132
 
133
  # How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
134
  devices: auto
 
63
  log_interval: 1
64
 
65
  # Number of samples between optimizer steps across data-parallel ranks (type: int, default: 512)
66
+ global_batch_size: 512
67
+ # global_batch_size: 256
68
 
69
  # Number of samples per data-parallel rank (type: int, default: 4)
70
  # micro_batch_size: 4
 
114
 
115
  # Optimizer-related arguments
116
 
117
+ # optimizer:
118
+ # # class_path: torch.optim.AdamW
119
+ # class_path: torchao.prototype.low_bit_optim.AdamW8bit
120
+ # # class_path: torchao.prototype.low_bit_optim.AdamW4bit
121
+ # # class_path: bitsandbytes.optim.AdamW8bit
122
+ # # class_path: bitsandbytes.optim.PagedAdamW8bit
123
+ # init_args:
124
+ # # (type: float, default: 0.001)
125
+ # lr: 1e-4
126
+ # # (type: float, default: 0.01)
127
+ # weight_decay: 0.01
128
+ # # (type: tuple, default: (0.9,0.999))
129
+ # betas:
130
+ # - 0.9
131
+ # - 0.99
132
+
133
  optimizer:
134
+ class_path: grokadamw.GrokAdamW
 
 
 
 
135
  init_args:
136
  # (type: float, default: 0.001)
137
  lr: 1e-4
 
140
  # (type: tuple, default: (0.9,0.999))
141
  betas:
142
  - 0.9
143
+ - 0.999
144
 
145
  # How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
146
  devices: auto
scripts/requirements.in CHANGED
@@ -1,29 +1,22 @@
1
  # pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
2
  torch>=2.5.0,<2.6.0
3
  numpy<2.0
4
- torchao
5
 
6
  tqdm
 
7
  datasets
8
  jinja2
9
  transformers
10
  wandb
11
- # litgpt[all]
12
  litgpt[all] @ git+https://github.com/Lightning-AI/litgpt.git
13
  mergekit @ git+https://github.com/arcee-ai/mergekit.git
14
- # litgpt @ git+https://github.com/Lightning-AI/litgpt.git
15
- # litdata
16
- # litdata @ git+https://github.com/Lightning-AI/litdata.git
17
- # lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git
18
  # muon @ git+https://github.com/KellerJordan/Muon
19
  # pytorch-optimizer
20
- lm_eval[ifeval,math]
21
  bitsandbytes
22
- # grokadamw
23
  # sophia-opt
24
- # bitsandbytes
25
  # pyzstd
26
  # zstd
27
- unsloth
28
-
29
- Pillow
 
1
  # pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
2
  torch>=2.5.0,<2.6.0
3
  numpy<2.0
 
4
 
5
  tqdm
6
+ Pillow
7
  datasets
8
  jinja2
9
  transformers
10
  wandb
 
11
  litgpt[all] @ git+https://github.com/Lightning-AI/litgpt.git
12
  mergekit @ git+https://github.com/arcee-ai/mergekit.git
 
 
 
 
13
  # muon @ git+https://github.com/KellerJordan/Muon
14
  # pytorch-optimizer
15
+ torchao
16
  bitsandbytes
17
+ grokadamw
18
  # sophia-opt
 
19
  # pyzstd
20
  # zstd
21
+ # unsloth
22
+ lm_eval[ifeval,math]