Getting Started with SILPAG
SILPAG (Spatial Inter-slice Linking via Patch-level Anchor Gene-guided Optimal Transport) is a framework for cross-slice spatial transcriptomics alignment and gene expression analysis. It learns shared representations across tissue slices using vector-quantized embeddings and anchor-gene-guided optimal transport.
[18]:
import warnings
warnings.filterwarnings('ignore')
import scanpy as sc
import anndata as ad
import gseapy as gp
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.decomposition import PCA
from scipy import sparse
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
matplotlib.rcParams['font.family'] = 'Arial'
import SILPAG as sp
Download Demo Data
This tutorial uses two 10x Visium spatial transcriptomics samples:
Sample |
Description |
File |
|---|---|---|
hNB-V02 |
Human normal breast tissue |
|
hBC-H1 |
Human breast cancer tissue |
|
Download link: OneDrive – Demo Data
After downloading, place the files in the data/ directory:
SILPAG/
├── __init__.py
├── main.py
├── model.py
├── agot.py # Anchor-Gene-guided Optimal Transport
├── util.py
└── data/
├── V02_nb.h5ad # normal breast (reference)
└── H1_nb.h5ad # breast cancer (query)
Quick Start
1. Load & preprocess data
[14]:
ref = sc.read_h5ad('SILPAG/data/V02_nb.h5ad')
adata = sc.read_h5ad('SILPAG/data/H1_nb.h5ad')
ref.var_names_make_unique()
adata.var_names_make_unique()
sp.prefilter_specialgenes(ref)
sp.prefilter_specialgenes(adata)
gene = list(set(ref.var_names)&set(adata.var_names))
ref = ref[:,gene]
adata = adata[:, gene]
g1 = sc.pp.filter_genes(ref, min_cells=int(ref.shape[0]*0.02), inplace=False)
g2 = sc.pp.filter_genes(adata, min_cells=int(adata.shape[0]*0.02), inplace=False)
ref = ref[:,g1[0] | g2[0]]
adata = adata[:, g1[0] | g2[0]]
key1 = list(ref.var_names)
key2 = list(adata.var_names)
key = list(set(key1) & set(key2))
ref = ref[:, key]
adata = adata[:, key]
ref.var['hkgene'] = adata.var['hkgene']
# ref.var['hkgene'] = 1
labels = np.zeros(ref.shape[1], dtype=int)
labels[np.where(ref.var['hkgene'] == 1)[0]] = 1
print(ref.shape, adata.shape)
del ref.raw
(1896, 7874) (613, 7874)
2.Configure & train
[3]:
args = sp.Config()
args = args.parse()
args.device = 'cuda:7'
args.distr = ['nb', 'nb']# ['gaussian', 'gaussian']
args.K = [1024] * args.num_slice
args.patch_size = [(6, 6), (6, 6)]
args.resize_factor = [1, 1]
args.hist = [False, False]
args.beta = 1e-4 # cross-view loss weight
args.learning_rate = 1.5e-3
args.weight_decay = 5e-2 # L2 regularization
args.batch_size = 64
args.anchor_size = 512
args.pre_epochs = 30
args.warmup_epochs = 20
args.epochs = 200
args.anchor_epochs = 10
args.lr_decay_epochs = [200]# [120,160]
args.trace = True
args.save = True
args.model_path = '/data/luy/SILPAG/saved_model/hBC_demo'
args.train = True
gene_images = [ref.varm['gene_img'], adata.varm['gene_img']]
train_dataset = sp.GeneDataset(gene_images, labels, args)
[4]:
# train model
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)
model, dataset, PI, Q = sp.train_marker(train_loader, hist_data=None, args=args)
print('Number of Anchor gene:', dataset.labels.sum())
code = sp.get_code(model, dataset, hist_data=None, pi_init=PI, source_idx=0, target_idx=1, args=args)
adata.varm['code_ref'] = code[0]
adata.varm['code_tgt'] = code[1]
Shape of each View: [(66, 96), (30, 30)]
Pre-training: 20%|██ | 6/30 [00:24<01:35, 3.97s/epochs]
Ranking loss: 1.0484376436022294
Huber loss: 1.1827829629144966
Cosine loss: 0.07646908347062238
VQ loss: 0.0
Pre-training: 40%|████ | 12/30 [00:47<01:10, 3.94s/epochs]
Ranking loss: 0.9576785071931092
Huber loss: 0.38763301365563857
Cosine loss: 0.06005056376138677
VQ loss: 0.0
Pre-training: 60%|██████ | 18/30 [01:10<00:46, 3.84s/epochs]
Ranking loss: 0.8954382856878679
Huber loss: 0.25106055748994077
Cosine loss: 0.0517275154151558
VQ loss: 0.0
Pre-training: 80%|████████ | 24/30 [01:35<00:24, 4.04s/epochs]
Ranking loss: 1.0781532580821624
Huber loss: 0.3098702563688723
Cosine loss: 3.209030464628429
VQ loss: 0.010244798983035881
Pre-training: 100%|██████████| 30/30 [01:59<00:00, 3.99s/epochs]
Ranking loss: 1.0641289414686088
Huber loss: 0.32430742846426697
Cosine loss: 3.110177480863823
VQ loss: 0.007852736139489759
==================== Slice 0 ====================
Original codebook size: 1024
New codebook size: 51
==================== Slice 1 ====================
Original codebook size: 1024
New codebook size: 203
Training: 20%|█▉ | 39/200 [11:46<49:16, 18.36s/epochs]
Ranking loss: 0.9412368250613834
Huber loss: 0.25285821554319854
AGOT loss: 2.7221589030303837
VQ loss: 0.029331755227082146
Anchor Pool size: 2009
P0: tensor(0.1887, device='cuda:7')
Mean Q: tensor(0.1878, device='cuda:7')
Training: 40%|███▉ | 79/200 [24:02<37:15, 18.48s/epochs]
Ranking loss: 0.850529023042814
Huber loss: 0.1461822370231409
AGOT loss: 0.8229230126058023
VQ loss: 0.026565113252254496
Anchor Pool size: 2338
P0: tensor(0.0626, device='cuda:7')
Mean Q: tensor(0.0473, device='cuda:7')
Training: 60%|█████▉ | 119/200 [36:19<25:02, 18.54s/epochs]
Ranking loss: 0.7562366306918026
Huber loss: 0.10002416282975495
AGOT loss: 0.425887694766384
VQ loss: 0.027144083628326404
Anchor Pool size: 2383
P0: tensor(0.0425, device='cuda:7')
Mean Q: tensor(0.0237, device='cuda:7')
Training: 80%|███████▉ | 159/200 [48:34<12:31, 18.32s/epochs]
Ranking loss: 0.7011057725708467
Huber loss: 0.07753480484878045
AGOT loss: 0.4764239428641077
VQ loss: 0.026703196980583356
Anchor Pool size: 2368
P0: tensor(0.0451, device='cuda:7')
Mean Q: tensor(0.0267, device='cuda:7')
Training: 100%|█████████▉| 199/200 [1:00:54<00:18, 18.54s/epochs]
Ranking loss: 0.6846841919201111
Huber loss: 0.07288047036968393
AGOT loss: 0.4087873572530497
VQ loss: 0.02679214139954145
Anchor Pool size: 2382
P0: tensor(0.0418, device='cuda:7')
Mean Q: tensor(0.0227, device='cuda:7')
Training: 100%|██████████| 200/200 [1:01:12<00:00, 18.36s/epochs]
Number of Anchor gene: 2382
3. Generate view-tranferred gene expression
[6]:
result1 = model.generate(args, adata, tgt_index=1, code_id='code_ref', gene='all')
result2 = model.generate(args, adata, tgt_index=1, code_id='code_tgt', gene='all')
Generating...: 100%|██████████| 7874/7874 [00:08<00:00, 951.35gene/s]
Generating...: 100%|██████████| 7874/7874 [00:08<00:00, 966.43gene/s]