{ "cells": [ { "cell_type": "markdown", "id": "bb8ebdf2", "metadata": {}, "source": [ "# Getting Started with SILPAG\n", "\n", "SILPAG (**S**patial **I**nter-slice **L**inking via **P**atch-level **A**nchor **G**ene-guided Optimal Transport) is a framework for cross-slice spatial transcriptomics alignment and gene expression analysis. It learns shared representations across tissue slices using vector-quantized embeddings and anchor-gene-guided optimal transport." ] }, { "cell_type": "code", "execution_count": 18, "id": "6d719e9f", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.filterwarnings('ignore') \n", "import scanpy as sc\n", "import anndata as ad\n", "import gseapy as gp\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import torch.nn.functional as F\n", "from torch.utils.data import TensorDataset, DataLoader\n", "from sklearn.decomposition import PCA\n", "from scipy import sparse\n", "from matplotlib.lines import Line2D\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import matplotlib\n", "matplotlib.rcParams['font.family'] = 'Arial'\n", "\n", "import SILPAG as sp\n" ] }, { "cell_type": "markdown", "id": "7ebf0aef", "metadata": {}, "source": [ "## Download Demo Data\n", "\n", "This tutorial uses two 10x Visium spatial transcriptomics samples:\n", "\n", "| Sample | Description | File |\n", "|--------|-------------|------|\n", "| **hNB-V02** | Human normal breast tissue | `V02_nb.h5ad` |\n", "| **hBC-H1** | Human breast cancer tissue | `H1_nb.h5ad` |\n", "\n", "**Download link:** [OneDrive – Demo Data](https://cuhko365-my.sharepoint.com/:f:/g/personal/225040459_link_cuhk_edu_cn/IgB8oVwMOxjRTZndM5LLiexwASgi3ts3kse0_dRJZ9SLnzc?e=pwTcxQ)\n", "\n", "After downloading, place the files in the `data/` directory:\n", "\n", "\n", "```\n", "SILPAG/\n", "├── __init__.py\n", "├── main.py\n", "├── model.py\n", "├── agot.py # Anchor-Gene-guided Optimal Transport\n", "├── util.py\n", "└── data/\n", " ├── V02_nb.h5ad # normal breast (reference)\n", " └── H1_nb.h5ad # breast cancer (query)\n", "```\n", "\n" ] }, { "cell_type": "markdown", "id": "7dd91abf", "metadata": {}, "source": [ "## Quick Start" ] }, { "cell_type": "markdown", "id": "bf297743", "metadata": {}, "source": [ "### 1. Load & preprocess data" ] }, { "cell_type": "code", "execution_count": 14, "id": "0df3ec68", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1896, 7874) (613, 7874)\n" ] } ], "source": [ "ref = sc.read_h5ad('SILPAG/data/V02_nb.h5ad')\n", "adata = sc.read_h5ad('SILPAG/data/H1_nb.h5ad')\n", "ref.var_names_make_unique()\n", "adata.var_names_make_unique()\n", "sp.prefilter_specialgenes(ref)\n", "sp.prefilter_specialgenes(adata)\n", "\n", "gene = list(set(ref.var_names)&set(adata.var_names))\n", "ref = ref[:,gene]\n", "adata = adata[:, gene]\n", "\n", "g1 = sc.pp.filter_genes(ref, min_cells=int(ref.shape[0]*0.02), inplace=False)\n", "g2 = sc.pp.filter_genes(adata, min_cells=int(adata.shape[0]*0.02), inplace=False)\n", "ref = ref[:,g1[0] | g2[0]]\n", "adata = adata[:, g1[0] | g2[0]]\n", "\n", "key1 = list(ref.var_names)\n", "key2 = list(adata.var_names)\n", "key = list(set(key1) & set(key2)) \n", "ref = ref[:, key]\n", "adata = adata[:, key]\n", "ref.var['hkgene'] = adata.var['hkgene']\n", "# ref.var['hkgene'] = 1 \n", "labels = np.zeros(ref.shape[1], dtype=int)\n", "labels[np.where(ref.var['hkgene'] == 1)[0]] = 1\n", "print(ref.shape, adata.shape)\n", "del ref.raw" ] }, { "cell_type": "markdown", "id": "b9847795", "metadata": {}, "source": [ "### 2.Configure & train" ] }, { "cell_type": "code", "execution_count": 3, "id": "fca1577e", "metadata": {}, "outputs": [], "source": [ "args = sp.Config()\n", "args = args.parse()\n", "args.device = 'cuda:7'\n", "args.distr = ['nb', 'nb']# ['gaussian', 'gaussian']\n", "args.K = [1024] * args.num_slice\n", "args.patch_size = [(6, 6), (6, 6)]\n", "args.resize_factor = [1, 1]\n", "args.hist = [False, False]\n", "args.beta = 1e-4 # cross-view loss weight\n", "args.learning_rate = 1.5e-3\n", "args.weight_decay = 5e-2 # L2 regularization\n", "args.batch_size = 64\n", "args.anchor_size = 512\n", "args.pre_epochs = 30\n", "args.warmup_epochs = 20\n", "args.epochs = 200\n", "args.anchor_epochs = 10\n", "args.lr_decay_epochs = [200]# [120,160]\n", "args.trace = True\n", "args.save = True\n", "args.model_path = '/data/luy/SILPAG/saved_model/hBC_demo'\n", "args.train = True\n", "\n", "gene_images = [ref.varm['gene_img'], adata.varm['gene_img']]\n", "train_dataset = sp.GeneDataset(gene_images, labels, args)" ] }, { "cell_type": "code", "execution_count": 4, "id": "658dd5e2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Shape of each View: [(66, 96), (30, 30)]\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Pre-training: 20%|██ | 6/30 [00:24<01:35, 3.97s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 1.0484376436022294\n", "Huber loss: 1.1827829629144966\n", "Cosine loss: 0.07646908347062238\n", "VQ loss: 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Pre-training: 40%|████ | 12/30 [00:47<01:10, 3.94s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 0.9576785071931092\n", "Huber loss: 0.38763301365563857\n", "Cosine loss: 0.06005056376138677\n", "VQ loss: 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Pre-training: 60%|██████ | 18/30 [01:10<00:46, 3.84s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 0.8954382856878679\n", "Huber loss: 0.25106055748994077\n", "Cosine loss: 0.0517275154151558\n", "VQ loss: 0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Pre-training: 80%|████████ | 24/30 [01:35<00:24, 4.04s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 1.0781532580821624\n", "Huber loss: 0.3098702563688723\n", "Cosine loss: 3.209030464628429\n", "VQ loss: 0.010244798983035881\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Pre-training: 100%|██████████| 30/30 [01:59<00:00, 3.99s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 1.0641289414686088\n", "Huber loss: 0.32430742846426697\n", "Cosine loss: 3.110177480863823\n", "VQ loss: 0.007852736139489759\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "==================== Slice 0 ====================\n", " Original codebook size: 1024\n", " New codebook size: 51\n", "==================== Slice 1 ====================\n", " Original codebook size: 1024\n", " New codebook size: 203\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 20%|█▉ | 39/200 [11:46<49:16, 18.36s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 0.9412368250613834\n", "Huber loss: 0.25285821554319854\n", "AGOT loss: 2.7221589030303837\n", "VQ loss: 0.029331755227082146\n", "Anchor Pool size: 2009\n", "P0: tensor(0.1887, device='cuda:7')\n", "Mean Q: tensor(0.1878, device='cuda:7')\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 40%|███▉ | 79/200 [24:02<37:15, 18.48s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 0.850529023042814\n", "Huber loss: 0.1461822370231409\n", "AGOT loss: 0.8229230126058023\n", "VQ loss: 0.026565113252254496\n", "Anchor Pool size: 2338\n", "P0: tensor(0.0626, device='cuda:7')\n", "Mean Q: tensor(0.0473, device='cuda:7')\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 60%|█████▉ | 119/200 [36:19<25:02, 18.54s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 0.7562366306918026\n", "Huber loss: 0.10002416282975495\n", "AGOT loss: 0.425887694766384\n", "VQ loss: 0.027144083628326404\n", "Anchor Pool size: 2383\n", "P0: tensor(0.0425, device='cuda:7')\n", "Mean Q: tensor(0.0237, device='cuda:7')\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 80%|███████▉ | 159/200 [48:34<12:31, 18.32s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 0.7011057725708467\n", "Huber loss: 0.07753480484878045\n", "AGOT loss: 0.4764239428641077\n", "VQ loss: 0.026703196980583356\n", "Anchor Pool size: 2368\n", "P0: tensor(0.0451, device='cuda:7')\n", "Mean Q: tensor(0.0267, device='cuda:7')\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|█████████▉| 199/200 [1:00:54<00:18, 18.54s/epochs]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Ranking loss: 0.6846841919201111\n", "Huber loss: 0.07288047036968393\n", "AGOT loss: 0.4087873572530497\n", "VQ loss: 0.02679214139954145\n", "Anchor Pool size: 2382\n", "P0: tensor(0.0418, device='cuda:7')\n", "Mean Q: tensor(0.0227, device='cuda:7')\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Training: 100%|██████████| 200/200 [1:01:12<00:00, 18.36s/epochs]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Number of Anchor gene: 2382\n" ] } ], "source": [ "# train model\n", "train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False)\n", "model, dataset, PI, Q = sp.train_marker(train_loader, hist_data=None, args=args)\n", "print('Number of Anchor gene:', dataset.labels.sum())\n", "code = sp.get_code(model, dataset, hist_data=None, pi_init=PI, source_idx=0, target_idx=1, args=args)\n", "adata.varm['code_ref'] = code[0]\n", "adata.varm['code_tgt'] = code[1]" ] }, { "cell_type": "markdown", "id": "24224ed8", "metadata": {}, "source": [ "### 3. Generate view-tranferred gene expression" ] }, { "cell_type": "code", "execution_count": 6, "id": "bcbd1357", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Generating...: 100%|██████████| 7874/7874 [00:08<00:00, 951.35gene/s]\n", "Generating...: 100%|██████████| 7874/7874 [00:08<00:00, 966.43gene/s]\n" ] } ], "source": [ "result1 = model.generate(args, adata, tgt_index=1, code_id='code_ref', gene='all')\n", "result2 = model.generate(args, adata, tgt_index=1, code_id='code_tgt', gene='all')" ] }, { "cell_type": "markdown", "id": "00b93530", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }