Getting Started with SILPAG

SILPAG (Spatial Inter-slice Linking via Patch-level Anchor Gene-guided Optimal Transport) is a framework for cross-slice spatial transcriptomics alignment and gene expression analysis. It learns shared representations across tissue slices using vector-quantized embeddings and anchor-gene-guided optimal transport.

[18]:
import warnings
warnings.filterwarnings('ignore')
import scanpy as sc
import anndata as ad
import gseapy as gp
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from sklearn.decomposition import PCA
from scipy import sparse
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
matplotlib.rcParams['font.family'] = 'Arial'

import SILPAG as sp

Download Demo Data

This tutorial uses two 10x Visium spatial transcriptomics samples:

Sample

Description

File

hNB-V02

Human normal breast tissue

V02_nb.h5ad

hBC-H1

Human breast cancer tissue

H1_nb.h5ad

Download link: OneDrive – Demo Data

After downloading, place the files in the data/ directory:

SILPAG/
├── __init__.py
├── main.py
├── model.py
├── agot.py          # Anchor-Gene-guided Optimal Transport
├── util.py
└── data/
    ├── V02_nb.h5ad  # normal breast (reference)
    └── H1_nb.h5ad   # breast cancer (query)

Quick Start

1. Load & preprocess data

[14]:
ref = sc.read_h5ad('SILPAG/data/V02_nb.h5ad')
adata = sc.read_h5ad('SILPAG/data/H1_nb.h5ad')
ref.var_names_make_unique()
adata.var_names_make_unique()
sp.prefilter_specialgenes(ref)
sp.prefilter_specialgenes(adata)

gene = list(set(ref.var_names)&set(adata.var_names))
ref = ref[:,gene]
adata = adata[:, gene]

g1 = sc.pp.filter_genes(ref, min_cells=int(ref.shape[0]*0.02), inplace=False)
g2 = sc.pp.filter_genes(adata, min_cells=int(adata.shape[0]*0.02), inplace=False)
ref = ref[:,g1[0] | g2[0]]
adata = adata[:, g1[0] | g2[0]]

key1 = list(ref.var_names)
key2 = list(adata.var_names)
key = list(set(key1) & set(key2))
ref = ref[:, key]
adata = adata[:, key]
ref.var['hkgene'] = adata.var['hkgene']
# ref.var['hkgene'] = 1
labels = np.zeros(ref.shape[1], dtype=int)
labels[np.where(ref.var['hkgene'] == 1)[0]] = 1
print(ref.shape, adata.shape)
del ref.raw
(1896, 7874) (613, 7874)

2.Configure & train

[3]:
args = sp.Config()
args = args.parse()
args.device = 'cuda:7'
args.distr = ['nb', 'nb']# ['gaussian', 'gaussian']
args.K = [1024] * args.num_slice
args.patch_size = [(6, 6), (6, 6)]
args.resize_factor = [1, 1]
args.hist = [False, False]
args.beta = 1e-4 # cross-view loss weight
args.learning_rate = 1.5e-3
args.weight_decay = 5e-2 # L2 regularization
args.batch_size = 64
args.anchor_size = 512
args.pre_epochs = 30
args.warmup_epochs = 20
args.epochs = 200
args.anchor_epochs = 10
args.lr_decay_epochs = [200]# [120,160]
args.trace = True
args.save = True
args.model_path = '/data/luy/SILPAG/saved_model/hBC_demo'
args.train = True

