|
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling |
|
from datasets import load_dataset |
|
from model import SASOKModel, SASOKConfig |
|
from tokenizer import tokenizer |
|
|
|
dataset = load_dataset("wikitext", "wikitext-2-raw-v1") |
|
tokenized = dataset.map(lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True) |
|
|
|
config = SASOKConfig() |
|
model = SASOKModel(config) |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./sasok_output", |
|
evaluation_strategy="steps", |
|
eval_steps=500, |
|
per_device_train_batch_size=4, |
|
num_train_epochs=3, |
|
save_steps=1000, |
|
logging_dir="./logs" |
|
) |
|
|
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized["train"], |
|
eval_dataset=tokenized["validation"], |
|
tokenizer=tokenizer, |
|
data_collator=data_collator |
|
) |
|
|
|
trainer.train() |