Train, assess, and use pySCN classifier#

In this brief tutorial, we show you how to train a pySCN classifier, assess its performance, and use it to predict the cell type of independent data.

Data#

We will use the training data that was processed in the prepare training data tutorial. And we will use another PBMC data as the query data:

Load packages#

import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import scipy as sp
import numpy as np
import anndata
import pySingleCellNet as pySCN
import igraph as ig
from joblib import dump, load
import sys

sc.settings.verbosity = 3 
sc.logging.print_header()
/opt/homebrew/Caskroom/miniforge/base/envs/simple_pyscn/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
scanpy==1.10.0 anndata==0.10.6 umap==0.5.5 numpy==1.26.4 scipy==1.12.0 pandas==2.2.1 scikit-learn==1.4.1.post1 statsmodels==0.14.1 igraph==0.11.4 pynndescent==0.5.12

Load data#

Important

pySCN assumes that the expression estimates are in their raw, un-transformed state. While a shifted log normalization should not be a detriment to classifier performance, inputing scaled data will reduce performance.

adRef = sc.read_h5ad("../../data/adPBMC_ref_040623.h5ad")
adQ1 = sc.read_h5ad("../../data/adPBMC_query_1_20k_HT_040723.h5ad")

Limit to genes shared in both data sets#

pySCN.limit_anndata_to_common_genes([adRef, adQ1])

Train classifier#

First, we will split the reference data into the training set made up of equal numbers of cells per cell type, and a held out data set.

adTrain, adHeldOut = pySCN.splitCommonAnnData(adRef, ncells=50,dLevel="cell_type")
CD4 T cell : 
3554
B cell : 
1450
Megakaryocyte : 
59
FCGR3A monocyte : 
327
CD8 T cell : 
1029
NK cell : 
608
CD14 monocyte : 
3128
Dendritic : 
154

Now, we will train the pySCN classifier.

clf = pySCN.scn_train(adTrain, dLevel = 'cell_type', nTopGenes = 200, nTopGenePairs = 200, nRand = 100, nTrees = 1000, stratify=False, propOther=0.4)
normalizing by total count per cell
    finished (0:00:00): normalized adata.X and added    'n_counts', counts per cell before normalization (adata.obs)
HVG
extracting highly variable genes
    finished (0:00:00)
--> added
    'highly_variable', boolean vector (adata.var)
    'means', float vector (adata.var)
    'dispersions', float vector (adata.var)
    'dispersions_norm', float vector (adata.var)
... as `zero_center=True`, sparse input is densified and may lead to large memory consumption
Matrix normalized
ranking genes
    finished: added to `.uns['rank_genes_groups']`
    'names', sorted np.recarray to be indexed by group ids
    'scores', sorted np.recarray to be indexed by group ids
    'logfoldchanges', sorted np.recarray to be indexed by group ids
    'pvals', sorted np.recarray to be indexed by group ids
    'pvals_adj', sorted np.recarray to be indexed by group ids (0:00:00)
There are  1096  classification genes

B cell
CD14 monocyte
CD4 T cell
CD8 T cell
Dendritic
FCGR3A monocyte
Megakaryocyte
NK cell
There are 6042 top gene pairs

Finished pair transforming the data

Classify held out data and assess#

pySCN.scn_classify(adHeldOut, clf, nrand = 0)
pySCN.barplot_classifier_f1(adHeldOut, ground_truth="cell_type", class_prediction="SCN_class")
0.9967753493371552
0.9833305826043902
0.9576460098082925
0.8952110004741584
0.9107142857142857
0.8722044728434505
0.4864864864864865
0.9655781112091791
../_images/eca7ffce20046cefc377925e12babba876b73db140b3799b271ceffd35746552.png

Classify query data#

pySCN.scn_classify(adQ1, clf, nrand = 0)

# Note that the "cell_type" obs is derived from manual annotation based on marker expression. 
pySCN.heatmap_scores(adQ1, groupby='cell_type')
../_images/533f75e50912b6a86ac05b1ea81fafc0018dca8f6bf91b495766a364532f616e.png

We can also arrange the cells based on the cell type in which the received the highest classification score (softmax).

pySCN.heatmap_scores(adQ1, groupby='SCN_class')
../_images/d05a1f13f91f347761cd4a4138570a01fd4797b7d696203e2fea26292b62fda0.png

We can embed the query data and see how SCN classification and SCN scores are distributed across clusters

adQ1_norm = pySCN.norm_hvg_scale_pca(adQ1)

npcs = 11
sc.pp.neighbors(adQ1_norm, n_neighbors=25, n_pcs=npcs)
sc.tl.leiden(adQ1_norm,.1)

sc.tl.paga(adQ1_norm)
sc.pl.paga(adQ1_norm, plot=False)
sc.tl.umap(adQ1_norm, 0.25, init_pos='paga')

