SYSTEM / JOURNAL CLUB / MACHINE LEARNING / ARXIV / LANGUAGE MODELS
Analysis of Cole et al. (2026), GenBio AI — bioRxiv preprint Generated on March 20, 2026
This paper tackles a central question in computational biology: can foundation models (FMs) — large pretrained neural networks — actually help predict how cells respond to genetic or chemical perturbations? Recent literature has given contradictory answers, with some groups claiming FMs are no better than simple baselines like PCA. The authors resolve this by running an exhaustive benchmark of over 600 models.
The key finding is nuanced: some FMs do fail to beat baselines, but others — particularly those trained on protein-protein interaction networks (interactome data) — significantly outperform simple methods. Furthermore, combining embeddings from multiple FMs through an attention-based fusion model pushes performance even further, in some cases reaching the theoretical limit set by experimental noise.
This matters because accurate perturbation prediction is foundational for drug discovery, understanding disease mechanisms, and building "virtual cell" models that simulate biology in silico.
Predicting how cells respond to perturbations — whether you knock out a gene, overexpress it, or treat the cell with a drug — is one of the holy grails of molecular biology. If we could accurately simulate these responses computationally, it would transform drug discovery by letting researchers screen interventions in silico before running expensive wet-lab experiments.
The field has evolved from differential equation-based models of gene regulatory networks, through classical ML approaches (ElasticNet, matrix factorization), to deep learning methods (Dr.VAE, scGen, CPA, GEARS). Most recently, transformer-based foundation models pretrained on massive single-cell atlases — Geneformer, scGPT, scFoundation — have entered the scene, claiming strong results. But a counter-narrative has emerged: papers like PerturbBench and Ahlmann-Eltze et al. argue that FMs don't outperform simple linear baselines.
This paper resolves the contradiction by showing that the answer depends on which FM you use. The authors take an "embedding-centric" view: rather than comparing complex end-to-end architectures, they extract embeddings from various FMs and feed them into simple predictors (kNN). This isolates the quality of the embedding itself — the core biological knowledge captured by the FM.
The most striking finding is that embedding quality varies enormously across FM types, and this variation is primarily explained by the modality of the training data. The authors benchmarked embeddings on the Essential dataset (4 cell lines, ~2000 perturbations each) — much larger than the commonly used Norman dataset, which had masked real performance differences.
Interactome-based embeddings (WaveGC, STRING GNN, GenotypeVAE, GenePT) consistently rank at the top. These capture how genes interact with each other in cellular networks. Expression-based FMs (scGPT, AIDO.Cell) are middling, while protein sequence and DNA sequence FMs generally perform worst. The intuition is that knowing what a gene's protein looks like matters less than knowing what it does and who it talks to in the cell.
Remarkably, on the K562 cell line, the best single embedding closes 77% of the gap between a naive baseline and the estimated experimental error limit — using nothing more than kNN regression. This tells us the embedding is doing the heavy lifting, not the prediction algorithm.
import numpy as np
from typing import Dict, List, Tuple
def knn_perturbation_prediction(
train_embeddings: np.ndarray, # (N_train, d) - embeddings of training perturbations
train_lfc: np.ndarray, # (N_train, G) - observed LFC for training perturbations
test_embeddings: np.ndarray, # (N_test, d) - embeddings of test perturbations
k: int = 20
) -> np.ndarray:
"""
Predict perturbation response using kNN in embedding space.
The core idea: perturbations with similar embeddings should
produce similar cellular responses. The FM embedding encodes
biological similarity — kNN leverages that to predict LFC.
Args:
train_embeddings: FM embeddings for seen perturbations
train_lfc: Log fold-change vectors for seen perturbations
test_embeddings: FM embeddings for unseen perturbations
k: Number of neighbors
Returns:
Predicted LFC vectors for test perturbations (N_test, G)
"""
predictions = []
for test_emb in test_embeddings:
# Step 1: Compute distances in embedding space
distances = np.linalg.norm(train_embeddings - test_emb, axis=1)
# Step 2: Find k nearest neighbors
nn_indices = np.argsort(distances)[:k]
# Step 3: Average their observed responses
predicted_lfc = train_lfc[nn_indices].mean(axis=0)
predictions.append(predicted_lfc)
return np.array(predictions)
# Example: evaluate with L2 error
def evaluate_l2(true_lfc: np.ndarray, pred_lfc: np.ndarray) -> float:
"""Average L2 error across perturbations."""
return np.mean(np.linalg.norm(true_lfc - pred_lfc, axis=1))
Can you improve an FM's perturbation predictions by fine-tuning it on the actual perturbation data? The answer is: it depends. The authors tested two approaches for AIDO.Cell (3M) and one for STRING GNN.
For AIDO.Cell, the "In-Silico KO" method — where the target gene's expression is masked and the model predicts the downstream effect — provided a significant boost, outperforming both the frozen kNN baseline and an MLP ablation. However, a simpler "Indexing" approach (extracting the gene's embedding and training a head on top) actually hurt performance. For STRING GNN, fine-tuning degraded results relative to using frozen embeddings.
The takeaway is sobering: current perturbation datasets may be too small to reliably fine-tune large models. Overfitting is a real risk. Using frozen FM embeddings as features for simple predictors is often the safer bet.
import numpy as np
from typing import Optional
def in_silico_ko_prediction(
control_expression: np.ndarray, # (G,) mean expression of control cells
target_gene_idx: int, # index of knocked-out gene
encoder_fn, # FM encoder: (G,) → (G, d)
prediction_head_fn # head: (d,) → (G,)
) -> np.ndarray:
"""
In-Silico Knockout: mask the target gene, encode with FM,
then predict LFC from the contextualized embedding.
This approach works because the FM learns to propagate
information through gene-gene relationships. When a gene
is masked, the FM's output captures what the network
'expects' should change.
Args:
control_expression: Average control cell profile
target_gene_idx: Which gene to knock out
encoder_fn: Foundation model encoder
prediction_head_fn: Learned prediction head
Returns:
Predicted log fold-change vector (G,)
"""
# Step 1: Mask the target gene (simulate knockout)
masked_expression = control_expression.copy()
masked_expression[target_gene_idx] = 0.0
# Step 2: Run through FM encoder to get gene embeddings
gene_embeddings = encoder_fn(masked_expression) # (G, d)
# Step 3: Extract embedding at target position
# The FM contextualizes this based on other genes
target_embedding = gene_embeddings[target_gene_idx] # (d,)
# Step 4: Predict LFC from contextualized embedding
predicted_lfc = prediction_head_fn(target_embedding) # (G,)
return predicted_lfc
Given a perturbation embedding, how should you translate it into a predicted expression change? The literature proposes sophisticated generative approaches — Latent Diffusion, Flow Matching, Schrödinger Bridge — that model the full distribution of perturbed single cells. The authors benchmarked simple implementations of each against kNN.
The result: none of these advanced methods outperform kNN paired with the best embedding. This is remarkable because the generative methods use the same embedding as input but are far more computationally expensive (~1000 GPU-hours each to train). Similarly, GEARS, a published GNN-based perturbation model, didn't beat the embedding+kNN approach.
The implication is clear: for predicting average perturbation effects, the bottleneck is the embedding quality, not the prediction model complexity.
Chemical perturbations present a harder prediction problem than genetic ones. A small molecule may hit multiple targets, the chemical space is enormous (~10^60 possible molecules), and we have less network knowledge for drugs.
The authors tested molecular fingerprints, SMILES-based FMs (ChemBERTa, Uni-Mol, MiniMol), target-based embeddings (embedding the drug's predicted protein target with a gene FM), and LLM-based text embeddings of drug descriptions.
In the LFC regression formulation, no embedding clearly outperformed baselines — the signal was weak. But in the DEG (differentially expressed gene) classification formulation, target-based embeddings worked best: if you know (or can predict) what protein a drug binds, embedding that protein with an scRNA-seq FM gives you useful information. Traditional molecular structure fingerprints (ECPF:2) were also competitive. Interestingly, SMILES-based FMs generally underperformed, likely because they were trained for chemical — not biological — property prediction.
Since different FM types capture different aspects of gene biology (sequence, structure, interactions, function), the authors hypothesized that combining them could be more powerful than any single embedding. They designed an attention-based fusion model that ingests embeddings from all sources and learns to weight them dynamically.
The results are impressive: fusion consistently beats the best unimodal embedding (WaveGC) on all four cell lines in Essential. For K562 and Jurkat, the fusion model actually matches the estimated experimental error limit — meaning the model is as accurate as the experiment itself. For the other two cell lines, it bridges 86% (Hep-G2) and 53% (hTERT-RPE1) of the gap between random performance and the experimental error bound.
However, fusion did not help for chemical perturbations, likely because individual drug embeddings were too weak to provide meaningful complementary information.
import numpy as np
from typing import List, Dict, Optional
class EmbeddingFusionModel:
"""
Simplified attention-based fusion of multiple FM embeddings.
Each perturbation is represented by J embeddings from different
FMs. A transformer learns to attend across these sources and
produce a unified prediction.
"""
def __init__(self, embedding_dims: List[int], common_dim: int = 100,
n_heads: int = 5, n_genes: int = 1000):
"""
Args:
embedding_dims: Dimension of each source embedding
common_dim: Shared projection dimension
n_heads: Attention heads in transformer
n_genes: Number of genes to predict
"""
self.common_dim = common_dim
self.n_sources = len(embedding_dims)
# Per-source projection matrices: map each to common_dim
self.projections = [
np.random.randn(d, common_dim) * 0.1
for d in embedding_dims
]
# Learnable cell line embeddings
self.cell_line_embeddings: Dict[str, np.ndarray] = {}
# Prediction head weights (simplified)
self.pred_weights = np.random.randn(common_dim, n_genes) * 0.01
def project_embeddings(
self,
embeddings: List[Optional[np.ndarray]]
) -> np.ndarray:
"""
Project each source embedding to the common space.
Handles missing embeddings (not all sources cover all genes).
Returns:
tokens: (n_valid + 1, common_dim) including CLS token
"""
tokens = []
# CLS token for aggregation
cls_token = np.zeros(self.common_dim)
tokens.append(cls_token)
# Project each available embedding
for i, emb in enumerate(embeddings):
if emb is not None:
projected = emb @ self.projections[i]
tokens.append(projected)
return np.array(tokens) # (n_tokens, common_dim)
def predict(
self,
embeddings: List[Optional[np.ndarray]],
cell_line: str
) -> np.ndarray:
"""
Predict LFC by fusing all available embeddings.
Returns:
predicted_lfc: (n_genes,) vector
"""
# Step 1: Project to common space + add cell line embedding
tokens = self.project_embeddings(embeddings)
if cell_line in self.cell_line_embeddings:
tokens += self.cell_line_embeddings[cell_line]
# Step 2: Self-attention (simplified as mean for illustration)
# In practice: multi-head self-attention transformer layers
cls_output = tokens.mean(axis=0) # (common_dim,)
# Step 3: Predict LFC from CLS output
predicted_lfc = cls_output @ self.pred_weights # (n_genes,)
return predicted_lfc
The paper frames perturbation prediction as a regression problem: given a perturbation (gene knockout or drug treatment), predict the vector of per-gene expression changes. Specifically, they predict the "batch-aware average treatment effect" (BA-ATE), which they call log fold-change (LFC) for simplicity.
The key idea is to compare the average expression of perturbed cells to the average expression of control cells, accounting for batch effects by weighting each batch equally. For datasets with large batches, per-batch control means are subtracted. For datasets with small batches (like Essential), a global control mean is used instead to avoid noisy per-batch estimates.
The primary evaluation metric is the L2 error between predicted and observed LFC vectors, averaged across all test perturbations.
import numpy as np
from typing import Dict, List
def compute_batch_aware_ate(
expression_matrix: np.ndarray, # (N, G) normalized expression
cell_labels: np.ndarray, # (N,) perturbation ID or 'ctrl'
batch_labels: np.ndarray, # (N,) batch assignment
perturbation_id: str,
use_global_control: bool = False # True for small-batch datasets
) -> np.ndarray:
"""
Compute Batch-Aware Average Treatment Effect (BA-ATE).
For large batches: per-batch control subtraction
For small batches: global control subtraction (more stable)
Args:
expression_matrix: log1p-normalized scRNA-seq data
cell_labels: Which perturbation each cell received
batch_labels: Batch assignment for each cell
perturbation_id: Target perturbation to compute ATE for
use_global_control: Use global vs per-batch control mean
Returns:
lfc: (G,) vector of per-gene treatment effects
"""
ctrl_mask = cell_labels == 'ctrl'
pert_mask = cell_labels == perturbation_id
G = expression_matrix.shape[1]
if use_global_control:
# Small-batch mode: compute global control mean
batches = np.unique(batch_labels)
global_ctrl = np.mean([
expression_matrix[ctrl_mask & (batch_labels == b)].mean(axis=0)
for b in batches
if np.any(ctrl_mask & (batch_labels == b))
], axis=0)
# Find batches containing both control and perturbed cells
valid_batches = []
for b in np.unique(batch_labels):
has_ctrl = np.any(ctrl_mask & (batch_labels == b))
has_pert = np.any(pert_mask & (batch_labels == b))
if has_ctrl and has_pert:
valid_batches.append(b)
# Average treatment effect across batches
batch_effects = []
for b in valid_batches:
pert_mean = expression_matrix[pert_mask & (batch_labels == b)].mean(axis=0)
if use_global_control:
batch_effects.append(pert_mean - global_ctrl)
else:
ctrl_mean = expression_matrix[ctrl_mask & (batch_labels == b)].mean(axis=0)
batch_effects.append(pert_mean - ctrl_mean)
return np.mean(batch_effects, axis=0) # (G,)
An alternative to predicting exact expression changes is to classify each gene as upregulated (+1), downregulated (−1), or unchanged (0). This is the DEG formulation. For large-batch datasets, a Student's t-test with Benjamini-Hochberg correction is run per batch, then majority-voted. For small-batch datasets, all cells are pooled. The metric is macro F1 score.
This formulation is particularly useful for chemical perturbations where the continuous LFC signal is weak but discrete changes can still be detected.
Four datasets span different perturbation types, cell lines, and scales:
Essential is the primary benchmark for genetic perturbations because it has enough perturbations (~2000 per cell line) to reliably distinguish embedding quality. Norman is included for comparison with prior work, but its small size masks performance differences. Tahoe is the largest chemical perturbation dataset, enabling more robust evaluation of drug embeddings.
The paper evaluates embeddings from four modalities for genetic perturbation prediction: expression (scRNA-seq FMs), DNA sequence, protein sequence/structure, and prior knowledge (interaction networks, annotations, text). For chemical perturbations, additional sources include molecular fingerprints, SMILES-based FMs, target-based embeddings, and text embeddings from LLMs.
A key methodological contribution is the systematic evaluation across all these sources using a unified framework (same datasets, metrics, cross-validation splits).
The fusion model uses a transformer encoder to integrate variable numbers of embeddings per perturbation. Each embedding is projected to a common 100-dimensional space, a learnable cell line token is added, and a CLS token aggregates information through self-attention. The CLS output feeds into a prediction head that outputs the LFC vector.
A key design choice is handling missing embeddings — not every FM produces an embedding for every gene. The attention mechanism naturally handles this by operating on variable-length input sets. Training uses L2 loss with Optuna hyperparameter tuning (100 trials), and the model is trained jointly across cell lines.
The authors conclude that foundation models do improve perturbation prediction — but only when you choose the right ones. The key findings paint a nuanced picture:
For genetic perturbations, interactome-based embeddings are the single most valuable information source. This has practical implications: organizations building "virtual cell" models should invest in interaction data across contexts (cell types, diseases, developmental stages), potentially even more than in additional single-cell expression data. Fusion of multiple FM types pushes performance to the experimental noise floor.
For chemical perturbations, the picture is harder. Genetic perturbations are specific (one gene → one knockout), while drugs can hit multiple targets through complex pharmacology. The search space is vastly larger (~10^60 molecules vs ~20K genes), and we have much less interaction network data for small molecules. Off-the-shelf SMILES-based FMs perform poorly because they were trained to predict chemical, not biological, properties. The field needs a biological function-aware molecular FM.
Fine-tuning remains a challenge due to limited perturbation data. The authors suggest that jointly fine-tuning multiple FMs within a fusion framework could help, but overfitting risk is high without more training data.
Foundation models DO improve perturbation prediction — but only certain types. Interactome-based FMs (WaveGC, GenePT, GenotypeVAE) significantly outperform baselines, while DNA and protein sequence FMs add little value for this task.
Embedding quality > model complexity: kNN with the right embedding beats Latent Diffusion, Flow Matching, Schrödinger Bridge, and GEARS — all at a fraction of the compute cost.
Multi-modal fusion reaches experimental limits: Combining embeddings from diverse FMs via attention-based fusion matches the noise floor in 2 of 4 cell lines tested.
Chemical perturbations remain hard: Drug response prediction is fundamentally more difficult due to multi-target effects and the lack of biology-aware molecular FMs.
Data and benchmarks need to scale up: Small datasets mask real differences between methods; larger benchmarks like Essential reveal robust trends. The field needs both more perturbation data and harder evaluation splits.
Original paper: Cole, E. et al. (2026). "Foundation Models Improve Perturbation Response Prediction." bioRxiv. DOI: 10.64898/2026.02.18.706454
Code and Data: https://github.com/genbio-ai/foundation-models-perturbation
This analysis was generated to make complex academic concepts more accessible. For complete technical details, mathematical formulations, supplementary figures, and all 600+ model results, please refer to the original paper.