|
from transformers import PretrainedConfig
|
|
import json
|
|
|
|
|
|
class HyenaConfig(PretrainedConfig):
|
|
model_type = "hyenadna"
|
|
def __init__(
|
|
self,
|
|
vocab_size=12,
|
|
d_model=256,
|
|
d_inner=None,
|
|
use_bias=True,
|
|
train_freq=True,
|
|
max_seq_len=1024,
|
|
emb_dim=3,
|
|
n_layer=12,
|
|
num_inner_mlps=2,
|
|
hyena_order=2,
|
|
short_filter_order=3,
|
|
filter_order=64,
|
|
activation_freq=1,
|
|
embed_dropout=0.1,
|
|
hyena_dropout=0.0,
|
|
hyena_filter_dropout=0.0,
|
|
layer_norm_epsilon=1e-5,
|
|
initializer_range=0.02,
|
|
pad_vocab_size_multiple=8,
|
|
**kwargs,
|
|
):
|
|
self.vocab_size = vocab_size
|
|
self.d_model = d_model
|
|
if d_inner is None:
|
|
self.d_inner = 4 * d_model
|
|
else:
|
|
self.d_inner = d_inner
|
|
self.use_bias = use_bias
|
|
self.train_freq = train_freq
|
|
self.max_seq_len = max_seq_len
|
|
self.emb_dim = emb_dim
|
|
self.n_layer = n_layer
|
|
self.hyena_order = hyena_order
|
|
self.filter_order = filter_order
|
|
self.short_filter_order = short_filter_order
|
|
self.activation_freq = activation_freq
|
|
self.num_inner_mlps = num_inner_mlps
|
|
self.embed_dropout = embed_dropout
|
|
self.hyena_dropout = hyena_dropout
|
|
self.hyena_filter_dropout = hyena_filter_dropout
|
|
self.layer_norm_epsilon = layer_norm_epsilon
|
|
self.initializer_range = initializer_range
|
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
super().__init__(**kwargs)
|
|
|
|
@classmethod
|
|
def from_original_config(cls, config_path, **kwargs):
|
|
with open(config_path, "r") as f:
|
|
config = json.load(f)
|
|
|
|
vocab_size = config["vocab_size"]
|
|
d_model = config["d_model"]
|
|
d_inner = config["d_inner"]
|
|
max_seq_len = config["layer"]["l_max"]
|
|
emb_dim = config["layer"]["emb_dim"]
|
|
filter_order = config["layer"]["filter_order"]
|
|
if "local_order" in config["layer"]:
|
|
short_filter_order = config["layer"]["local_order"]
|
|
elif "short_filter_order" in config["layer"]:
|
|
short_filter_order = config["layer"]["short_filter_order"]
|
|
else:
|
|
short_filter_order = 3
|
|
n_layer = config["n_layer"]
|
|
activation_freq = config["layer"]["w"]
|
|
embed_dropout = config["embed_dropout"]
|
|
pad_vocab_size_multiple = config["pad_vocab_size_multiple"]
|
|
return cls(vocab_size=vocab_size,
|
|
d_model=d_model,
|
|
d_inner=d_inner,
|
|
max_seq_len=max_seq_len,
|
|
emb_dim=emb_dim,
|
|
filter_order=filter_order,
|
|
short_filter_order=short_filter_order,
|
|
n_layer=n_layer,
|
|
activation_freq=activation_freq,
|
|
embed_dropout=embed_dropout,
|
|
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
|
tie_word_embeddings=False,
|
|
**kwargs
|
|
) |