{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": "# Zeroshot Tutorial", "id": "913371d9d9205a10" }, { "metadata": {}, "cell_type": "markdown", "source": "This tutorial demonstrates how to use the scLinguist model for zeroshot prediction of protein expression from RNA data without the need for finetuning. The model is pre-trained and can be directly applied to new datasets.", "id": "109043e1b5a396" }, { "metadata": {}, "cell_type": "markdown", "source": "Import necessary packages and define paths for checkpoints and save directory.", "id": "a1c9680c71be0485" }, { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2025-08-01T18:24:47.200045Z", "start_time": "2025-08-01T18:24:44.312415Z" } }, "source": [ "from pathlib import Path\n", "from torch.utils.data import DataLoader\n", "import sys\n", "sys.path.append('../../')\n", "from scLinguist.data_loaders.data_loader import scMultiDataset\n", "from scLinguist.model.configuration_hyena import HyenaConfig\n", "from scLinguist.model.model import scTrans\n", "import importlib, sys\n", "sys.modules['model'] = importlib.import_module('scLinguist.model')\n", "\n", "ENCODER_CKPT = Path(\"../../pretrained_model/encoder.ckpt\")\n", "DECODER_CKPT = Path(\"../../pretrained_model/decoder.ckpt\")\n", "FINETUNE_CKPT = Path(\"../../pretrained_model/finetune.ckpt\")\n", "SAVE_DIR = Path(\"../../docs/tutorials/zeroshot_output\")\n", "SAVE_DIR.mkdir(exist_ok=True)" ], "outputs": [], "execution_count": 1 }, { "metadata": {}, "cell_type": "markdown", "source": "**Since this is a zeroshot task, we dont need to prepare our dataloaders.** We configure the model with the appropriate encoder and decoder checkpoints, and set the mode to \"RNA-protein\". The `HyenaConfig` class is used to define the model configuration parameters such as `d_model`, `emb_dim`, `max_seq_len`, `vocab_len`, and `n_layer`.", "id": "c7641486791dff9e" }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-01T18:24:53.411689Z", "start_time": "2025-08-01T18:24:47.630459Z" } }, "cell_type": "code", "source": [ "enc_cfg = HyenaConfig(\n", " d_model = 128,\n", " emb_dim = 5,\n", " max_seq_len = 19202,\n", " vocab_len = 19202,\n", " n_layer = 1,\n", " output_hidden_states=False,\n", ")\n", "dec_cfg = HyenaConfig(\n", " d_model = 128,\n", " emb_dim = 5,\n", " max_seq_len = 6427,\n", " vocab_len = 6427,\n", " n_layer = 1,\n", " output_hidden_states=False,\n", ")\n", "model = scTrans.load_from_checkpoint(checkpoint_path=FINETUNE_CKPT)\n", "model.encoder_ckpt_path = ENCODER_CKPT\n", "model.decoder_ckpt_path = DECODER_CKPT\n", "model.mode = \"RNA-protein\"" ], "id": "e05dad9a57fed188", "outputs": [], "execution_count": 2 }, { "metadata": {}, "cell_type": "markdown", "source": "Zeroshot Inference with the pre-trained model on the test dataset. Use RNA data to predict proteins in ../../docs/tutorials/protein_names.txt", "id": "c7928d6998c7261d" }, { "metadata": { "ExecuteTime": { "end_time": "2025-08-01T18:24:54.551270Z", "start_time": "2025-08-01T18:24:53.496931Z" } }, "cell_type": "code", "source": [ "import scanpy as sc\n", "import torch\n", "\n", "# only use 10 cells for example\n", "zeroshot_adata = sc.read_h5ad(\"../../data/test_sample_rna.h5ad\")[:10]\n", "zeroshot_rna_tensor = torch.tensor(zeroshot_adata.X.todense(), dtype=torch.float32).cuda()\n", "\n", "model.eval().cuda()\n", "\n", "with torch.no_grad():\n", " _, _, protein_pred = model(zeroshot_rna_tensor)\n", "\n", "# predict given proteins\n", "target_proteins = [line.strip() for line in open(\"../../docs/tutorials/protein_names.txt\")]\n", "\n", "import pandas as pd\n", "prot_map = pd.read_csv(\"../../docs/tutorials/protein_index_map.csv\")\n", "name_to_idx = dict(zip(prot_map[\"name\"], prot_map[\"index\"]))\n", "\n", "idx = [name_to_idx[p] for p in target_proteins if p in name_to_idx]\n", "\n", "pred_df = pd.DataFrame(\n", " protein_pred[:, idx].cpu().numpy(),\n", " columns = target_proteins,\n", " index = zeroshot_adata.obs_names,\n", ")\n", "pred_df.to_csv(SAVE_DIR/\"predicted_protein_expression.csv\")" ], "id": "c1717a304b3ca0fd", "outputs": [], "execution_count": 3 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }