|
--- |
|
license: apache-2.0 |
|
--- |
|
First we define a class T5ClassificationModel: |
|
```python |
|
from transformers import ( |
|
T5Config, |
|
T5EncoderModel, |
|
T5Tokenizer, |
|
PreTrainedModel, |
|
TrainingArguments, |
|
Trainer, |
|
DataCollatorWithPadding, |
|
) |
|
class T5ClassificationModel(PreTrainedModel): |
|
config_class = T5Config |
|
|
|
def __init__(self, config, d_model=None, num_classes=2): |
|
super().__init__(config) |
|
self.num_classes = num_classes |
|
|
|
self.encoder = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50") |
|
|
|
hidden_dim = d_model if d_model is not None else config.d_model |
|
self.classification_head = nn.Linear(hidden_dim, num_classes) |
|
|
|
def forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
labels=None, |
|
**kwargs |
|
): |
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) |
|
hidden_states = encoder_outputs.last_hidden_state |
|
|
|
mask = attention_mask.unsqueeze(-1) |
|
pooled_output = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1) |
|
logits = self.classification_head(pooled_output) # [batch_size, num_classes] |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(torch.bfloat16) |
|
loss = nn.CrossEntropyLoss()(logits, labels) |
|
|
|
return { |
|
"loss": loss, |
|
"logits": logits |
|
} |
|
``` |
|
Then we load our pretrained model |
|
```python |
|
tokenizer = T5Tokenizer.from_pretrained("jiaxie/DeepProtT5-Human", do_lower_case=False) |
|
model = T5ClassificationModel.from_pretrained("jiaxie/DeepProtT5-Human", torch_dtype=torch.bfloat16).to("cuda") |
|
``` |