Lab Notes: Building a GenePT Cell Typer - Part 3


Introduction
We recently built a cell-typer model based on a subset of the CellXGene V2 dataset. This model is a stepping stone towards a “universal” cell-typer model that will allow us to annotate cell types (and other properties) based on single cell RNA transcription data.
I’ve noticed that people prefer shorter, multi-part articles over longer ones, so I’ve broken these Lab Notes up into three parts. This is the third part:
Part 3: Model design and training ← you are here
If you’re interested in more detail, you can find the notebook I used building the training and test sets here:
https://github.com/honicky/GenePT-tools/blob/main/notebooks/cellxgene_v2_mlp.ipynb
Here is a summary:
We chose this design to simplify the initial model training, but will relax constraints as we develop a more and more general cell-typer.
Scaling Gradient Boosted Trees vs Multi-layer Perceptrons
The previous models we’ve built use LightGBM and XGBoost as classifiers once we have our embeddings. These work great, and they are easy and fast to train (for small data sets). Their performance often beats Neural Networks for tabular data. Gradient boosted trees (such as LightGB and XGBoost), on the other hand require that you have your entire dataset in memory during training (with some important caveats). There have been some successful attempts to relax this constraint, but they introduce complexity that pushed us towards MLPs.
XGBoost supports a few features that allow for training on very large datasets
External Memory - this allows you to use disk to offload memory, as well as CPU memory to offload GPU memory. While this, in theory, is what we need, the results seem to be mixed in practice, based on various blog posts and forums. This also only has experimental support for distributed GPU.
Gradient-based sampling - this allows us to sample a subset of the data for building trees weighted by the gradients. This is much more sample efficient, and allows ratios as low as 10% without a significant loss of sample (according to this video). This seems to be mostly intended for offloading GPU memory to CPU, which is useful, but doesn’t get us to TB scale very cost effectively. It might be that we can also use disk offload, but that still leaves a 10% ratio (or whatever) limit, which will be prohibitive with large datasets)
Spark - XGBoost supports distributed GPU training using the XGBoost4j Spark library. If you already have a GPU-enabled Spark cluster (via Databricks or whatever), then this might be a good option to get started with. On the other hand, Spark is heavyweight if you don’t have a managed cluster and still suffers from the same fundamental scaling limitations.
Multi-layer perceptrons (MLP), on the other hand, can easily handle batched training (via stochastic gradient descent). Batched training means that we can just keep feeding in data in batches. The scale of the model is independent of the scale of the data. There is lots of infrastructure for different levels of complexity so that we can start simple and scale to massive clusters as needed. Overall, this seems like a better approach than investing in learning and building infrastructure that will be limited in its ability to scale to very large data sets (like the ones we hope to use going forward).
So MLPs it is!
Can we “iron out” distribution differences?
The CellXGene V2 data we are using has 961 files taken from different publicly available data set of various sizes, taken under various conditions etc. Each one will have at least slight variation between batches. On the other hand, doing a random shuffle of all the cells in our training set requires a lot of work and resources. It seems like if we run a few epochs the later epochs will have learned the different distributions and will have reached a minimum in loss-space that accommodates all of the different distributions in the same way that random shuffling would.
Here is the first epoch of a training run after some hyper parameter tuning. Each spike is a training file, in which all of the cells are from the same source file. The solid line is the training loss, the dashed line is the validation set loss, and the dotted line is the loss on the 5k-cell subset of the validation set.
There is a slight downward trend in all of the losses over the epoch. In the following epochs, we would expect (based on the hypothesis above) the spikes to get smaller as the model learns how to integrate all of the data sets.
Hmm… the size of the spikes go down a bit, and the model is learning a bit over time, but not nearly what I had hoped for. Zooming in on a couple of spikes, we can see that the 5k validation goes down for a few steps after starting a new file, but quickly starts going up. This means that it is overfitting to the specific file.
We could try lowering the learning rate to see if we can reduce overfitting by averaging more, but maybe just randomizing the dataset will work better. As an experiment, lets try chopping down the size of the embeddings so that the entire training set fits in memory so we can easily shuffle it.
In memory shuffling
An interesting thing about OpenAI’s embeddings is that they use Matryoshka Representation Learning to concentrate the most information in the lower dimensions. This means that you can just take the lowest k dimensions and you will still have reasonably informative features (within reason). I picked k=500 because that allows the entire data set to fit into memory. This means that we can easily shuffle the data without doing gymnastics like 2-pass shuffling. If it works well, we can do ablations on a good tradeoff for embedding size and performance and use 2-pass shuffling on our larger data set.
With a 64GB machine, a A10 GPU and using the first 500 dimensions of our embeddings, we get this:
Ok, that’s dramatically better! Zooming in we see (more clearly) that our new model has dramatically better validation loss, learns much faster, and is way more smooth:
We obviously should try with the full 3078-dimensional embeddings, but have not done that yet.
The fact that the per-file model overfits to each data set with comparison to the randomized version tells us a few things:
The distributions between different data sets are pretty different. We already knew this, but this puts a face and a number on the impact of batch effects
We are, to some extent, able to learn a classifier that can ignore batch effects. E.g., it can “integrate” data sets! This is an unexpected outcome: maybe we don’t need a fancy binning algorithm like the one that scGPT uses in order to account for batch effects after all. It will be interesting to see how much using both scGPT and GenePT embeddings helps our predictor - maybe not much?
Vibes?
We can use a kinda-sorta confusion matrix to get a sense of how well we are predicting (click on the image to see an interactive version).
I say “kinda-sorta,” because we use the sum of the predicted probability of each class instead of the normalized count of responses. This allows us to see the distribution that model is learning more accurately. Rows are the true labels, and columns are mean probability of predicting the label, given the true label. We also cluster the X and Y axes separately so that empirically related cells are more likely to appear close to each other in the diagram.
If you mouse over the different clusters of cell types, you can see that we are mostly predicting cell types that are approximately correct, even when we are slightly off.
Evals
Ok, so good vibes. What about numbers?
The cell-type hierarchy presents a classification problem. If we use macro precision, recall and f1, the typical metrics for a multi-class classifier, then we are “wrong” if we classify a cell more accurately (e.g. further down the hierarchy) than its label.
This example hierarchy, taken from EMBL-EBI Ontology Lookup Service illustrates the dilemma. If I classify a cell as “CD16-positive, CD56-dim natural killer cell,” but the label is “mature natural killer cell,” then my prediction is “wrong” according to the macro statistics. The reverse is also true.
We tried several approaches to deal with this. Indeed our algorithm performs very poorly when evaluating using these statistics:
Ooof! Our algorithm is struggling to get the exact labels correct even a small percentage of the time.
We can also look at whether the correct label is in the top 2,5 or 10 answers (Recall@2, Recall@5, Recall@10) answers from the algorithm (since it gives us a probability distribution over all of the classes). Mean Reciprocal Rank (MRR) is a ranking metric often used in information retrieval that gives partial credit if the correct answer is close to the top predicted class, and Discounted Cumulative Gain (DCG) gives more credit for being closer to the top of the list, and takes the probability assigned to each class into account.
We do significantly better on these ranking metics, suggesting that indeed we are just predicting the a level in the hierarchy that doesn’t match the ground truth labels, rather than actually mis-predicting. The fact that we do better on DCG than MRR suggests that we are generally relatively confident about the correct prediction. On the other hand, ranking metrics might be useful as a diagnostic, but a user really needs to know the correct cell type!
Hierarchical Evals
What if we instead use a hierarchical metric in which we measure our closeness in the hierarchy to the correct label? Learning and Evaluation in the Presence of Class Hierarchies: Application to Text Categorization describes a good algorithm for this. Consider the diagram above. In this example, if the correct classification is e and we pick d, then we got a and b correct, so we should get credit for those. So we should intuitively have a score of 2/3, whereas we should have 1/3 if we predict f.
We can define hierarchical precision (hP) as the number of nodes we got correct in all of our predictions divided by the number of nodes we predicted for a class, and hierarchical recall (hR) as the number of correctly predicted nodes in our hierarchy divided by the total number of correct nodes. This is a natural extension of precision and recall that have the properties that (from the paper):
The measure gives credit to partially correct classification
The measure punishes distant errors more heavily
The measure punishes errors at higher levels of a hierarchy more heavily
Interestingly, it also favors classifiers that are more specific: If you always predict the root node in our example, you will have high precision but low recall.
F1 naturally falls out of this score as well:
We can set Beta to 1 to get hF1.
So how does our algorithm do?
Fantastic! We get lots of our predictions correct. What about if we compare to another commonly used cell typer like CellTypist? CellTypist has the disadvantage that you need to know the cell class in order to choose a specific model that is good at that cell class. On the other hand, it is commonly used and growing in popularity.
Spot comparison with CellTypist
We benchmarked a few data sets against the best CellTypist model in their suite. We chose these comparison randomly:
Dataset | Algorithm | Hierarchical Precision | Hierarchical Recall | Hierarchical F1 |
Immune | CellTypist (Immune_All_Low) | 0.8688 | 0.4026 | 0.5502 |
Miraomics CellTyper | 0.8740 | 0.8395 | 0.8564 | |
Brain | CellTypist (Adult_Human_MTG) | 0.7145 | 0.1127 | 0.1947 |
Miraomics CellTyper | 0.9967 | 0.9707 | 0.9835 |
Immune Dataset:
Fink et al. (2022). Single-cell and spatial mapping identify cell types and signaling networks in the human ureter (immune subset).
Brain Dataset:
Jorstad et al. (2023). Transcriptomic cytoarchitecture reveals principles of human neocortex organization
Obviously this is just a spot check to make sure we are improving over existing methods. These results are very promising, but in order to be satisfied that we are robust to real-world factors, we need to do a more comprehensive comparison over a well designed test set across a variety of algorithms. We are looking at the Open Problems dataset, and Benchmarking atlas-level data integration in single-cell genomics, but we need to make sure that the newer models haven’t been contaminated with the data from these benchmarks.
Next up: comparison against SoTA models
We are in the process of doing a comprehensive comparison vs. the existing SoTA algorithms across a variety of data sets. This is a bit tricky because that means ensuring that the algorithms under comparison have not been trained on the test data.
Stay tuned! As always, your questions and feedback, especially negative, is greatly appreciated. Find me on
LinkedIn: rj-honicky
BlueSky: honicky.bsky.social
X: honicky
Subscribe to my newsletter
Read articles from RJ Honicky directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
