{ "cells": [ { "cell_type": "markdown", "id": "4760ef5ca3ad7b8a", "metadata": {}, "source": [ "# Fewshot Tutorial" ] }, { "cell_type": "markdown", "id": "5591e5ea0b8550e6", "metadata": {}, "source": [ "This tutorial demonstrates how to finetune the scLinguist model on a fewshot dataset and then use it to predict protein expression from RNA data." ] }, { "cell_type": "markdown", "id": "c20285aad68c6452", "metadata": {}, "source": [ "Import necessary packages and define paths for checkpoints and save directory." ] }, { "cell_type": "code", "execution_count": 1, "id": "initial_id", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T04:26:08.061560Z", "start_time": "2025-08-31T04:26:04.537193Z" }, "collapsed": true }, "outputs": [], "source": [ "import torch\n", "from pathlib import Path\n", "from torch.utils.data import DataLoader\n", "import sys\n", "sys.path.append('../../')\n", "from scLinguist.data_loaders.data_loader import scMultiDataset\n", "from scLinguist.model.configuration_hyena import HyenaConfig\n", "from scLinguist.model.model import scTrans\n", "import importlib, sys\n", "sys.modules['model'] = importlib.import_module('scLinguist.model')\n", "\n", "ENCODER_CKPT = Path(\"../../pretrained_model/encoder.ckpt\")\n", "DECODER_CKPT = Path(\"../../pretrained_model/decoder.ckpt\")\n", "FINETUNE_CKPT = Path(\"../../pretrained_model/finetune.ckpt\")\n", "SAVE_DIR = Path(\"../../docs/tutorials/fewshot_output\")\n", "SAVE_DIR.mkdir(exist_ok=True)" ] }, { "cell_type": "markdown", "id": "162900a151b5b992", "metadata": {}, "source": [ "Configure dataloaders for fewshot and test datasets.\n", "\n", "First, inspect the data structure of the datasets to ensure they are compatible with the model." ] }, { "cell_type": "code", "execution_count": 25, "id": "bff7d74c1ed6cbae", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:10.791866Z", "start_time": "2025-08-31T03:17:10.734023Z" } }, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 20 × 19202\n", " obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import scanpy as sc\n", "rna_train = sc.read_h5ad('../../data/fewshot_sample_rna.h5ad')\n", "rna_train" ] }, { "cell_type": "code", "execution_count": 26, "id": "406faeec6e69a11f", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:11.957905Z", "start_time": "2025-08-31T03:17:11.947240Z" } }, "outputs": [ { "data": { "text/plain": [ "matrix([[ 0., 0., 0., ..., 15., 1., 27.],\n", " [ 0., 0., 0., ..., 0., 0., 2.],\n", " [ 0., 0., 0., ..., 2., 0., 7.],\n", " ...,\n", " [ 0., 0., 0., ..., 2., 0., 1.],\n", " [ 0., 0., 0., ..., 31., 0., 69.],\n", " [ 0., 0., 0., ..., 2., 0., 8.]], dtype=float32)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rna_train.X.todense()" ] }, { "cell_type": "code", "execution_count": 27, "id": "4e26d6ed540d442d", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:15.306723Z", "start_time": "2025-08-31T03:17:15.297132Z" } }, "outputs": [ { "data": { "text/plain": [ "Index(['ENSG00000186092', 'ENSG00000284733', 'ENSG00000284662',\n", " 'ENSG00000187634', 'ENSG00000188976', 'ENSG00000187961',\n", " 'ENSG00000187583', 'ENSG00000187642', 'ENSG00000188290',\n", " 'ENSG00000187608',\n", " ...\n", " 'ENSG00000198712', 'ENSG00000228253', 'ENSG00000198899',\n", " 'ENSG00000198938', 'ENSG00000198840', 'ENSG00000212907',\n", " 'ENSG00000198886', 'ENSG00000198786', 'ENSG00000198695',\n", " 'ENSG00000198727'],\n", " dtype='object', length=19202)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rna_train.var_names" ] }, { "cell_type": "code", "execution_count": 28, "id": "1c06bdf4dddb39e3", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:17.433460Z", "start_time": "2025-08-31T03:17:17.312128Z" } }, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 10546 × 19202\n", " obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import scanpy as sc\n", "rna_test = sc.read_h5ad('../../data/test_sample_rna.h5ad')\n", "rna_test" ] }, { "cell_type": "code", "execution_count": 29, "id": "4c53fdea8be7bd78", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:18.880122Z", "start_time": "2025-08-31T03:17:18.566271Z" } }, "outputs": [ { "data": { "text/plain": [ "matrix([[ 0., 0., 0., ..., 5., 0., 3.],\n", " [ 0., 0., 0., ..., 0., 0., 7.],\n", " [ 0., 0., 0., ..., 1., 0., 6.],\n", " ...,\n", " [ 0., 0., 0., ..., 1., 0., 2.],\n", " [ 0., 0., 0., ..., 2., 0., 2.],\n", " [ 0., 0., 0., ..., 4., 0., 12.]], dtype=float32)" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rna_test.X.todense()" ] }, { "cell_type": "code", "execution_count": 30, "id": "2c6747360b131ca2", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:20.270653Z", "start_time": "2025-08-31T03:17:20.261668Z" } }, "outputs": [ { "data": { "text/plain": [ "Index(['ENSG00000186092', 'ENSG00000284733', 'ENSG00000284662',\n", " 'ENSG00000187634', 'ENSG00000188976', 'ENSG00000187961',\n", " 'ENSG00000187583', 'ENSG00000187642', 'ENSG00000188290',\n", " 'ENSG00000187608',\n", " ...\n", " 'ENSG00000198712', 'ENSG00000228253', 'ENSG00000198899',\n", " 'ENSG00000198938', 'ENSG00000198840', 'ENSG00000212907',\n", " 'ENSG00000198886', 'ENSG00000198786', 'ENSG00000198695',\n", " 'ENSG00000198727'],\n", " dtype='object', length=19202)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rna_test.var_names" ] }, { "cell_type": "code", "execution_count": 31, "id": "9b84f4dda663c60c", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:23.161588Z", "start_time": "2025-08-31T03:17:22.415663Z" } }, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 16994 × 6427\n", " obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import scanpy as sc\n", "adt_train = sc.read_h5ad('../../data/train_sample_adt.h5ad')\n", "adt_train" ] }, { "cell_type": "code", "execution_count": 32, "id": "497bd4321c403102", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:34.433262Z", "start_time": "2025-08-31T03:17:34.211569Z" } }, "outputs": [ { "data": { "text/plain": [ "matrix([[1.100e+02, 1.400e+01, 4.900e+01, ..., 1.100e+01, 2.120e+02,\n", " 2.800e+01],\n", " [1.200e+02, 1.000e+00, 5.890e+02, ..., 2.000e+00, 3.600e+01,\n", " 2.400e+01],\n", " [6.450e+02, 5.000e+00, 1.256e+03, ..., 1.000e+00, 7.200e+01,\n", " 1.320e+02],\n", " ...,\n", " [2.330e+02, 2.700e+01, 8.420e+02, ..., 3.000e+00, 7.700e+01,\n", " 4.600e+01],\n", " [3.120e+02, 1.500e+01, 1.079e+03, ..., 4.000e+00, 4.800e+01,\n", " 8.000e+01],\n", " [1.960e+02, 2.000e+00, 1.910e+02, ..., 0.000e+00, 3.400e+01,\n", " 1.900e+01]])" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "mask = ~np.isnan(adt_train.X[0].toarray())\n", "adt_train[:, mask].X.todense()" ] }, { "cell_type": "code", "execution_count": 38, "id": "23db37c02a2cc70e", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:20:13.452987Z", "start_time": "2025-08-31T03:20:13.444177Z" } }, "outputs": [ { "data": { "text/plain": [ "Index(['SP110', 'GTPBA', 'SNX2', 'FRG1', 'TT21A', 'RHG18', 'AR', 'DOCK1',\n", " 'RAB1A', 'MUC1.HMFG2',\n", " ...\n", " 'CYTSA', 'LFNG', 'PFKFB4', 'LIPB1', 'ZN225', 'TRI69', 'CCL14', 'ZN541',\n", " 'TAP1', 'SCG3'],\n", " dtype='object', length=6427)" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adt_train.var_names" ] }, { "cell_type": "code", "execution_count": 33, "id": "19c888ff6fe5f34a", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:17:37.257945Z", "start_time": "2025-08-31T03:17:36.797743Z" } }, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 10546 × 6427\n", " obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import scanpy as sc\n", "adt_test = sc.read_h5ad('../../data/test_sample_adt.h5ad')\n", "adt_test" ] }, { "cell_type": "markdown", "id": "3dff4dd5528c93a8", "metadata": {}, "source": [ "Then, create the dataloaders for both fewshot and test datasets. The `scMultiDataset` class is used to load the RNA and protein data from the specified paths." ] }, { "cell_type": "code", "execution_count": 34, "id": "80ddcdbf5b659c88", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:18:07.226968Z", "start_time": "2025-08-31T03:18:04.587677Z" } }, "outputs": [], "source": [ "BATCH_SIZE = 4\n", "fewshot_data = scMultiDataset(\n", " data_dir_1=\"../../data/fewshot_sample_rna.h5ad\",\n", " data_dir_2=\"../../data/fewshot_sample_adt.h5ad\",\n", ")\n", "test_data = scMultiDataset(\n", " data_dir_1=\"../../data/test_sample_rna.h5ad\",\n", " data_dir_2=\"../../data/test_sample_adt.h5ad\",\n", ")\n", "fewshot_dataloader = DataLoader(\n", " fewshot_data,\n", " batch_size=BATCH_SIZE,\n", " shuffle=True,\n", " num_workers=8,\n", " pin_memory=True\n", ")\n", "test_dataloader = DataLoader(\n", " test_data,\n", " batch_size=BATCH_SIZE,\n", " shuffle=False,\n", " drop_last=False,\n", " num_workers=0,\n", " pin_memory=True,\n", ")" ] }, { "cell_type": "markdown", "id": "37a4e2fd1165741d", "metadata": {}, "source": [ "Last, configure the model with the appropriate encoder and decoder checkpoints, and set the mode to \"RNA-protein\". The `HyenaConfig` class is used to define the model configuration parameters such as `d_model`, `emb_dim`, `max_seq_len`, `vocab_len`, and `n_layer`." ] }, { "cell_type": "code", "execution_count": 35, "id": "97bf208af1652d7d", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:18:11.653358Z", "start_time": "2025-08-31T03:18:07.266110Z" } }, "outputs": [], "source": [ "enc_cfg = HyenaConfig(\n", " d_model = 128,\n", " emb_dim = 5,\n", " max_seq_len = 19202,\n", " vocab_len = 19202,\n", " n_layer = 1,\n", " output_hidden_states=False,\n", ")\n", "dec_cfg = HyenaConfig(\n", " d_model = 128,\n", " emb_dim = 5,\n", " max_seq_len = 6427,\n", " vocab_len = 6427,\n", " n_layer = 1,\n", " output_hidden_states=False,\n", ")\n", "model = scTrans.load_from_checkpoint(checkpoint_path=FINETUNE_CKPT)\n", "model.encoder_ckpt_path = ENCODER_CKPT\n", "model.decoder_ckpt_path = DECODER_CKPT\n", "model.mode = \"RNA-protein\"" ] }, { "cell_type": "markdown", "id": "15572e727148bd7e", "metadata": {}, "source": [ "Start training the model using PyTorch Lightning. The `ModelCheckpoint` callback is used to save the best model based on validation loss, and the `EarlyStopping` callback is used to stop training if the validation loss does not improve for a specified number of epochs." ] }, { "cell_type": "code", "execution_count": 37, "id": "d21252083680bc85", "metadata": { "ExecuteTime": { "end_time": "2025-08-31T03:20:13.397572Z", "start_time": "2025-08-31T03:19:15.028049Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True, used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", " | Name | Type | Params\n", "------------------------------------------------\n", "0 | encoder | scHeyna_enc | 313 K \n", "1 | decoder | scHeyna_dec | 249 K \n", "2 | translator | MLPTranslator | 284 M \n", "3 | cos_gene | CosineSimilarity | 0 \n", "4 | cos_cell | CosineSimilarity | 0 \n", "------------------------------------------------\n", "285 M Trainable params\n", "0 Non-trainable params\n", "285 M Total params\n", "1,141.275 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "cc14adc5d8ba4306b1d39b88c45ece7a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation sanity check: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8311eb35fab24d01b8299e214daba15d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c6dcb6ff6357417da461fd36fca2af04", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validating: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "da76a1e996b64ba6aad6e9fcac2ce634", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validating: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a565eabac1a5413f8a5f5a2cc0ab32b0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validating: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2e3aa5ec7fff4066ab255ff26ab7f589", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validating: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f70afd3c542d4d27afb1a233c32cefba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validating: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73b2bf0fb4cf46d38f8eca5399a3a8e3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validating: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n", "\n", "ckpt_cb = ModelCheckpoint(\n", " dirpath = SAVE_DIR/\"ckpt\",\n", " monitor = \"valid_loss\",\n", " mode = \"min\",\n", " save_top_k = 1,\n", " filename = \"best-{epoch}-{valid_loss:.4f}\",\n", ")\n", "early_cb = EarlyStopping(monitor=\"valid_loss\", mode=\"min\", patience=3)\n", "\n", "trainer = pl.Trainer(\n", " accelerator = \"gpu\",\n", " devices = [1],\n", " max_epochs = 6,\n", " log_every_n_steps = 50,\n", " callbacks = [ckpt_cb, early_cb],\n", ")\n", "\n", "trainer.fit(model, fewshot_dataloader, fewshot_dataloader)\n", "best_ckpt = ckpt_cb.best_model_path" ] }, { "cell_type": "markdown", "id": "874f62a4f446a010", "metadata": {}, "source": [ "Inference with the trained model on the test dataset. Use RNA data to predict proteins in ../../docs/tutorials/protein_names.txt" ] }, { "cell_type": "code", "execution_count": null, "id": "b597c4c88b552386", "metadata": { "jupyter": { "is_executing": true } }, "outputs": [], "source": [ "import scanpy as sc\n", "import torch\n", "\n", "# only use 10 cells for example\n", "fewshot_adata = sc.read_h5ad(\"../../data/test_sample_rna.h5ad\")[:10]\n", "sc.pp.normalize_total(fewshot_adata, target_sum=10000)\n", "sc.pp.log1p(fewshot_adata)\n", "fewshot_rna_tensor = torch.tensor(fewshot_adata.X.todense(), dtype=torch.float32).cuda()\n", "\n", "model.eval().cuda()\n", "\n", "with torch.no_grad():\n", " _, _, protein_pred = model(fewshot_rna_tensor)\n", "\n", "# predict given proteins\n", "target_proteins = [line.strip() for line in open(\"../../docs/tutorials/protein_names.txt\")]\n", "\n", "import pandas as pd\n", "prot_map = pd.read_csv(\"../../docs/tutorials/protein_index_map.csv\")\n", "name_to_idx = dict(zip(prot_map[\"name\"], prot_map[\"index\"]))\n", "\n", "idx = [name_to_idx[p] for p in target_proteins if p in name_to_idx]\n", "\n", "pred_df = pd.DataFrame(\n", " protein_pred[:, idx].cpu().numpy(),\n", " columns = target_proteins,\n", " index = fewshot_adata.obs_names,\n", ")\n", "pred_df.to_csv(SAVE_DIR/\"predicted_protein_expression.csv\")" ] } ], "metadata": { "kernelspec": { "display_name": "py3.8", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }