STPath: A Generative Foundation Model for Integrating Spatial Transcriptomics and Whole Slide Images
This is a Huggingface repo for the paper:
Tinglin Huang, Tianyu Liu, Mehrtash Babadi, Rex Ying, and Wengong Jin (2025). STPath: A Generative Foundation Model for Integrating Spatial Transcriptomics and Whole Slide Images. Paper in bioRxiv. Code in GitHub.
Usage
We provide an easy-to-use interface for users to perform inference on the pre-trained model, which can be found in app/pipeline/inference.py
. Specifically, the following code snippet shows how to use it:
from stpath.app.pipeline.inference import STPathInference
agent = STPathInference(
gene_voc_path='STPath_dir/utils_data/symbol2ensembl.json',
model_weight_path='your_dir/stpath.pkl',
device=0
)
pred_adata = agent.inference(
coords=coords, # [number_of_spots, 2]
img_features=embeddings, # [number_of_spots, 1536], the image features extracted using Gigapath
organ_type="Kidney", # Default is None
tech_type="Visium", # Default is None
save_gene_names=hvg_list # a list of gene names to save in the adata, e.g., ['GATA3', 'UBLE2C', ...]. None will save all genes in the model.
)
# save adata
pred_adata.write_h5ad(f"your_dir/pred_{sample_id}.h5ad")
The vocabularies for organs and technologies can be found in the following locations:
If the organ type or the tech type is unknown, you can set them to None
in the inference function. Besides, the predicted gene expression values are log1p-transformed (log(1 + x)
), consistent with the transformation applied during the training of STPath.
Example of Inference
Here, we provide an example of how to perform inference on a sample from the HEST dataset:
from scipy.stats import pearsonr
from stpath.hest_utils.st_dataset import load_adata
from stpath.hest_utils.file_utils import read_assets_from_h5
sample_id = "INT2"
source_dataroot = "STPath_dir" # the root directory of the STPath repository
with open(os.path.join(source_dataroot, "example_data/var_50genes.json")) as f:
hvg_list = json.load(f)['genes']
data_dict, _ = read_assets_from_h5(os.path.join(source_dataroot, f"{sample_id}.h5")) # load the data from the h5 file
coords = data_dict["coords"]
embeddings = data_dict["embeddings"]
barcodes = data_dict["barcodes"].flatten().astype(str).tolist()
adata = sc.read_h5ad(os.path.join(source_dataroot, f"{sample_id}.h5ad"))[barcodes, :]
# The return pred_adata includes the expressions of the genes in hvg_list, which is a list of highly variable genes.
pred_adata = agent.inference(
coords=coords,
img_features=embeddings,
organ_type="Kidney",
tech_type="Visium",
save_gene_names=hvg_list # we only need the highly variable genes for evaluation
)
# calculate the Pearson correlation coefficient between the predicted and ground truth gene expression
all_pearson_list = []
gt = np.log1p(adata[:, hvg_list].X.toarray()) # sparse -> dense
# go through each gene in the highly variable genes list
for i in range(len(hvg_list)):
pearson_corr, _ = pearsonr(gt[:, i], pred_adata.X[:, i])
all_pearson_list.append(pearson_corr.item())
print(f"Pearson correlation for {sample_id}: {np.mean(all_pearson_list)}") # 0.1562
In-context Learning
STPath also support in-context learning, which allows users to provide the expression of a few spots to guide the model to predict the expression of other spots:
from stpath.data.sampling_utils import PatchSampler
rightest_coord = np.where(coords[:, 0] == coords[:, 0].max())[0][0]
masked_ids = PatchSampler.sample_nearest_patch(coords, int(len(coords) * 0.95), rightest_coord) # predict the expression of the 95% spots
context_ids = np.setdiff1d(np.arange(len(coords)), masked_ids) # the index not in masked_ids will be used as context
context_gene_exps = adata.X.toarray()[context_ids]
context_gene_names = adata.var_names.tolist()
pred_adata = agent.inference(
coords=coords,
img_features=embeddings,
context_ids=context_ids, # the index of the context spots
context_gene_exps=context_gene_exps, # the expression of the context spots
context_gene_names=context_gene_names, # the gene names of the context spots
organ_type="Kidney",
tech_type="Visium",
save_gene_names=hvg_list,
)
all_pearson_list = []
gt = np.log1p(adata[:, hvg_list].X.toarray())[masked_ids, :] # groundtruth expression of the spots in masked_ids
pred = pred_adata.X[masked_ids, :] # predicted expression of the spots in masked_ids
for i in range(len(hvg_list)):
pearson_corr, _ = pearsonr(gt[:, i], pred[:, i])
all_pearson_list.append(pearson_corr.item())
print(f"Pearson correlation for {sample_id}: {np.mean(all_pearson_list)}") # 0.2449
Reference
If you find our work useful in your research, please consider citing our paper:
@inproceedings{huang2025stflow,
title={Scalable Generation of Spatial Transcriptomics from Histology Images via Whole-Slide Flow Matching},
author={Huang, Tinglin and Liu, Tianyu and Babadi, Mehrtash and Jin, Wengong and Ying, Rex},
booktitle={International Conference on Machine Learning},
year={2025}
}
@article{huang2025stpath,
title={STPath: A Generative Foundation Model for Integrating Spatial Transcriptomics and Whole Slide Images},
author={Huang, Tinglin and Liu, Tianyu and Babadi, Mehrtash and Ying, Rex and Jin, Wengong},
journal={bioRxiv},
pages={2025--04},
year={2025},
publisher={Cold Spring Harbor Laboratory}
}