import math
import random
import numpy as np
import pandas as pd
import torch
import tqdm
from peptdeep.model.building_block import (
Hidden_HFace_Transformer,
PositionalEncoding,
SeqAttentionSum,
ascii_embedding,
)
from peptdeep.utils import get_available_device, logging
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from .constants._const import D_MODEL
from .mhc_utils import NonSpecificDigest
random.seed(1337)
np.random.seed(1337)
torch.random.manual_seed(1337)
# peptdeep has removed this function,
# copy it here as a local method.
[docs]
def get_cosine_schedule_with_warmup(
optimizer: torch.optim.Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
) -> LambdaLR:
"""Creates a learning rate scheduler with linear warmup and cosine decay.
The learning rate rises linearly during warmup steps, then follows a cosine
decay curve. Useful for stabilizing early training.
Args:
optimizer: Optimizer to wrap with the scheduler.
num_warmup_steps: Number of steps for linear warmup.
num_training_steps: Total number of training steps.
num_cycles: Number of cosine cycles (default 0.5 for half-cycle).
last_epoch: Index of last epoch (-1 for new training).
Returns:
LambdaLR: A PyTorch learning rate scheduler.
"""
def lr_lambda(current_step: int) -> float:
# linear warmup phase
if current_step < num_warmup_steps:
return current_step / max(1, num_warmup_steps)
# cosine
progress = (current_step - num_warmup_steps) / max(
1, num_training_steps - num_warmup_steps
)
cosine_lr_multiple = 0.5 * (
1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)
)
return max(0.0, cosine_lr_multiple)
return LambdaLR(optimizer, lr_lambda, last_epoch)
[docs]
def get_ascii_indices(seq_array: list[str]) -> torch.LongTensor:
"""Converts a list of peptide sequences into ASCII-encoded index tensors.
Each character in the peptide string is represented by its ASCII code,
reshaped into a 2D tensor.
Args:
seq_array: List of peptide sequence strings (e.g., ['GLCTLVAML', ...]).
Returns:
A tensor of shape (batch_size, sequence_length), dtype=torch.long.
"""
return torch.tensor(
np.array(seq_array).view(np.int32).reshape(len(seq_array), -1),
dtype=torch.long,
)
[docs]
class ModelSeqEncoder(torch.nn.Module):
"""Transformer-based encoder for peptide sequences."""
[docs]
def __init__(
self, d_model: int = D_MODEL, layer_num: int = 4, dropout: float = 0.2
) -> None:
"""Initialize the sequence encoder.
Args:
d_model: Embedding dimension.
layer_num: Number of Transformer layers.
dropout: Dropout rate for Transformer layers.
"""
super().__init__()
self.embedding = ascii_embedding(d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len=100)
self.bert = Hidden_HFace_Transformer(
hidden_dim=d_model, nlayers=layer_num, dropout=dropout
)
self.out_nn = SeqAttentionSum(d_model)
[docs]
def forward(self, aa_idxes: torch.Tensor) -> torch.Tensor:
"""Encode peptide sequences to embeddings.
Args:
aa_idxes: Tensor of shape (batch_size, seq_len) with ASCII indices.
Returns:
Normalized embedding tensor of shape (batch_size, d_model).
"""
attention_mask = aa_idxes > 0
x = self.embedding(aa_idxes)
x = self.pos_encoder(x)
x = self.bert(x, attention_mask)[0] * attention_mask.unsqueeze(-1)
return torch.nn.functional.normalize(self.out_nn(x))
[docs]
class ModelHlaEncoder(torch.nn.Module):
"""Transformer-based encoder for HLA embeddings."""
[docs]
def __init__(
self, d_model: int = D_MODEL, layer_num: int = 1, dropout: float = 0.2
) -> None:
"""Initialize the HLA encoder.
Args:
d_model: Embedding dimension.
layer_num: Number of Transformer layers.
dropout: Dropout rate for Transformer layers.
"""
super().__init__()
self.nn = Hidden_HFace_Transformer(d_model, nlayers=layer_num, dropout=dropout)
self.out_nn = SeqAttentionSum(d_model)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Encodes variable-length HLA embeddings into fixed-size vectors.
Args:
x: Input tensor of shape (batch_size, seq_len, d_model),
typically from ESM models.
Returns:
Normalized embedding tensor of shape (batch_size, d_model).
"""
attn_mask = (x != 0).any(dim=2)
x = self.nn(x, attn_mask)[0] * attn_mask.unsqueeze(-1)
return torch.nn.functional.normalize(self.out_nn(x))
[docs]
class HlaDataSet(Dataset):
"""Dataset providing paired HLA embeddings and peptides for training."""
[docs]
def __init__(
self,
hla_df: pd.DataFrame,
hla_esm_list: list[np.ndarray],
pept_df: pd.DataFrame | None,
protein_data: pd.DataFrame | list | str,
min_peptide_len: int = 8,
max_peptide_len: int = 14,
) -> None:
"""Initialize the dataset.
Args:
hla_df: DataFrame with HLA information; must have 'allele' column.
hla_esm_list: List of HLA ESM embeddings corresponding to hla_df rows.
pept_df: Peptide DataFrame with columns 'sequence' and 'allele'.
protein_data: Protein FASTA path(s) or DataFrame to generate negatives.
min_peptide_len: Minimum length for random digestion.
max_peptide_len: Maximum length for random digestion.
"""
self.hla_esm_list = hla_esm_list
hla_df["hla_id"] = range(len(hla_df))
self.allele_idxes_dict: dict = (
hla_df.groupby("allele")["hla_id"].apply(list).to_dict()
)
self._expand_allele_names()
self.hla_df = hla_df
if pept_df is not None:
self.pept_df = (
pept_df.groupby("sequence")[["allele"]]
.agg(list)
.reset_index(drop=False)
)
self.pept_seq_list = self.pept_df.sequence
self.pept_allele_list = self.pept_df.allele
self.digest = NonSpecificDigest(protein_data, min_peptide_len, max_peptide_len)
self.prob_pept_from_hla_df = 0.8
def _expand_allele_names(self) -> None:
"""Add underscore-free allele names to ``allele_idxes_dict``."""
self.allele_idxes_dict.update(
[
(allele.replace("_", ""), val)
for allele, val in self.allele_idxes_dict.items()
]
)
[docs]
def get_neg_pept(self) -> str:
"""Sample a negative peptide sequence.
Returns:
Random peptide string from the dataset or digested proteins.
"""
if random.random() > self.prob_pept_from_hla_df:
return self.pept_seq_list[random.randint(0, len(self.pept_seq_list) - 1)]
idx = random.randint(0, len(self.digest.digest_starts) - 1)
return self.digest.cat_protein_sequence[
self.digest.digest_starts[idx] : self.digest.digest_stops[idx]
]
[docs]
def get_allele_embed(self, index: int) -> np.ndarray:
"""Get HLA embedding for a specific peptide.
Args:
index: Index of the peptide.
Returns:
Corresponding HLA embedding.
"""
alleles = self.pept_allele_list[index]
allele = alleles[random.randint(0, len(alleles) - 1)]
hla_ids = self.allele_idxes_dict[allele]
return self.hla_esm_list[hla_ids[random.randint(0, len(hla_ids) - 1)]]
def __getitem__(self, index: int) -> tuple[np.ndarray, str, str]:
"""Returns a training triplet: (HLA embed, positive peptide, negative peptide).
Args:
index: Index of the sample.
Returns:
A tuple containing:
- hla_embedding: HLA ESM embedding.
- pos_peptide: Known binding peptide.
- neg_peptide: Non-binding (negative) peptide.
"""
return (
self.get_allele_embed(index),
self.pept_seq_list[index],
self.get_neg_pept(),
)
def __len__(self) -> int:
"""Return number of peptide samples."""
return len(self.pept_df)
[docs]
def batchify_hla_esm_list(batch_esm_list: list[np.ndarray]) -> torch.Tensor:
"""Converts a list of variable-length HLA ESM embeddings into a padded tensor.
Args:
batch_esm_list: List of arrays, each of shape (1, seq_len, d_model).
Returns:
Padded tensor of shape (batch_size, max_seq_len, d_model).
"""
max_hla_len = max(len(x) for x in batch_esm_list)
hla_x = np.zeros(
(len(batch_esm_list), max_hla_len, batch_esm_list[0].shape[-1]),
dtype=np.float32,
)
for i, x in enumerate(batch_esm_list):
hla_x[i, : len(x[0]), :] = x[0]
return torch.tensor(hla_x, dtype=torch.float32)
[docs]
def pept_hla_collate(
batch: list[tuple[np.ndarray, str, str]],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Collate function for creating batches from HlaDataSet.
Handles variable-length HLA embeddings and ASCII-encodes peptides.
Args:
batch: List of tuples (hla_embed, pos_peptide, neg_peptide).
Returns:
A tuple of:
- hla_tensor: Padded HLA embeddings.
- pos_pept_tensor: ASCII-encoded positive peptides.
- neg_pept_tensor: ASCII-encoded negative peptides.
"""
hla_embeds = [x[0] for x in batch]
pos_pept_array = [x[1] for x in batch]
neg_pept_array = [x[2] for x in batch]
max_hla_len = max(len(x) for x in hla_embeds)
hla_x = np.zeros(
(len(batch), max_hla_len, hla_embeds[0].shape[-1]), dtype=np.float32
)
for i, x in enumerate(hla_embeds):
hla_x[i, : len(x[0]), :] = x[0]
return (
torch.tensor(hla_x, dtype=torch.float32),
get_ascii_indices(pos_pept_array),
get_ascii_indices(neg_pept_array),
)
[docs]
def get_hla_dataloader(
dataset: HlaDataSet, batch_size: int, shuffle: bool
) -> DataLoader:
"""Creates a DataLoader for HlaDataSet with custom collation.
Args:
dataset: The dataset to load.
batch_size: Number of samples per batch.
shuffle: Whether to shuffle data each epoch.
Returns:
A DataLoader with pept_hla_collate as collate_fn.
"""
return DataLoader(
dataset=dataset,
collate_fn=pept_hla_collate,
batch_size=batch_size,
shuffle=shuffle,
)
[docs]
class SiameseCELoss:
"""Contrastive Siamese loss for HLA-peptide similarity learning.
Encourages the model to bring positive pairs closer and push negative pairs apart.
Uses margin-based contrastive loss.
"""
margin: float = 1
[docs]
def get_loss(
self, hla_x: torch.Tensor, x: torch.Tensor, y: float = 1.0
) -> torch.Tensor:
"""Computes contrastive loss for one pair.
Args:
hla_x: HLA embedding tensor.
x: Peptide embedding tensor.
y: Label (1.0 for positive pair, 0.0 for negative).
Returns:
Scalar loss tensor.
"""
diff = hla_x - x
dist_sq = torch.sum(torch.pow(diff, 2), 1)
dist = torch.sqrt(dist_sq)
mdist = self.margin - dist
dist = torch.clamp(mdist, min=0.0)
loss = y * dist_sq + (1 - y) * torch.pow(dist, 2)
loss = torch.mean(loss) / 2.0
return loss
def __call__(
self, hla_x: torch.Tensor, pos_x: torch.Tensor, neg_x: torch.Tensor
) -> torch.Tensor:
"""Computes total Siamese loss from positive and negative triplets.
Args:
hla_x: HLA embedding.
pos_x: Positive (binding) peptide embedding.
neg_x: Negative (non-binding) peptide embedding.
Returns:
Combined loss tensor.
"""
loss0 = self.get_loss(hla_x, pos_x, 1)
loss1 = self.get_loss(hla_x, neg_x, 0)
return (loss0 + loss1) / 2
[docs]
def train(
hla_encoder: ModelHlaEncoder,
pept_encoder: ModelSeqEncoder,
dataset: HlaDataSet,
batch_size: int = 256,
lr: float = 1e-4,
epoch: int = 100,
warmup_epoch: int = 20,
verbose: bool = True,
device: str = "cuda",
test_bundle: tuple | None = None,
neptune_run=None,
) -> None:
"""Train the peptide/HLA encoders.
Args:
hla_encoder: Encoder for HLA embeddings.
pept_encoder: Encoder for peptide sequences.
dataset: Training dataset.
batch_size: Number of samples per batch.
lr: Learning rate for the optimizer.
epoch: Total number of epochs.
warmup_epoch: Number of warmup epochs for the scheduler.
verbose: Whether to print training progress.
device: Device identifier for ``torch.device``.
test_bundle: Optional tuple of test data passed to :func:`test`.
neptune_run: Optional Neptune experiment for logging.
"""
loss_func = SiameseCELoss()
dataloader = get_hla_dataloader(dataset, batch_size, True)
device = torch.device(device)
hla_encoder.to(device)
pept_encoder.to(device)
optimizer = torch.optim.Adam(
[
{"params": pept_encoder.parameters()},
{"params": hla_encoder.parameters()},
],
lr=lr,
)
if warmup_epoch > 0:
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_epoch,
num_training_steps=epoch,
)
else:
lr_scheduler = None
if verbose:
logging.info(f"{len(dataset)} training samples")
for i_epoch in range(epoch):
hla_encoder.train()
pept_encoder.train()
loss_list = []
for hla_x, pos_x, neg_x in dataloader:
hla_x = hla_encoder(hla_x.to(device))
pos_x = pept_encoder(pos_x.to(device))
neg_x = pept_encoder(neg_x.to(device))
loss = loss_func(hla_x, pos_x, neg_x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
if lr_scheduler:
lr_scheduler.step()
_lr = lr_scheduler.get_last_lr()[0]
else:
_lr = lr
mean_loss = np.mean(loss_list)
if verbose:
logging.info(f"[Epoch={i_epoch}] loss={mean_loss:.5f}, lr={_lr:.3e}")
if test_bundle:
test_df, test_allele_list, hla_df, hla_esm_list, fasta_list = test_bundle
(
mean_rank01_recall_rate,
mean_rank05_recall_rate,
mean_rank20_recall_rate,
) = test(
test_df,
test_allele_list,
hla_encoder,
pept_encoder,
hla_df,
hla_esm_list,
fasta_list,
)
print(
f"test alleles rank%<0.1 average recall rate: {mean_rank01_recall_rate:.2f}"
)
print(
f"test alleles rank%<0.5 average recall rate: {mean_rank05_recall_rate:.2f}"
)
print(
f"test alleles rank%<2 average recall rate: {mean_rank20_recall_rate:.2f}"
)
if neptune_run:
neptune_run["train/loss"].log(mean_loss)
neptune_run["train/lr"].log(_lr)
if test_bundle:
neptune_run["test/loss1"].log(mean_rank01_recall_rate)
neptune_run["test/loss2"].log(mean_rank05_recall_rate)
neptune_run["test/loss3"].log(mean_rank20_recall_rate)
[docs]
def embed_hla_esm_list(
hla_encoder: ModelHlaEncoder,
hla_esm_list: list[np.ndarray],
batch_size: int = 200,
device: str | torch.device | None = None,
verbose: bool = False,
) -> np.ndarray:
"""Generates fixed-size embeddings for a list of HLA ESM features.
Args:
hla_encoder: Trained HLA encoder model.
hla_esm_list: List of raw ESM embeddings for HLA alleles.
batch_size: Inference batch size.
device: Device to use. Auto-detected if None.
verbose: Show progress bar.
Returns:
Array of shape (num_hla, d_model) containing encoded HLA embeddings.
"""
if not device:
device = get_available_device()[0]
hla_encoder.to(device)
hla_encoder.eval()
embeds = np.zeros((len(hla_esm_list), hla_esm_list[0].shape[-1]), dtype=np.float32)
with torch.no_grad():
batches = range(0, len(hla_esm_list), batch_size)
if verbose:
batches = tqdm.tqdm(batches)
for i in batches:
x = batchify_hla_esm_list(hla_esm_list[i : i + batch_size]).to(device)
embeds[i : i + batch_size] = hla_encoder(x).detach().cpu().numpy()
torch.cuda.empty_cache()
return embeds
[docs]
def embed_peptides(
pept_encoder: ModelSeqEncoder,
seqs: list[str],
d_model: int = D_MODEL,
batch_size: int = 512,
device: str | torch.device | None = None,
verbose: bool = False,
) -> np.ndarray:
"""Encodes a list of peptide sequences into embeddings.
Args:
pept_encoder: Trained peptide encoder model.
seqs: List of peptide strings.
d_model: Expected embedding dimension.
batch_size: Inference batch size.
device: Device to use (auto-detected if None).
verbose: Show progress bar.
Returns:
Array of shape (num_peptides, d_model) with peptide embeddings.
"""
if not device:
device = get_available_device()[0]
pept_encoder.to(device)
pept_encoder.eval()
embeds = np.zeros((len(seqs), d_model), dtype=np.float32)
with torch.no_grad():
batches = range(0, len(seqs), batch_size)
if verbose:
batches = tqdm.tqdm(batches)
for i in batches:
x = get_ascii_indices(seqs[i : i + batch_size]).to(device)
embeds[i : i + batch_size, :] = pept_encoder(x).detach().cpu().numpy()
torch.cuda.empty_cache()
return embeds
[docs]
def test(
test_df: pd.DataFrame,
test_allele_list,
hla_encoder: ModelHlaEncoder,
pept_encoder: ModelSeqEncoder,
hla_df: pd.DataFrame,
hla_esm_list: list[np.ndarray],
fasta_list: list[str],
) -> tuple[float, float, float]:
"""Evaluates model performance on test alleles using rank-based recall.
Args:
test_df: DataFrame with test peptide-allele pairs.
test_allele_list: List of HLA alleles to evaluate.
hla_encoder: Trained HLA encoder.
pept_encoder: Trained peptide encoder.
hla_df: HLA metadata DataFrame.
hla_esm_list: List of raw HLA ESM embeddings.
fasta_list: List of protein FASTA file paths.
Returns:
Tuple of mean recall rates at rank < 0.1, < 0.5, and < 2.0.
"""
from .mhc_binding_retriever import MHCBindingRetriever
hla_embeds = embed_hla_esm_list(hla_encoder, hla_esm_list)
retriever = MHCBindingRetriever(
hla_encoder, pept_encoder, hla_df, hla_embeds, fasta_list
)
retriever.n_decoy_samples = 1000000
pept_groups = test_df.groupby("allele")
rank01_list = []
rank05_list = []
rank20_list = []
for i in range(len(test_allele_list)):
tmp_allele = test_allele_list[i]
pept_df = pept_groups.get_group(tmp_allele)
embed = retriever.hla_embeds[retriever.dataset.allele_idxes_dict[tmp_allele][0]]
df = retriever.get_binding_metrics_for_embeds(embed, pept_df.sequence.values)
rank01_list.append(len(df.query("best_allele_rank<=0.1")) / len(df))
rank05_list.append(len(df.query("best_allele_rank<=0.5")) / len(df))
rank20_list.append(len(df.query("best_allele_rank<=2")) / len(df))
return np.mean(rank01_list), np.mean(rank05_list), np.mean(rank20_list)