Imputation Tutorial

This tutorial demonstrates how to use finetuned scLinguist model to impute protein expression from RNA data.

Import necessary packages and define paths for checkpoints and save directory.

[1]:
from pathlib import Path
from torch.utils.data import DataLoader
import sys
sys.path.append('../../')
from scLinguist.data_loaders.data_loader import scMultiDataset
from scLinguist.model.configuration_hyena import HyenaConfig
from scLinguist.model.model import scTrans
import importlib, sys
sys.modules['model'] = importlib.import_module('scLinguist.model')

ENCODER_CKPT = Path("../../pretrained_model/encoder.ckpt")
DECODER_CKPT = Path("../../pretrained_model/decoder.ckpt")
FINETUNE_CKPT = Path("../../pretrained_model/finetune.ckpt")
SAVE_DIR = Path("../../docs/tutorials/finetune_output")
SAVE_DIR.mkdir(exist_ok=True)

Configure the model with the appropriate encoder and decoder checkpoints, and set the mode to “RNA-protein”. The HyenaConfig class is used to define the model configuration parameters such as d_model, emb_dim, max_seq_len, vocab_len, and n_layer.

[3]:
enc_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 19202,
    vocab_len      = 19202,
    n_layer        = 1,
    output_hidden_states=False,
)
dec_cfg = HyenaConfig(
    d_model        = 128,
    emb_dim        = 5,
    max_seq_len    = 6427,
    vocab_len      = 6427,
    n_layer        = 1,
    output_hidden_states=False,
)
model = scTrans.load_from_checkpoint(checkpoint_path=FINETUNE_CKPT)
model.encoder_ckpt_path = ENCODER_CKPT
model.decoder_ckpt_path = DECODER_CKPT
model.mode = "RNA-protein"

Put your proteins of interest in ../../docs/tutorials/protein_names.txt.

Important: these proteins should be in ../../docs/tutorials/protein_index_map.csv (6427 proteins in total).

[5]:
import scanpy as sc
import torch

# only use 10 cells for example
test_adata = sc.read_h5ad("../../data/test_sample_rna.h5ad")[:10]
rna_tensor = torch.tensor(test_adata.X.todense(), dtype=torch.float32).cuda()

model.encoder_ckpt_path = ENCODER_CKPT
model.decoder_ckpt_path = DECODER_CKPT
model.mode = "RNA-protein"
model.eval().cuda()

with torch.no_grad():
    _, _, protein_pred = model(rna_tensor)

# impute given proteins
target_proteins = [line.strip() for line in open("../../docs/tutorials/protein_names.txt")]

import pandas as pd
prot_map = pd.read_csv("../../docs/tutorials/protein_index_map.csv")
name_to_idx = dict(zip(prot_map["name"], prot_map["index"]))

idx = [name_to_idx[p] for p in target_proteins if p in name_to_idx]

pred_df = pd.DataFrame(
    protein_pred[:, idx].cpu().numpy(),
    columns = target_proteins,
    index   = test_adata.obs_names,
)
pred_df.to_csv(SAVE_DIR/"predicted_protein_expression.csv")