from tqdm import tqdm
from anndata import AnnData
import numpy as np
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence as kl
from torch.utils.data import DataLoader
from scproca import settings
from scproca.module.scproca_module import scProca_VAE
from scproca.utils.data import data2input, to_one_hot, DS, split_dataset_into_train_valid
from scproca.utils.estimation import get_priors_size, get_init_priors_adt
from scproca.utils.optimization import compute_kl_weight, EarlyStopping
from sklearn.preprocessing import LabelEncoder
from torch.optim.lr_scheduler import ReduceLROnPlateau
from typing import Literal, List
import logging
logger = logging.getLogger("scProca")
[docs]
class scProca:
"""Single-cell model of proteomics with/from transcriptomics using cross-attention.
Parameters
----------
adata : AnnData
AnnData object. `adata.X` contains the RNA measurements.
key_adt : str
The key used to access ADT measurements stored in `adata.obsm`.
key_batch : str
The key used to access batch annotations stored in `adata.obs`.
key_valid_adt : str
The key used to access whether the ADT measurements are valid or just placeholders in `adata.obs`.
d_latent : int, optional (default=20)
The dimensionality of the latent space.
distribution_rna : {"ZINB", "NB"}, optional (default="NB")
The distribution to model RNA data. One of:
* ``'NB'`` - Negative Binomial distribution
* ``'ZINB'`` - Zero-Inflated Negative Binomial distribution
distribution_adt : {"MixtureNB", "NB"}, optional (default="MixtureNB")
The distribution to model ADT data. One of:
* ``'NB'`` - Negative Binomial distribution
* ``'MixtureNB'`` - Mixture of two Negative Binomial distributions
activation : {"relu", "mish"}, optional (default="mish")
The activation function used in the neural networks. One of:
* ``'relu'`` - Rectified Linear Unit
* ``'mish'`` - Mish activation function
norm : {"BatchNorm", "LayerNorm"}, optional (default="LayerNorm")
The type of normalization used in the networks. One of:
* ``'BatchNorm'`` - Batch normalization
* ``'LayerNorm'`` - Layer normalization
mode : {"none", "cross_attention", "NN"}, optional (default="cross_attention")
Defines the mode of interaction between RNA and ADT data. One of:
* ``'none'`` - No interaction between RNA and ADT data (independent processing).
* ``'cross_attention'`` - Cross-attention mechanism between RNA and ADT data.
* ``'NN'`` - Nearest Neighbors averaging approach.
dropout : float, optional (default=0.2)
The dropout rate applied during training.
d_hidden : tuple of int, optional (default=(256, 256))
A tuple indicating the number of neurons in each hidden layer.
pre_to_device : bool, optional (default=True)
Whether to move the data to the device (e.g., GPU) beforehand to reduce data transfer overhead.
For large datasets, this should be set to False.
Examples
--------
>>> scproca.settings.seed = seed
>>> scproca.settings.batch_size = batch_size (default=512)
>>> scproca.settings.device = index_cuda (None if using 'cpu')
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> batch = adata.obs[key_batch].values.ravel()
>>> valid_adt = np.array([True] * len(adata))
>>> valid_adt[adt_not_valid] = False
>>> adata.obs["valid_adt"] = valid_adt
>>> scproca = scProca(adata=adata, key_adt=key_adt, key_batch=key_batch, key_valid_adt="valid_adt")
>>> scproca.train()
>>> adata.obsm["latent"], adata.obsm["embedding_rna"], adata.obsm["embedding_adt"] = scproca.get_latent_representation()
>>> adata.obsm["protein_generation"] = scproca.generation(anchor_batch=list_str_anchor_batch)
"""
def __init__(
self,
adata: AnnData,
key_adt: str,
key_batch: str,
key_valid_adt: str,
d_latent: int = 20,
distribution_rna: Literal["ZINB", "NB"] = "NB",
distribution_adt: Literal["MixtureNB", "NB"] = "MixtureNB",
activation: Literal["relu", "mish"] = "mish",
norm: Literal["BatchNorm", "LayerNorm"] = "LayerNorm",
mode: Literal["none", "cross_attention", "NN"] = "cross_attention",
dropout: float = 0.2,
d_hidden: tuple = (256, 256),
pre_to_device: bool = True,
):
self.device = settings.device
self.rna = data2input(adata.X, pre_to_device)
self.adt = data2input(adata.obsm[key_adt].to_numpy().astype(np.float32), pre_to_device)
self.valid_adt = data2input(adata.obs[key_valid_adt].to_numpy(), pre_to_device)
self.batch = adata.obs[key_batch].to_numpy()
self.n_batch = len(np.unique(self.batch))
label_encoder = LabelEncoder()
batch_code = np.array(label_encoder.fit_transform(self.batch))
self.batch_mapping = dict(zip(np.unique(self.batch), label_encoder.transform(np.unique(self.batch))))
self.batch_one_hot = data2input(to_one_hot(batch_code, self.n_batch), pre_to_device)
self.pre_to_device = pre_to_device
self.distribution_rna = distribution_rna
self.distribution_adt = distribution_adt
if self.distribution_adt == "MixtureNB":
init_background_mean_adt, init_background_std_adt = get_init_priors_adt(
adata.obsm[key_adt], batch_code, adata.obs[key_valid_adt].to_numpy())
else:
init_background_mean_adt, init_background_std_adt = None, None
if self.distribution_adt == "NB":
library_log_means, library_log_vars = get_priors_size(
adata.obsm[key_adt], batch_code)
else:
library_log_means, library_log_vars = None, None
self.module = scProca_VAE(
d_rna=self.rna.shape[-1],
d_adt=self.adt.shape[-1],
n_batch=self.n_batch,
d_latent=d_latent,
d_hidden=d_hidden,
activation=activation,
norm=norm,
dropout=dropout,
distribution_rna=self.distribution_rna,
distribution_adt=self.distribution_adt,
mode=mode,
prior_parameters={
"init_background_mean_adt": init_background_mean_adt,
"init_background_std_adt": init_background_std_adt,
"library_log_means": library_log_means,
"library_log_vars": library_log_vars,
}
).to(self.device)
self.log = {
"loss_elbo": [],
"loss_discriminator": [],
}
[docs]
def train(
self,
batch_size: int | None = None, lambda_a: float = 30.0,
adversarial_step=1, epochs=400, lr=4e-3,
ratio_val: float = 0.1, epochs_warmup: int | None = None,
steps_warmup: int | None = None,
bool_also_reconstructed_from_embedding: bool = True,
):
"""Trains the model using variational inference.
Parameters
----------
batch_size : int, optional
The minibatch size used during training. Can also be specified via `scproca.settings.batch_size`.
lambda_a : float (default=30.0)
The coefficient for the adversarial loss.
adversarial_step : int, optional (default=1)
The number of steps for adversarial network optimization in each training epoch.
epochs : int, optional (default=400)
The maximum number of training epochs.
lr : float, optional (default=4e-3)
The learning rate.
ratio_val : float, optional (default=0.1)
The proportion of the dataset used as the validation set.
epochs_warmup : int, optional (default=None)
The number of epochs to use for warmup.
steps_warmup : int, optional (default=None)
The number of steps to use for warmup.
bool_also_reconstructed_from_embedding : bool, optional (default=True)
Whether to additionally train the reconstruction loss from the embeddings,
apart from the latent space reconstruction loss.
"""
if batch_size is None:
batch_size = settings.batch_size
else:
settings.batch_size = batch_size
parameters_classifiers = (
list(self.module.classifier_embedding.parameters())
+ list(self.module.classifier_embedding_rna.parameters())
+ list(self.module.classifier_embedding_adt.parameters())
)
optimizer = optim.Adam(
filter(lambda p: id(p) not in map(id, parameters_classifiers),
self.module.parameters()), lr=lr, weight_decay=1e-6, eps=0.01
)
optimizer_discriminator = optim.Adam(
parameters_classifiers, lr=1e-3, weight_decay=1e-6, eps=0.01)
scheduler = ReduceLROnPlateau(
optimizer,
patience=30,
factor=0.6,
threshold=0.0,
min_lr=0,
threshold_mode="abs",
verbose=True
)
epoch_progress = tqdm(range(epochs), desc='Training Progress', ncols=120)
dataset = DS(self.rna, self.adt, self.batch_one_hot, self.valid_adt)
dataset_train, dataset_val = split_dataset_into_train_valid(dataset, batch_size, ratio_val)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader_valid = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)
early_stopping = EarlyStopping(patience=45, delta=0.0)
step = 0
steps_warmup = (
steps_warmup if steps_warmup is not None else int(0.75 * len(dataset))
)
for epoch in epoch_progress:
self.module.train()
for (_, rna, adt, batch_one_hot, valid_adt) in dataloader_train:
lambda_kl_epoch = compute_kl_weight(epoch, step, epochs_warmup, steps_warmup)
lambda_a_epoch = lambda_a + 1.0 - lambda_kl_epoch
if not self.pre_to_device:
rna = rna.to(self.device)
adt = adt.to(self.device)
batch_one_hot = batch_one_hot.to(self.device)
valid_adt = valid_adt.to(self.device)
(latent_mean, latent_std), embedding_rna, embedding_adt, size_rna, size_adt = self.module.encode(
rna, adt, batch_one_hot, valid_adt)
latent_distribution = Normal(latent_mean, latent_std)
loss_kl = kl(latent_distribution, Normal(0, 1)).sum(dim=-1).mean() * lambda_kl_epoch
latent = latent_distribution.rsample()
for _ in range(adversarial_step):
loss_discriminator = self.module.classifier_loss(
latent.detach(), embedding_rna.detach(), embedding_adt.detach(), batch_one_hot
) * lambda_a_epoch
optimizer_discriminator.zero_grad()
loss_discriminator.backward()
optimizer_discriminator.step()
(parameters_unshared_rna_from_latent,
parameters_unshared_rna_from_embedding,
parameters_shared_rna,
parameters_unshared_adt_from_latent,
parameters_unshared_adt_from_embedding,
parameters_shared_adt) = self.module.decode(
latent, embedding_rna, embedding_adt, size_rna, size_adt, batch_one_hot, valid_adt,
also_from_embedding=bool_also_reconstructed_from_embedding, mode="train")
loss_r = self.module.reconstruction_loss(
rna, parameters_unshared_rna_from_latent, parameters_unshared_rna_from_embedding,
parameters_shared_rna,
adt, parameters_unshared_adt_from_latent, parameters_unshared_adt_from_embedding,
parameters_shared_adt,
valid_adt,
also_from_embedding=bool_also_reconstructed_from_embedding
)
if self.distribution_adt == "MixtureNB":
(mean_log_beta, std_log_beta) = parameters_unshared_adt_from_latent[-2:]
(prior_mean_log_beta, prior_std_log_beta) = parameters_shared_adt[-2:]
prior_mean_log_beta = F.linear(
batch_one_hot, prior_mean_log_beta
)
prior_std_log_beta = F.linear(
batch_one_hot, prior_std_log_beta,
)
loss_kl_log_beta = (kl(
Normal(mean_log_beta, std_log_beta),
Normal(prior_mean_log_beta, prior_std_log_beta)
).sum(-1) * valid_adt.float()).mean() * lambda_kl_epoch
else:
loss_kl_log_beta = 0.0
loss_adversarial = - self.module.classifier_loss(
latent, embedding_rna, embedding_adt, batch_one_hot) * lambda_a_epoch
loss = loss_kl + loss_r + loss_kl_log_beta + loss_adversarial
optimizer.zero_grad()
loss.backward()
optimizer.step()
step += 2
self.module.eval()
with torch.no_grad():
loss_elbo = 0.0
loss_discriminator = 0.0
n_val = 0
for (_, rna, adt, batch_one_hot, valid_adt) in dataloader_valid:
if not self.pre_to_device:
rna = rna.to(self.device)
adt = adt.to(self.device)
batch_one_hot = batch_one_hot.to(self.device)
valid_adt = valid_adt.to(self.device)
(latent_mean, latent_std), embedding_rna, embedding_adt, size_rna, size_adt = self.module.encode(
rna, adt, batch_one_hot, valid_adt)
latent_distribution = Normal(latent_mean, latent_std)
loss_kl = kl(latent_distribution, Normal(0, 1)).sum(dim=-1).mean()
latent = latent_distribution.rsample()
(parameters_unshared_rna_from_latent,
parameters_unshared_rna_from_embedding,
parameters_shared_rna,
parameters_unshared_adt_from_latent,
parameters_unshared_adt_from_embedding,
parameters_shared_adt) = self.module.decode(
latent, embedding_rna, embedding_adt, size_rna, size_adt, batch_one_hot, valid_adt,
also_from_embedding=bool_also_reconstructed_from_embedding, mode="train"
)
loss_r = self.module.reconstruction_loss(
rna, parameters_unshared_rna_from_latent, parameters_unshared_rna_from_embedding,
parameters_shared_rna,
adt, parameters_unshared_adt_from_latent, parameters_unshared_adt_from_embedding,
parameters_shared_adt,
valid_adt,
also_from_embedding=bool_also_reconstructed_from_embedding
)
if self.distribution_adt == "MixtureNB":
(mean_log_beta, std_log_beta) = parameters_unshared_adt_from_latent[-2:]
(prior_mean_log_beta, prior_std_log_beta) = parameters_shared_adt[-2:]
prior_mean_log_beta = F.linear(
batch_one_hot, prior_mean_log_beta
)
prior_std_log_beta = F.linear(
batch_one_hot, prior_std_log_beta,
)
loss_kl_log_beta = (kl(
Normal(mean_log_beta, std_log_beta),
Normal(prior_mean_log_beta, prior_std_log_beta)
).sum(-1) * valid_adt.float()).mean()
else:
loss_kl_log_beta = 0.0
loss_elbo += loss_kl + loss_r + loss_kl_log_beta
loss_discriminator += self.module.classifier_loss(
latent, embedding_rna, embedding_adt, batch_one_hot)
n_val += 1
if n_val > 0:
loss_elbo /= n_val
loss_discriminator /= n_val
self.log["loss_elbo"].append((float(epoch), float(loss_elbo.item())))
self.log["loss_discriminator"].append((float(epoch), float(loss_discriminator.item())))
epoch_progress.set_postfix(
loss_elbo=loss_elbo.item(),
loss_discriminator=loss_discriminator.item(),
)
early_stop = early_stopping(loss_elbo.item())
scheduler.step(loss_elbo.item())
if n_val > 0 and early_stop:
print("Early stopping triggered.")
break
[docs]
@torch.inference_mode()
def get_latent_representation(self, n_shuffle: int | None = 100):
"""Infers the integrated latent representation, RNA-specific embedding, and ADT-specific embedding for each cell.
Parameters
----------
n_shuffle : int, optional (default=100)
The number of repetitions used to estimate the mean representation.
Returns
-------
- **latent** - integrated latent representation
- **embedding_rna** - RNA-specific embedding representation
- **embedding_adt** - ADT-specific embedding representation
"""
batch_size = settings.batch_size
dataset = DS(self.rna, self.adt, self.batch_one_hot, self.valid_adt)
self.module.eval()
sum_latent = np.zeros((len(dataset), self.module.d_latent), dtype=np.float32)
sum_embedding_rna = np.zeros((len(dataset), self.module.d_latent), dtype=np.float32)
sum_embedding_adt = np.zeros((len(dataset), self.module.d_latent), dtype=np.float32)
counts = np.zeros((len(dataset), 1), dtype=np.int32)
epoch_progress = tqdm(range(n_shuffle), desc='Inference Progress', ncols=120)
for _ in epoch_progress:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
latent = np.zeros((len(dataset), self.module.d_latent), dtype=np.float32)
embedding_rna = np.zeros((len(dataset), self.module.d_latent), dtype=np.float32)
embedding_adt = np.zeros((len(dataset), self.module.d_latent), dtype=np.float32)
for (idx, rna, adt, batch_one_hot, valid_adt) in dataloader:
if not self.pre_to_device:
rna = rna.to(self.device)
adt = adt.to(self.device)
batch_one_hot = batch_one_hot.to(self.device)
valid_adt = valid_adt.to(self.device)
(latent_mean, _), embedding_from_rna, embedding_from_adt, _, _ = self.module.encode(
rna, adt, batch_one_hot, valid_adt)
latent[idx.numpy()] = latent_mean.cpu().numpy()
embedding_rna[idx.numpy()] = embedding_from_rna.cpu().numpy()
embedding_adt[idx.numpy()] = embedding_from_adt.cpu().numpy()
counts[idx.numpy(), 0] += 1
sum_latent = sum_latent + latent
sum_embedding_rna = sum_embedding_rna + embedding_rna
sum_embedding_adt = sum_embedding_adt + embedding_adt
return sum_latent / counts, sum_embedding_rna / counts, sum_embedding_adt / counts
[docs]
@torch.inference_mode()
def generation(self, anchor_batch: str | List[str] | None, n_shuffle: int | None = 100):
"""Generates ADT measurements for each cell.
Parameters
----------
anchor_batch : str or List[str], optional (default=None)
The batch or list of batches used to the generated measurements.
If None, it refers to the original batch of the cells.
n_shuffle : int, optional (default=100)
The number of repetitions used to estimate the mean generated measurements.
Returns
-------
- **protein_generation** - generated ADT measurements.
"""
anchor_batch_list = anchor_batch if isinstance(anchor_batch, list) else [anchor_batch]
anchor_batch_list = [
None if anchor_batch is None else self.batch_mapping[anchor_batch] for anchor_batch in anchor_batch_list
]
self.module.eval()
batch_size = settings.batch_size
dataset = DS(self.rna, self.adt, self.batch_one_hot, self.valid_adt)
self.module.eval()
sum_expression = np.zeros((len(dataset), self.module.d_adt), dtype=np.float32)
counts_expression = np.zeros((len(dataset), 1), dtype=np.int32)
epoch_progress = tqdm(range(n_shuffle), desc='Inference Progress', ncols=120)
for _ in epoch_progress:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
expression = np.zeros((len(dataset), self.module.d_adt), dtype=np.float32)
for (idx, rna, adt, batch_one_hot, valid_adt) in dataloader:
for anchor_batch in anchor_batch_list:
if not self.pre_to_device:
rna = rna.to(self.device)
adt = adt.to(self.device)
batch_one_hot = batch_one_hot.to(self.device)
valid_adt = valid_adt.to(self.device)
(latent_mean, _), embedding_from_rna, embedding_from_adt, size_rna, size_adt = self.module.encode(
rna, adt, batch_one_hot, valid_adt)
if anchor_batch is not None:
batch_one_hot = torch.zeros_like(batch_one_hot)
batch_one_hot[:, anchor_batch] = torch.ones_like(batch_one_hot[:, anchor_batch])
(_, _, _, parameters_unshared_adt, _, _) = self.module.decode(
latent_mean, embedding_from_rna, embedding_from_adt, size_rna, size_adt, batch_one_hot,
valid_adt, mode="imputation"
)
if self.distribution_adt == "NB":
expression[idx.numpy()] = parameters_unshared_adt[0].cpu().numpy()
if self.distribution_adt == "MixtureNB":
expression[idx.numpy()] = (
torch.sigmoid(parameters_unshared_adt[2]) * parameters_unshared_adt[0]
+ (
1 - torch.sigmoid(parameters_unshared_adt[2])
) * parameters_unshared_adt[0] * parameters_unshared_adt[1]
).cpu().numpy()
counts_expression[idx.numpy(), 0] += 1
sum_expression = sum_expression + expression
return sum_expression / counts_expression
[docs]
def curve_loss(self, key_loss):
"""Plots the loss curve for the validation dataset during the training process.
Parameters
----------
key_loss : str, optional (default="loss_elbo")
The key used to specify which loss to plot. Choices are:
* ``'loss_elbo'`` - ELBO (Evidence Lower Bound) loss
* ``'loss_discriminator'`` - Loss for the discriminators
"""
import matplotlib.pyplot as plt
plt.figure(figsize=(4, 2))
plt.plot(np.array(self.log[key_loss])[:, 0] + 1, np.array(self.log[key_loss])[:, 1])
plt.xlabel('training epoch')
plt.ylabel(key_loss)
plt.title('Loss curve')
plt.show()