from transformers import PretrainedConfig | |
class AuriStreamConfig(PretrainedConfig): | |
model_type = "AuriStream.AuriStream" | |
def __init__( | |
self, | |
n_layer=48, | |
n_head=16, | |
n_embd=1280, | |
vocab_size=8192, | |
dropout=0.0, | |
bias=False, | |
use_rope=False, | |
n_pred_steps=20, | |
seq_len=4096, | |
**kwargs | |
): | |
self.n_layer = n_layer | |
self.n_head = n_head | |
self.n_embd = n_embd | |
self.vocab_size = vocab_size | |
self.dropout = dropout | |
self.bias = bias | |
self.use_rope = use_rope | |
self.n_pred_steps = n_pred_steps | |
self.seq_len = seq_len | |
super().__init__(**kwargs) | |