SASOK_V1 / train.py
TSheylock's picture
Upload 5 files
10f998c verified
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from model import SASOKModel, SASOKConfig
from tokenizer import tokenizer
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
tokenized = dataset.map(lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True)
config = SASOKConfig()
model = SASOKModel(config)
training_args = TrainingArguments(
output_dir="./sasok_output",
evaluation_strategy="steps",
eval_steps=500,
per_device_train_batch_size=4,
num_train_epochs=3,
save_steps=1000,
logging_dir="./logs"
)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["validation"],
tokenizer=tokenizer,
data_collator=data_collator
)
trainer.train()