gene_images = [ref.varm['gene_img'], adata.varm['gene_img']]
train_dataset = sp.GeneDataset(gene_images, labels, args)
[4]:
# train model
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)
model, dataset, PI, Q = sp.train_marker(train_loader, hist_data=None, args=args)
print('Number of Anchor gene:', dataset.labels.sum())
code = sp.get_code(model, dataset, hist_data=None, pi_init=PI, source_idx=0, target_idx=1, args=args)
adata.varm['code_ref'] = code[0]
adata.varm['code_tgt'] = code[1]
Shape of each View:  [(66, 96), (30, 30)]
Pre-training:  20%|██        | 6/30 [00:24<01:35,  3.97s/epochs]
Ranking loss:  1.0484376436022294
Huber loss:  1.1827829629144966
Cosine loss:  0.07646908347062238
VQ loss:  0.0
Pre-training:  40%|████      | 12/30 [00:47<01:10,  3.94s/epochs]
Ranking loss:  0.9576785071931092
Huber loss:  0.38763301365563857
Cosine loss:  0.06005056376138677
VQ loss:  0.0
Pre-training:  60%|██████    | 18/30 [01:10<00:46,  3.84s/epochs]
Ranking loss:  0.8954382856878679
Huber loss:  0.25106055748994077
Cosine loss:  0.0517275154151558
VQ loss:  0.0
Pre-training:  80%|████████  | 24/30 [01:35<00:24,  4.04s/epochs]
Ranking loss:  1.0781532580821624
Huber loss:  0.3098702563688723
Cosine loss:  3.209030464628429
VQ loss:  0.010244798983035881
Pre-training: 100%|██████████| 30/30 [01:59<00:00,  3.99s/epochs]
Ranking loss:  1.0641289414686088
Huber loss:  0.32430742846426697
Cosine loss:  3.110177480863823
VQ loss:  0.007852736139489759

==================== Slice  0 ====================
    Original codebook size:  1024
    New codebook size:  51
==================== Slice  1 ====================
    Original codebook size:  1024
    New codebook size:  203
Training:  20%|█▉        | 39/200 [11:46<49:16, 18.36s/epochs]
Ranking loss:  0.9412368250613834
Huber loss:  0.25285821554319854
AGOT loss:  2.7221589030303837
VQ loss:  0.029331755227082146
Anchor Pool size:  2009
P0:  tensor(0.1887, device='cuda:7')
Mean Q:  tensor(0.1878, device='cuda:7')
Training:  40%|███▉      | 79/200 [24:02<37:15, 18.48s/epochs]
Ranking loss:  0.850529023042814
Huber loss:  0.1461822370231409
AGOT loss:  0.8229230126058023
VQ loss:  0.026565113252254496
Anchor Pool size:  2338
P0:  tensor(0.0626, device='cuda:7')
Mean Q:  tensor(0.0473, device='cuda:7')
Training:  60%|█████▉    | 119/200 [36:19<25:02, 18.54s/epochs]
Ranking loss:  0.7562366306918026
Huber loss:  0.10002416282975495
AGOT loss:  0.425887694766384
VQ loss:  0.027144083628326404
Anchor Pool size:  2383
P0:  tensor(0.0425, device='cuda:7')
Mean Q:  tensor(0.0237, device='cuda:7')
Training:  80%|███████▉  | 159/200 [48:34<12:31, 18.32s/epochs]
Ranking loss:  0.7011057725708467
Huber loss:  0.07753480484878045
AGOT loss:  0.4764239428641077
VQ loss:  0.026703196980583356
Anchor Pool size:  2368
P0:  tensor(0.0451, device='cuda:7')
Mean Q:  tensor(0.0267, device='cuda:7')
Training: 100%|█████████▉| 199/200 [1:00:54<00:18, 18.54s/epochs]
Ranking loss:  0.6846841919201111
Huber loss:  0.07288047036968393
AGOT loss:  0.4087873572530497
VQ loss:  0.02679214139954145
Anchor Pool size:  2382
P0:  tensor(0.0418, device='cuda:7')
Mean Q:  tensor(0.0227, device='cuda:7')
Training: 100%|██████████| 200/200 [1:01:12<00:00, 18.36s/epochs]
Number of Anchor gene: 2382

3. Generate view-tranferred gene expression

[6]:
result1 = model.generate(args, adata, tgt_index=1, code_id='code_ref', gene='all')
result2 = model.generate(args, adata, tgt_index=1, code_id='code_tgt', gene='all')
Generating...: 100%|██████████| 7874/7874 [00:08<00:00, 951.35gene/s]
Generating...: 100%|██████████| 7874/7874 [00:08<00:00, 966.43gene/s]