tlhuang commited on
Commit
3346881
·
verified ·
1 Parent(s): 0d8d98b

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. README.md +133 -0
  3. stfm.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ stfm.pth filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # STPath: A Generative Foundation Model for Integrating Spatial Transcriptomics and Whole Slide Images
2
+
3
+ This is a Huggingface repo for the paper:
4
+
5
+ > Tinglin Huang, Tianyu Liu, Mehrtash Babadi, Rex Ying, and Wengong Jin (2025). STPath: A Generative Foundation Model for Integrating Spatial Transcriptomics and Whole Slide Images. Paper in [bioRxiv](https://www.biorxiv.org/content/10.1101/2025.04.19.649665v2.abstract). Code in [GitHub](https://github.com/Graph-and-Geometric-Learning/STPath).
6
+
7
+
8
+ ## Usage
9
+
10
+ We provide an easy-to-use interface for users to perform inference on the pre-trained model, which can be found in `app/pipeline/inference.py`. Specifically, the following code snippet shows how to use it:
11
+
12
+ ```python
13
+ from stpath.app.pipeline.inference import STPathInference
14
+
15
+ agent = STPathInference(
16
+ gene_voc_path='STPath_dir/utils_data/symbol2ensembl.json',
17
+ model_weight_path='your_dir/stpath.pkl',
18
+ device=0
19
+ )
20
+
21
+ pred_adata = agent.inference(
22
+ coords=coords, # [number_of_spots, 2]
23
+ img_features=embeddings, # [number_of_spots, 1536], the image features extracted using Gigapath
24
+ organ_type="Kidney", # Default is None
25
+ tech_type="Visium", # Default is None
26
+ save_gene_names=hvg_list # a list of gene names to save in the adata, e.g., ['GATA3', 'UBLE2C', ...]. None will save all genes in the model.
27
+ )
28
+
29
+ # save adata
30
+ pred_adata.write_h5ad(f"your_dir/pred_{sample_id}.h5ad")
31
+ ```
32
+
33
+ The vocabularies for organs and technologies can be found in the following locations:
34
+ * [organ vocabulary](https://github.com/Graph-and-Geometric-Learning/STPath/blob/main/stpath/utils/constants.py#L98)
35
+ * [tech vocabulary](https://github.com/Graph-and-Geometric-Learning/STPath/blob/main/stpath/utils/constants.py#L20)
36
+
37
+ If the organ type or the tech type is unknown, you can set them to `None` in the inference function. Besides, the predicted gene expression values are log1p-transformed (`log(1 + x)`), consistent with the transformation applied during the training of STPath.
38
+
39
+
40
+ ### Example of Inference
41
+
42
+ Here, we provide an example of how to perform inference on a [sample](https://github.com/Graph-and-Geometric-Learning/STPath/tree/main/example_data) from the HEST dataset:
43
+
44
+ ```python
45
+ from scipy.stats import pearsonr
46
+ from stpath.hest_utils.st_dataset import load_adata
47
+ from stpath.hest_utils.file_utils import read_assets_from_h5
48
+
49
+ sample_id = "INT2"
50
+ source_dataroot = "STPath_dir" # the root directory of the STPath repository
51
+ with open(os.path.join(source_dataroot, "example_data/var_50genes.json")) as f:
52
+ hvg_list = json.load(f)['genes']
53
+
54
+ data_dict, _ = read_assets_from_h5(os.path.join(source_dataroot, f"{sample_id}.h5")) # load the data from the h5 file
55
+ coords = data_dict["coords"]
56
+ embeddings = data_dict["embeddings"]
57
+ barcodes = data_dict["barcodes"].flatten().astype(str).tolist()
58
+ adata = sc.read_h5ad(os.path.join(source_dataroot, f"{sample_id}.h5ad"))[barcodes, :]
59
+
60
+ # The return pred_adata includes the expressions of the genes in hvg_list, which is a list of highly variable genes.
61
+ pred_adata = agent.inference(
62
+ coords=coords,
63
+ img_features=embeddings,
64
+ organ_type="Kidney",
65
+ tech_type="Visium",
66
+ save_gene_names=hvg_list # we only need the highly variable genes for evaluation
67
+ )
68
+
69
+ # calculate the Pearson correlation coefficient between the predicted and ground truth gene expression
70
+ all_pearson_list = []
71
+ gt = np.log1p(adata[:, hvg_list].X.toarray()) # sparse -> dense
72
+ # go through each gene in the highly variable genes list
73
+ for i in range(len(hvg_list)):
74
+ pearson_corr, _ = pearsonr(gt[:, i], pred_adata.X[:, i])
75
+ all_pearson_list.append(pearson_corr.item())
76
+ print(f"Pearson correlation for {sample_id}: {np.mean(all_pearson_list)}") # 0.1562
77
+ ```
78
+
79
+ ### In-context Learning
80
+
81
+ STPath also support in-context learning, which allows users to provide the expression of a few spots to guide the model to predict the expression of other spots:
82
+
83
+ ```python
84
+ from stpath.data.sampling_utils import PatchSampler
85
+
86
+ rightest_coord = np.where(coords[:, 0] == coords[:, 0].max())[0][0]
87
+ masked_ids = PatchSampler.sample_nearest_patch(coords, int(len(coords) * 0.95), rightest_coord) # predict the expression of the 95% spots
88
+ context_ids = np.setdiff1d(np.arange(len(coords)), masked_ids) # the index not in masked_ids will be used as context
89
+ context_gene_exps = adata.X.toarray()[context_ids]
90
+ context_gene_names = adata.var_names.tolist()
91
+
92
+ pred_adata = agent.inference(
93
+ coords=coords,
94
+ img_features=embeddings,
95
+ context_ids=context_ids, # the index of the context spots
96
+ context_gene_exps=context_gene_exps, # the expression of the context spots
97
+ context_gene_names=context_gene_names, # the gene names of the context spots
98
+ organ_type="Kidney",
99
+ tech_type="Visium",
100
+ save_gene_names=hvg_list,
101
+ )
102
+
103
+ all_pearson_list = []
104
+ gt = np.log1p(adata[:, hvg_list].X.toarray())[masked_ids, :] # groundtruth expression of the spots in masked_ids
105
+ pred = pred_adata.X[masked_ids, :] # predicted expression of the spots in masked_ids
106
+ for i in range(len(hvg_list)):
107
+ pearson_corr, _ = pearsonr(gt[:, i], pred[:, i])
108
+ all_pearson_list.append(pearson_corr.item())
109
+ print(f"Pearson correlation for {sample_id}: {np.mean(all_pearson_list)}") # 0.2449
110
+ ```
111
+
112
+
113
+ ## Reference
114
+
115
+ If you find our work useful in your research, please consider citing our paper:
116
+
117
+ ```
118
+ @inproceedings{huang2025stflow,
119
+ title={Scalable Generation of Spatial Transcriptomics from Histology Images via Whole-Slide Flow Matching},
120
+ author={Huang, Tinglin and Liu, Tianyu and Babadi, Mehrtash and Jin, Wengong and Ying, Rex},
121
+ booktitle={International Conference on Machine Learning},
122
+ year={2025}
123
+ }
124
+
125
+ @article{huang2025stpath,
126
+ title={STPath: A Generative Foundation Model for Integrating Spatial Transcriptomics and Whole Slide Images},
127
+ author={Huang, Tinglin and Liu, Tianyu and Babadi, Mehrtash and Ying, Rex and Jin, Wengong},
128
+ journal={bioRxiv},
129
+ pages={2025--04},
130
+ year={2025},
131
+ publisher={Cold Spring Harbor Laboratory}
132
+ }
133
+ ```
stfm.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03d49af98103c22eaee064632a366ad6ba2c1e627adcf47e00b13746a3b348fe
3
+ size 196728540