plt.rcdefaults()
# plt.subplots(layout="constrained")
#plt.rcParams.update({"figure.autolayout": True})
sc.pl.umap(adQ1_norm,color=['SCN_class', 'cell_type'], alpha=.75, s=10)
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:02)
--> added
    'highly_variable', boolean vector (adata.var)
    'means', float vector (adata.var)
    'dispersions', float vector (adata.var)
    'dispersions_norm', float vector (adata.var)
computing PCA
    with n_comps=100
    finished (0:00:07)
computing neighbors
    using 'X_pca' with n_pcs = 11
    finished: added to `.uns['neighbors']`
    `.obsp['distances']`, distances for each pair of neighbors
    `.obsp['connectivities']`, weighted adjacency matrix (0:00:02)
running Leiden clustering
    finished: found 6 clusters and added
    'leiden', the cluster labels (adata.obs, categorical) (0:00:06)
running PAGA
    finished: added
    'paga/connectivities', connectivities adjacency (adata.uns)
    'paga/connectivities_tree', connectivities subtree (adata.uns) (0:00:00)
--> added 'pos', the PAGA positions (adata.uns['paga'])
computing UMAP
    finished: added
    'X_umap', UMAP coordinates (adata.obsm) (0:00:10)
../_images/1d72d8e6421da4afdc8bc24d74c303b8e0a71b9529a118fa94ca588b6100bbfb.png

Examining the SCN_score can help to explain why there are differences in the manual annotation and pySCN classification in the CD8 T cells and NK cell groups.

pySCN.umap_scores(adQ1_norm, ["CD8 T cell", "NK cell"])
../_images/a762c558896ae0ab603b64e0e9022b22c4ab9fd458a1ac9e9d535cca95b7d2f5.png

Cell type composition#

We can visualize the cell type composition of samples as follows:

pySCN.plot_cell_type_proportions([adHeldOut,adQ1_norm], obs_column = "SCN_class", labels=["HeldOut", "Query"])
../_images/73cc52d6b06c0c3f741b7d1e6e0b38816663b2ba8fc56545455ac65049bd470c.png

Rank-based classifier#

Sometimes it is useful to skip the Top-scoring pair transformation. Below, we show how you can use pySCN using rank-transformed expression data after limiting the set of predictors to genes associated with the first x principal components.

adRef = sc.read_h5ad("../../data/adPBMC_ref_040623.h5ad")
adQ1 = sc.read_h5ad("../../data/adPBMC_query_1_20k_HT_040723.h5ad")

adRef_Norm = pySCN.norm_hvg_scale_pca(adRef)
topgenes = pySCN.top_genes_pca(adRef_Norm, n_pcs=7, top_x_genes = 10)
len(topgenes)
normalizing counts per cell
    finished (0:00:00)
extracting highly variable genes
    finished (0:00:00)
--> added
    'highly_variable', boolean vector (adata.var)
    'means', float vector (adata.var)
    'dispersions', float vector (adata.var)
    'dispersions_norm', float vector (adata.var)
computing PCA
    with n_comps=100
    finished (0:00:03)
54
adRef = adRef[:,topgenes].copy()
pySCN.limit_anndata_to_common_genes([adRef, adQ1])

adTrain, adHeldOut = pySCN.splitCommonAnnData(adRef, ncells=50,dLevel="cell_type")
clf_rank = pySCN.train_rank_classifier(adTrain, dLevel="cell_type")
pySCN.rank_classify(adHeldOut, clf_rank)
pySCN.heatmap_scores(adHeldOut, groupby='SCN_class')
CD4 T cell : 
3554
B cell : 
1450
Megakaryocyte : 
59
FCGR3A monocyte : 
327
CD8 T cell : 
1029
NK cell : 
608
CD14 monocyte : 
3128
Dendritic : 
154
../_images/baac7813434cdf747d0c5d59779201bdcbd1627d53a5e121439d02eab306a560.png
pySCN.barplot_classifier_f1(adHeldOut, ground_truth="cell_type", class_prediction="SCN_class")
0.998211091234347
0.9823228151329919
0.9638483104618563
0.8880631676730144
0.7157894736842105
0.9551724137931035
0.5142857142857142
0.9830508474576272
../_images/b4bd8adb1a622763c37d339c9e16f531f03fd9e88649947f7a466f5d38d89d47.png
pySCN.rank_classify(adQ1, clf_rank)
pySCN.heatmap_scores(adQ1, groupby='SCN_class')
../_images/6897450ea4363db5fad25ee210eba6784e0bdd430edc1a8fa9227dd4aa821b95.png
pySCN.plot_cell_type_proportions([adHeldOut,adQ1], obs_column = "SCN_class", labels=["HeldOut", "Query"])
../_images/52f6603bad2e9084ea7f346ef41cf2ee67b130f49c6779bbc302fca8255c888f.png