metadata
library_name: transformers
license: apache-2.0
metrics:
- perplexity
base_model:
- facebook/esm1b_t33_650M_UR50S
Fine-Tuning ESM-1b with Multiple Sequence Alignment (MSA) for Phosphosites
This repository provides a fine-tuned version of ESM-1b with Masked Language Modeling(MLM) Objective, incorporating genomic information by leveraging long phosphosite sequences from DARKIN dataset and Multiple Sequence Alignment (MSA) of those phosphosites. The goal is to enhance the model's understanding of phosphorylation by integrating sequence conservation patterns.
Developed by:
Zeynep Işık (MSc, Sabanci University)
Dataset & Preprocessing
To construct a robust dataset, we extracted 256 MSA sequences per phosphosite from publicly available sequence databases. This resulted in a dataset of approximately 2 million sequences. Due to the large data size, the following preprocessing steps were applied:
- Selection of MSA Sequences for Labeled Data
- Up to 10 MSA sequences were selected per human phosphosite.
- This resulted in a final dataset of 98,000 samples.
- Dataset Splitting
- 10% of the data was reserved for validation.
- The remaining 90% was used for fine-tuning with the Masked Language Modeling (MLM) objective.
- Data Processing & Preprocessing
- Special attention was given to conserving phosphorylation residues within sequences.
- To optimize memory efficiency, sequence lengths were truncated to 128 amino acids.
Evaluation
Perplexity: 2.69 (decreased from 7.05)
Usage
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
# Load the model and tokenizer
model_name = "isikz/phosphosite_msa_finetuned_esm1b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
# Example sequence with a masked residue
sequence = "MKTLLLTLVVV[MASK]VCLDLGYTGV"
# Tokenize input
inputs = tokenizer(sequence, return_tensors="pt")
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_token_id = torch.argmax(logits[0, 10]).item() # Assuming MASK is at position 10
predicted_token = tokenizer.decode([predicted_token_id])
print(f"Predicted Residue: {predicted_token}")