import torch.nn as nn from transformers import PreTrainedModel from .configuration_exon_classifier import Evo2ExonConfig class Evo2ExonModel(PreTrainedModel): config_class = Evo2ExonConfig base_model_prefix = "evo2_exon_classifier" def __init__(self, config: Evo2ExonConfig): super().__init__(config) # ▸ build (Linear + ReLU) * n + final Linear(…, 1) layers = [nn.Linear(config.embedding_dim, config.hidden_dim), nn.ReLU()] for _ in range(config.num_hidden_layers - 1): layers += [nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU()] layers += [nn.Linear(config.hidden_dim, 1)] self.fc_layers = nn.Sequential(*layers) self.sigmoid = nn.Sigmoid() # convert logits → probability def forward(self, inputs_embeds, labels=None, **kwargs): """ inputs_embeds : (batch, seq_len, embedding_dim) labels : (batch, seq_len) optional, 0/1 floats or ints """ bsz, seq_len, _ = inputs_embeds.shape # flatten → run FC layers → reshape back logits = self.fc_layers(inputs_embeds.view(-1, inputs_embeds.size(-1))) logits = logits.view(bsz, seq_len) probs = self.sigmoid(logits) if labels is not None: loss = nn.BCELoss()(probs, labels.float()) return {"loss": loss, "logits": probs} return {"logits": probs}