lhallee commited on
Commit
b1f43c1
·
verified ·
1 Parent(s): d6fcec3

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +112 -120
README.md CHANGED
@@ -1,120 +1,112 @@
1
- ---
2
- library_name: transformers
3
- tags: []
4
- ---
5
-
6
- # FastESM
7
- FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
-
9
- Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
10
-
11
- Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
- Various other optimizations also make the base implementation slightly different than the one in transformers.
13
-
14
- # FastESM2-650
15
-
16
- ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
17
- To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
18
-
19
- ## Use with 🤗 transformers
20
-
21
- ### For working with embeddings
22
- ```python
23
- import torch
24
- from transformers import AutoModel, AutoTokenizer
25
-
26
- model_path = 'Synthyra/FastESM2_650'
27
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
28
- tokenizer = model.tokenizer
29
-
30
- sequences = ['MPRTEIN', 'MSEQWENCE']
31
- tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
32
- with torch.no_grad():
33
- embeddings = model(**tokenized).last_hidden_state
34
-
35
- print(embeddings.shape) # (2, 11, 1280)
36
- ```
37
-
38
- ### For working with sequence logits
39
- ```python
40
- import torch
41
- from transformers import AutoModelForMaskedLM, AutoTokenizer
42
-
43
- model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
44
- with torch.no_grad():
45
- logits = model(**tokenized).logits
46
-
47
- print(logits.shape) # (2, 11, 33)
48
- ```
49
-
50
- ### For working with attention maps
51
- ```python
52
- import torch
53
- from transformers import AutoModel, AutoTokenizer
54
-
55
- model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
56
- with torch.no_grad():
57
- attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
58
-
59
- print(attentions[-1].shape) # (2, 20, 11, 11)
60
- ```
61
-
62
-
63
- ## Embed entire datasets with no new code
64
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
65
- ```python
66
- embeddings = model.embed_dataset(
67
- sequences=sequences, # list of protein strings
68
- batch_size=16, # embedding batch size
69
- max_len=2048, # truncate to max_len
70
- full_embeddings=True, # return residue-wise embeddings
71
- full_precision=False, # store as float32
72
- pooling_type='mean', # use mean pooling if protein-wise embeddings
73
- num_workers=0, # data loading num workers
74
- sql=False, # return dictionary of sequences and embeddings
75
- )
76
-
77
- _ = model.embed_dataset(
78
- sequences=sequences, # list of protein strings
79
- batch_size=16, # embedding batch size
80
- max_len=2048, # truncate to max_len
81
- full_embeddings=True, # return residue-wise embeddings
82
- full_precision=False, # store as float32
83
- pooling_type='mean', # use mean pooling if protein-wise embeddings
84
- num_workers=0, # data loading num workers
85
- sql=True, # store sequences in local SQL database
86
- sql_db_path='embeddings.db', # path to .db file of choice
87
- )
88
- ```
89
-
90
- ## Model probes
91
- We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
92
-
93
- The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
94
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png)
95
-
96
- ## Comparison of half precisions
97
- Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
98
-
99
- When summing the MSE of 1000 sequences vs. the fp32 weights:
100
-
101
- Average MSE for FP16: 0.00000140
102
-
103
- Average MSE for BF16: 0.00004125
104
-
105
- ### Inference speed
106
- We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
107
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png)
108
-
109
- ### Citation
110
- If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
111
- ```
112
- @misc {FastESM2,
113
- author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
114
- title = { FastESM2 },
115
- year = 2024,
116
- url = { https://huggingface.co/Synthyra/FastESM2_650 },
117
- doi = { 10.57967/hf/3729 },
118
- publisher = { Hugging Face }
119
- }
120
- ```
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # FastESM
7
+ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
+
9
+ Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
10
+
11
+ Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
+ Various other optimizations also make the base implementation slightly different than the one in transformers.
13
+
14
+ ## Use with 🤗 transformers
15
+
16
+ ### Supported models
17
+ ```python
18
+ model_dict = {
19
+ # Synthyra/ESM2-8M
20
+ 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
21
+ # Synthyra/ESM2-35M
22
+ 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
23
+ # Synthyra/ESM2-150M
24
+ 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
25
+ # Synthyra/ESM2-650M
26
+ 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
27
+ # Synthyra/ESM2-3B
28
+ 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
29
+ }
30
+ ```
31
+
32
+ ### For working with embeddings
33
+ ```python
34
+ import torch
35
+ from transformers import AutoModel, AutoTokenizer
36
+
37
+ model_path = 'Synthyra/ESM2-8M'
38
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
39
+ tokenizer = model.tokenizer
40
+
41
+ sequences = ['MPRTEIN', 'MSEQWENCE']
42
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
43
+ with torch.no_grad():
44
+ embeddings = model(**tokenized).last_hidden_state
45
+
46
+ print(embeddings.shape) # (2, 11, 1280)
47
+ ```
48
+
49
+ ### For working with sequence logits
50
+ ```python
51
+ import torch
52
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
53
+
54
+ model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
55
+ with torch.no_grad():
56
+ logits = model(**tokenized).logits
57
+
58
+ print(logits.shape) # (2, 11, 33)
59
+ ```
60
+
61
+ ### For working with attention maps
62
+ ```python
63
+ import torch
64
+ from transformers import AutoModel, AutoTokenizer
65
+
66
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
67
+ with torch.no_grad():
68
+ attentions = model(**tokenized, output_attentions).attentions # tuples of (batch_size, num_heads, seq_len, seq_len)
69
+
70
+ print(attentions[-1].shape) # (2, 20, 11, 11)
71
+ ```
72
+
73
+ ## Embed entire datasets with no new code
74
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
75
+ ```python
76
+ embeddings = model.embed_dataset(
77
+ sequences=sequences, # list of protein strings
78
+ batch_size=16, # embedding batch size
79
+ max_len=2048, # truncate to max_len
80
+ full_embeddings=True, # return residue-wise embeddings
81
+ full_precision=False, # store as float32
82
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
83
+ num_workers=0, # data loading num workers
84
+ sql=False, # return dictionary of sequences and embeddings
85
+ )
86
+
87
+ _ = model.embed_dataset(
88
+ sequences=sequences, # list of protein strings
89
+ batch_size=16, # embedding batch size
90
+ max_len=2048, # truncate to max_len
91
+ full_embeddings=True, # return residue-wise embeddings
92
+ full_precision=False, # store as float32
93
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
94
+ num_workers=0, # data loading num workers
95
+ sql=True, # store sequences in local SQL database
96
+ sql_db_path='embeddings.db', # path to .db file of choice
97
+ )
98
+ ```
99
+
100
+
101
+ ### Citation
102
+ If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
103
+ ```
104
+ @misc {FastESM2,
105
+ author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
106
+ title = { FastESM2 },
107
+ year = 2024,
108
+ url = { https://huggingface.co/Synthyra/FastESM2_650 },
109
+ doi = { 10.57967/hf/3729 },
110
+ publisher = { Hugging Face }
111
+ }
112
+ ```