does rnabert model have lm_head module?

#3
by mmjwxbc - opened

I want to input an RNA sequence into RnaBERT, and then restore the original sequence from the logits based on the features output by RnaBERT.
But It can't restroe seq successfully!!!
Here is my code:

from multimolecule import RnaTokenizer, RnaBertModel, RnaBertConfig, RnaBertForMaskedLM
import torch

# Create a configuration object and set the number of hidden layers to 10
config = RnaBertConfig()
config.num_hidden_layers = 10

# Load the pre - trained tokenizer
tokenizer = RnaTokenizer.from_pretrained("multimolecule/rnabert")

# Initialize the model with the configuration
model = RnaBertForMaskedLM(config)

# Define the input RNA sequence
text = "AUGC"
# Tokenize the text and convert it to PyTorch tensors
input = tokenizer(text, return_tensors="pt")

# Feed the input into the model
output = model(**input)

# Decode the predicted token IDs to get the reverted sequence
revert_seq = tokenizer.batch_decode(torch.argmax(output['logits'], dim=-1))
print(revert_seq)

Here is the output:

['<pad> <cls> R <pad> R -']

Hasn't the lm_head module been trained?

MultiMolecule org

You should use RnaBertForMaskedLM.from_ pretrained("multimolecule/rnabert") to use the pretrained RnaBERT.

Although the pertaining of RnaBERT is not very effective so the pretrained model may not be as good as you may thought.

thanks, I have also tried it before, but the result is bad. If I use rnafm, it will be greater than rnabert?

MultiMolecule org
edited 1 day ago

You can check out the examples on the top right of the model cards.
RNA-FM does perform better, but I believe only Ernie-RNA outperforms the random baseline on miR-21. Note that these sequences are for demos only, so they are very short, which may not accurately reflect the actual performance of the model.

import torch
from multimolecule import  ErnieRnaForMaskedLM, RnaTokenizer
model = ErnieRnaForMaskedLM.from_pretrained("multimolecule/ernierna")
tokenizer = RnaTokenizer.from_pretrained("multimolecule/rna")
input = tokenizer("ACGU", return_tensors="pt")
output = model(**input, labels=input["input_ids"])
out_logits = output["logits"]
out_ids = torch.argmax(out_logits, dim=-1)
out_seq = tokenizer.batch_decode(out_ids)
out_seq

Thanks, I have tried it, but I think the lm doesn't outperform the ability.

Sign up or log in to comment