| |
|
| | from pathlib import Path
|
| | from model import build_transformer
|
| | from util import create_resources
|
| | import torch
|
| | import sys
|
| | import yaml
|
| | import sacrebleu
|
| |
|
| | def translate(sentence: str):
|
| |
|
| | with open("config.yaml", "r") as file:
|
| | config = yaml.safe_load(file)
|
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| | print("Using device:", device)
|
| | train_dataloader,valid_dataloader,test_dataloader,tokenizer_src,tokenizer_tgt = create_resources()
|
| |
|
| | src_vocab_size = tokenizer_src.get_vocab_size()
|
| | tgt_vocab_size = tokenizer_tgt.get_vocab_size()
|
| |
|
| |
|
| | model = build_transformer(
|
| | src_vocab_size,
|
| | tgt_vocab_size,
|
| | config["seq_len"],
|
| | config["seq_len"],
|
| | config["num_enc_dec_blocks"],
|
| | config["num_of_heads"],
|
| | config["d_model"]
|
| | )
|
| |
|
| | model = model.to(device)
|
| |
|
| | model_filename = "models/model_epoch_15.pth"
|
| | state = torch.load(model_filename)
|
| | model.load_state_dict(state['model_state_dict'])
|
| |
|
| | model.eval()
|
| | with torch.no_grad():
|
| | source = tokenizer_src.encode(sentence)
|
| | print(source,source.ids)
|
| |
|
| | source = torch.cat([
|
| | torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
|
| | torch.tensor(source.ids, dtype=torch.int64),
|
| | torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
|
| | torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config["seq_len"] - len(source.ids) - 2), dtype=torch.int64)
|
| | ], dim=0)
|
| |
|
| | source = source.to(device)
|
| | source = source.unsqueeze(0)
|
| |
|
| | print(source.shape)
|
| |
|
| |
|
| |
|
| | source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
|
| |
|
| | encoder_output = model.encode(source, source_mask)
|
| | decoder_input = torch.full((1, 1), tokenizer_tgt.token_to_id('[SOS]'),
|
| | dtype=torch.long, device=device)
|
| |
|
| |
|
| | while decoder_input.size(1) < config["seq_len"]:
|
| |
|
| | decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
|
| | diagonal=1).to(device, dtype=torch.int)
|
| |
|
| |
|
| | out = model.decode(decoder_input, encoder_output, source_mask, decoder_mask)
|
| | prob = model.project(out[:, -1])
|
| | _, next_word = torch.max(prob, dim=1)
|
| |
|
| | next_token = torch.full((1, 1), next_word.item(), dtype=torch.long, device=device)
|
| | decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
| |
|
| | print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')
|
| |
|
| | if next_word.item() == tokenizer_tgt.token_to_id('[EOS]'):
|
| | break
|
| |
|
| |
|
| | return tokenizer_tgt.decode(decoder_input[0].tolist())
|
| |
|
| |
|
| |
|
| |
|
| | a = translate("Defending champion Kolkata Knight Riders (KKR) hosts Royal Challengers Bengaluru (RCB) at the Eden Gardens in the Indian Premier League 2025 opener on Saturday.")
|
| | a = translate("Which South American country is home to the Amazon Rainforest and the Christ the Redeemer statue?")
|
| |
|
| |
|
| |
|
| | with open("output.txt", "a") as w:
|
| | w.write(f"\n{a}")
|
| |
|