Fewshot Tutorial
This tutorial demonstrates how to finetune the scLinguist model on a fewshot dataset and then use it to predict protein expression from RNA data.
Import necessary packages and define paths for checkpoints and save directory.
[1]:
import torch
from pathlib import Path
from torch.utils.data import DataLoader
import sys
sys.path.append('../../')
from scLinguist.data_loaders.data_loader import scMultiDataset
from scLinguist.model.configuration_hyena import HyenaConfig
from scLinguist.model.model import scTrans
import importlib, sys
sys.modules['model'] = importlib.import_module('scLinguist.model')
ENCODER_CKPT = Path("../../pretrained_model/encoder.ckpt")
DECODER_CKPT = Path("../../pretrained_model/decoder.ckpt")
FINETUNE_CKPT = Path("../../pretrained_model/finetune.ckpt")
SAVE_DIR = Path("../../docs/tutorials/fewshot_output")
SAVE_DIR.mkdir(exist_ok=True)
Configure dataloaders for fewshot and test datasets.
First, inspect the data structure of the datasets to ensure they are compatible with the model.
[25]:
import scanpy as sc
rna_train = sc.read_h5ad('../../data/fewshot_sample_rna.h5ad')
rna_train
[25]:
AnnData object with n_obs × n_vars = 20 × 19202
obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'
[26]:
rna_train.X.todense()
[26]:
matrix([[ 0., 0., 0., ..., 15., 1., 27.],
[ 0., 0., 0., ..., 0., 0., 2.],
[ 0., 0., 0., ..., 2., 0., 7.],
...,
[ 0., 0., 0., ..., 2., 0., 1.],
[ 0., 0., 0., ..., 31., 0., 69.],
[ 0., 0., 0., ..., 2., 0., 8.]], dtype=float32)
[27]:
rna_train.var_names
[27]:
Index(['ENSG00000186092', 'ENSG00000284733', 'ENSG00000284662',
'ENSG00000187634', 'ENSG00000188976', 'ENSG00000187961',
'ENSG00000187583', 'ENSG00000187642', 'ENSG00000188290',
'ENSG00000187608',
...
'ENSG00000198712', 'ENSG00000228253', 'ENSG00000198899',
'ENSG00000198938', 'ENSG00000198840', 'ENSG00000212907',
'ENSG00000198886', 'ENSG00000198786', 'ENSG00000198695',
'ENSG00000198727'],
dtype='object', length=19202)
[28]:
import scanpy as sc
rna_test = sc.read_h5ad('../../data/test_sample_rna.h5ad')
rna_test
[28]:
AnnData object with n_obs × n_vars = 10546 × 19202
obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'
[29]:
rna_test.X.todense()
[29]:
matrix([[ 0., 0., 0., ..., 5., 0., 3.],
[ 0., 0., 0., ..., 0., 0., 7.],
[ 0., 0., 0., ..., 1., 0., 6.],
...,
[ 0., 0., 0., ..., 1., 0., 2.],
[ 0., 0., 0., ..., 2., 0., 2.],
[ 0., 0., 0., ..., 4., 0., 12.]], dtype=float32)
[30]:
rna_test.var_names
[30]:
Index(['ENSG00000186092', 'ENSG00000284733', 'ENSG00000284662',
'ENSG00000187634', 'ENSG00000188976', 'ENSG00000187961',
'ENSG00000187583', 'ENSG00000187642', 'ENSG00000188290',
'ENSG00000187608',
...
'ENSG00000198712', 'ENSG00000228253', 'ENSG00000198899',
'ENSG00000198938', 'ENSG00000198840', 'ENSG00000212907',
'ENSG00000198886', 'ENSG00000198786', 'ENSG00000198695',
'ENSG00000198727'],
dtype='object', length=19202)
[31]:
import scanpy as sc
adt_train = sc.read_h5ad('../../data/train_sample_adt.h5ad')
adt_train
[31]:
AnnData object with n_obs × n_vars = 16994 × 6427
obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'
[32]:
import numpy as np
mask = ~np.isnan(adt_train.X[0].toarray())
adt_train[:, mask].X.todense()
[32]:
matrix([[1.100e+02, 1.400e+01, 4.900e+01, ..., 1.100e+01, 2.120e+02,
2.800e+01],
[1.200e+02, 1.000e+00, 5.890e+02, ..., 2.000e+00, 3.600e+01,
2.400e+01],
[6.450e+02, 5.000e+00, 1.256e+03, ..., 1.000e+00, 7.200e+01,
1.320e+02],
...,
[2.330e+02, 2.700e+01, 8.420e+02, ..., 3.000e+00, 7.700e+01,
4.600e+01],
[3.120e+02, 1.500e+01, 1.079e+03, ..., 4.000e+00, 4.800e+01,
8.000e+01],
[1.960e+02, 2.000e+00, 1.910e+02, ..., 0.000e+00, 3.400e+01,
1.900e+01]])
[38]:
adt_train.var_names
[38]:
Index(['SP110', 'GTPBA', 'SNX2', 'FRG1', 'TT21A', 'RHG18', 'AR', 'DOCK1',
'RAB1A', 'MUC1.HMFG2',
...
'CYTSA', 'LFNG', 'PFKFB4', 'LIPB1', 'ZN225', 'TRI69', 'CCL14', 'ZN541',
'TAP1', 'SCG3'],
dtype='object', length=6427)
[33]:
import scanpy as sc
adt_test = sc.read_h5ad('../../data/test_sample_adt.h5ad')
adt_test
[33]:
AnnData object with n_obs × n_vars = 10546 × 6427
obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'lane', 'donor', 'celltype.l1', 'celltype.l2', 'RNA.weight', 'cell_type'
Then, create the dataloaders for both fewshot and test datasets. The scMultiDataset class is used to load the RNA and protein data from the specified paths.
[34]:
BATCH_SIZE = 4
fewshot_data = scMultiDataset(
data_dir_1="../../data/fewshot_sample_rna.h5ad",
data_dir_2="../../data/fewshot_sample_adt.h5ad",
)
test_data = scMultiDataset(
data_dir_1="../../data/test_sample_rna.h5ad",
data_dir_2="../../data/test_sample_adt.h5ad",
)
fewshot_dataloader = DataLoader(
fewshot_data,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=8,
pin_memory=True
)
test_dataloader = DataLoader(
test_data,
batch_size=BATCH_SIZE,
shuffle=False,
drop_last=False,
num_workers=0,
pin_memory=True,
)
Last, configure the model with the appropriate encoder and decoder checkpoints, and set the mode to “RNA-protein”. The HyenaConfig class is used to define the model configuration parameters such as d_model, emb_dim, max_seq_len, vocab_len, and n_layer.
[35]:
enc_cfg = HyenaConfig(
d_model = 128,
emb_dim = 5,
max_seq_len = 19202,
vocab_len = 19202,
n_layer = 1,
output_hidden_states=False,
)
dec_cfg = HyenaConfig(
d_model = 128,
emb_dim = 5,
max_seq_len = 6427,
vocab_len = 6427,
n_layer = 1,
output_hidden_states=False,
)
model = scTrans.load_from_checkpoint(checkpoint_path=FINETUNE_CKPT)
model.encoder_ckpt_path = ENCODER_CKPT
model.decoder_ckpt_path = DECODER_CKPT
model.mode = "RNA-protein"
Start training the model using PyTorch Lightning. The ModelCheckpoint callback is used to save the best model based on validation loss, and the EarlyStopping callback is used to stop training if the validation loss does not improve for a specified number of epochs.
[37]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
ckpt_cb = ModelCheckpoint(
dirpath = SAVE_DIR/"ckpt",
monitor = "valid_loss",
mode = "min",
save_top_k = 1,
filename = "best-{epoch}-{valid_loss:.4f}",
)
early_cb = EarlyStopping(monitor="valid_loss", mode="min", patience=3)
trainer = pl.Trainer(
accelerator = "gpu",
devices = [1],
max_epochs = 6,
log_every_n_steps = 50,
callbacks = [ckpt_cb, early_cb],
)
trainer.fit(model, fewshot_dataloader, fewshot_dataloader)
best_ckpt = ckpt_cb.best_model_path
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
| Name | Type | Params
------------------------------------------------
0 | encoder | scHeyna_enc | 313 K
1 | decoder | scHeyna_dec | 249 K
2 | translator | MLPTranslator | 284 M
3 | cos_gene | CosineSimilarity | 0
4 | cos_cell | CosineSimilarity | 0
------------------------------------------------
285 M Trainable params
0 Non-trainable params
285 M Total params
1,141.275 Total estimated model params size (MB)
Inference with the trained model on the test dataset. Use RNA data to predict proteins in ../../docs/tutorials/protein_names.txt
[ ]:
import scanpy as sc
import torch
# only use 10 cells for example
fewshot_adata = sc.read_h5ad("../../data/test_sample_rna.h5ad")[:10]
sc.pp.normalize_total(fewshot_adata, target_sum=10000)
sc.pp.log1p(fewshot_adata)
fewshot_rna_tensor = torch.tensor(fewshot_adata.X.todense(), dtype=torch.float32).cuda()
model.eval().cuda()
with torch.no_grad():
_, _, protein_pred = model(fewshot_rna_tensor)
# predict given proteins
target_proteins = [line.strip() for line in open("../../docs/tutorials/protein_names.txt")]
import pandas as pd
prot_map = pd.read_csv("../../docs/tutorials/protein_index_map.csv")
name_to_idx = dict(zip(prot_map["name"], prot_map["index"]))
idx = [name_to_idx[p] for p in target_proteins if p in name_to_idx]
pred_df = pd.DataFrame(
protein_pred[:, idx].cpu().numpy(),
columns = target_proteins,
index = fewshot_adata.obs_names,
)
pred_df.to_csv(SAVE_DIR/"predicted_protein_expression.csv")