-
Notifications
You must be signed in to change notification settings - Fork 374
/
Copy path_totalvi.py
1365 lines (1244 loc) · 53.2 KB
/
_totalvi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import logging
import warnings
from collections.abc import Iterable as IterableClass
from functools import partial
from typing import Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
import pandas as pd
import torch
from anndata import AnnData
from mudata import MuData
from scvi import REGISTRY_KEYS, settings
from scvi._types import Number
from scvi.data import AnnDataManager, fields
from scvi.data._utils import _check_nonnegative_integers
from scvi.dataloaders import DataSplitter
from scvi.model._utils import (
_get_batch_code_from_category,
_get_var_names_from_manager,
_init_library_size,
cite_seq_raw_counts_properties,
get_max_epochs_heuristic,
)
from scvi.model.base._utils import _de_core
from scvi.module import TOTALVAE
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.utils._docstrings import de_dsp, devices_dsp, setup_anndata_dsp
from .base import ArchesMixin, BaseModelClass, RNASeqMixin, VAEMixin
logger = logging.getLogger(__name__)
class TOTALVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseModelClass):
"""total Variational Inference :cite:p:`GayosoSteier21`.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~scvi.model.TOTALVI.setup_anndata`.
n_latent
Dimensionality of the latent space.
gene_dispersion
One of the following:
* ``'gene'`` - genes_dispersion parameter of NB is constant per gene across cells
* ``'gene-batch'`` - genes_dispersion can differ between different batches
* ``'gene-label'`` - genes_dispersion can differ between different labels
protein_dispersion
One of the following:
* ``'protein'`` - protein_dispersion parameter is constant per protein across cells
* ``'protein-batch'`` - protein_dispersion can differ between different batches NOT TESTED
* ``'protein-label'`` - protein_dispersion can differ between different labels NOT TESTED
gene_likelihood
One of:
* ``'nb'`` - Negative binomial distribution
* ``'zinb'`` - Zero-inflated negative binomial distribution
latent_distribution
One of:
* ``'normal'`` - Normal distribution
* ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax)
empirical_protein_background_prior
Set the initialization of protein background prior empirically. This option fits a GMM for each of
100 cells per batch and averages the distributions. Note that even with this option set to `True`,
this only initializes a parameter that is learned during inference. If `False`, randomly initializes.
The default (`None`), sets this to `True` if greater than 10 proteins are used.
override_missing_proteins
If `True`, will not treat proteins with all 0 expression in a particular batch as missing.
**model_kwargs
Keyword args for :class:`~scvi.module.TOTALVAE`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.model.TOTALVI.setup_anndata(adata, batch_key="batch", protein_expression_obsm_key="protein_expression")
>>> vae = scvi.model.TOTALVI(adata)
>>> vae.train()
>>> adata.obsm["X_totalVI"] = vae.get_latent_representation()
Notes
-----
See further usage examples in the following tutorials:
1. :doc:`/tutorials/notebooks/multimodal/totalVI`
2. :doc:`/tutorials/notebooks/multimodal/cite_scrna_integration_w_totalVI`
3. :doc:`/tutorials/notebooks/scrna/scarches_scvi_tools`
"""
_module_cls = TOTALVAE
_data_splitter_cls = DataSplitter
_training_plan_cls = AdversarialTrainingPlan
_train_runner_cls = TrainRunner
def __init__(
self,
adata: AnnData,
n_latent: int = 20,
gene_dispersion: Literal[
"gene", "gene-batch", "gene-label", "gene-cell"
] = "gene",
protein_dispersion: Literal[
"protein", "protein-batch", "protein-label"
] = "protein",
gene_likelihood: Literal["zinb", "nb"] = "nb",
latent_distribution: Literal["normal", "ln"] = "normal",
empirical_protein_background_prior: Optional[bool] = None,
override_missing_proteins: bool = False,
**model_kwargs,
):
super().__init__(adata)
self.protein_state_registry = self.adata_manager.get_state_registry(
REGISTRY_KEYS.PROTEIN_EXP_KEY
)
if (
fields.ProteinObsmField.PROTEIN_BATCH_MASK in self.protein_state_registry
and not override_missing_proteins
):
batch_mask = self.protein_state_registry.protein_batch_mask
msg = (
"Some proteins have all 0 counts in some batches. "
+ "These proteins will be treated as missing measurements; however, "
+ "this can occur due to experimental design/biology. "
+ "Reinitialize the model with `override_missing_proteins=True`,"
+ "to override this behavior."
)
warnings.warn(msg, UserWarning, stacklevel=settings.warnings_stacklevel)
self._use_adversarial_classifier = True
else:
batch_mask = None
self._use_adversarial_classifier = False
emp_prior = (
empirical_protein_background_prior
if empirical_protein_background_prior is not None
else (self.summary_stats.n_proteins > 10)
)
if emp_prior:
prior_mean, prior_scale = self._get_totalvi_protein_priors(adata)
else:
prior_mean, prior_scale = None, None
n_cats_per_cov = (
self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY)[
fields.CategoricalJointObsField.N_CATS_PER_KEY
]
if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry
else None
)
n_batch = self.summary_stats.n_batch
use_size_factor_key = (
REGISTRY_KEYS.SIZE_FACTOR_KEY in self.adata_manager.data_registry
)
library_log_means, library_log_vars = None, None
if not use_size_factor_key:
library_log_means, library_log_vars = _init_library_size(
self.adata_manager, n_batch
)
self.module = self._module_cls(
n_input_genes=self.summary_stats.n_vars,
n_input_proteins=self.summary_stats.n_proteins,
n_batch=n_batch,
n_latent=n_latent,
n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0),
n_cats_per_cov=n_cats_per_cov,
gene_dispersion=gene_dispersion,
protein_dispersion=protein_dispersion,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
protein_batch_mask=batch_mask,
protein_background_prior_mean=prior_mean,
protein_background_prior_scale=prior_scale,
use_size_factor_key=use_size_factor_key,
library_log_means=library_log_means,
library_log_vars=library_log_vars,
**model_kwargs,
)
self._model_summary_string = (
"TotalVI Model with the following params: \nn_latent: {}, "
"gene_dispersion: {}, protein_dispersion: {}, gene_likelihood: {}, latent_distribution: {}"
).format(
n_latent,
gene_dispersion,
protein_dispersion,
gene_likelihood,
latent_distribution,
)
self.init_params_ = self._get_init_params(locals())
@devices_dsp.dedent
def train(
self,
max_epochs: Optional[int] = None,
lr: float = 4e-3,
use_gpu: Optional[Union[str, int, bool]] = None,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
train_size: float = 0.9,
validation_size: Optional[float] = None,
shuffle_set_split: bool = True,
batch_size: int = 256,
early_stopping: bool = True,
check_val_every_n_epoch: Optional[int] = None,
reduce_lr_on_plateau: bool = True,
n_steps_kl_warmup: Union[int, None] = None,
n_epochs_kl_warmup: Union[int, None] = None,
adversarial_classifier: Optional[bool] = None,
plan_kwargs: Optional[dict] = None,
**kwargs,
):
"""Trains the model using amortized variational inference.
Parameters
----------
max_epochs
Number of passes through the dataset.
lr
Learning rate for optimization.
%(param_use_gpu)s
%(param_accelerator)s
%(param_devices)s
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Size of the test set. If `None`, defaults to 1 - `train_size`. If
`train_size + validation_size < 1`, the remaining cells belong to a test set.
shuffle_set_split
Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the
sequential order of the data according to `validation_size` and `train_size` percentages.
batch_size
Minibatch size to use during training.
early_stopping
Whether to perform early stopping with respect to the validation set.
check_val_every_n_epoch
Check val every n train epochs. By default, val is not checked, unless `early_stopping` is `True`
or `reduce_lr_on_plateau` is `True`. If either of the latter conditions are met, val is checked
every epoch.
reduce_lr_on_plateau
Reduce learning rate on plateau of validation metric (default is ELBO).
n_steps_kl_warmup
Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
Only activated when `n_epochs_kl_warmup` is set to None. If `None`, defaults
to `floor(0.75 * adata.n_obs)`.
n_epochs_kl_warmup
Number of epochs to scale weight on KL divergences from 0 to 1.
Overrides `n_steps_kl_warmup` when both are not `None`.
adversarial_classifier
Whether to use adversarial classifier in the latent space. This helps mixing when
there are missing proteins in any of the batches. Defaults to `True` is missing proteins
are detected.
plan_kwargs
Keyword args for :class:`~scvi.train.AdversarialTrainingPlan`. Keyword arguments passed to
`train()` will overwrite values present in `plan_kwargs`, when appropriate.
**kwargs
Other keyword args for :class:`~scvi.train.Trainer`.
"""
if adversarial_classifier is None:
adversarial_classifier = self._use_adversarial_classifier
n_steps_kl_warmup = (
n_steps_kl_warmup
if n_steps_kl_warmup is not None
else int(0.75 * self.adata.n_obs)
)
if reduce_lr_on_plateau:
check_val_every_n_epoch = 1
update_dict = {
"lr": lr,
"adversarial_classifier": adversarial_classifier,
"reduce_lr_on_plateau": reduce_lr_on_plateau,
"n_epochs_kl_warmup": n_epochs_kl_warmup,
"n_steps_kl_warmup": n_steps_kl_warmup,
}
if plan_kwargs is not None:
plan_kwargs.update(update_dict)
else:
plan_kwargs = update_dict
if max_epochs is None:
max_epochs = get_max_epochs_heuristic(self.adata.n_obs)
plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else {}
data_splitter = self._data_splitter_cls(
self.adata_manager,
train_size=train_size,
validation_size=validation_size,
shuffle_set_split=shuffle_set_split,
batch_size=batch_size,
)
training_plan = self._training_plan_cls(self.module, **plan_kwargs)
runner = self._train_runner_cls(
self,
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
early_stopping=early_stopping,
check_val_every_n_epoch=check_val_every_n_epoch,
**kwargs,
)
return runner()
@torch.inference_mode()
def get_latent_library_size(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
give_mean: bool = True,
batch_size: Optional[int] = None,
) -> np.ndarray:
r"""Returns the latent library size for each cell.
This is denoted as :math:`\ell_n` in the totalVI paper.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
give_mean
Return the mean or a sample from the posterior distribution.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
"""
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
libraries = []
for tensors in post:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
if give_mean:
ql = outputs["ql"]
library = torch.exp(ql.loc + 0.5 * (ql.scale**2))
else:
library = outputs["library_gene"]
libraries += [library.cpu()]
return torch.cat(libraries).numpy()
@torch.inference_mode()
def get_normalized_expression(
self,
adata=None,
indices=None,
n_samples_overall: Optional[int] = None,
transform_batch: Optional[Sequence[Union[Number, str]]] = None,
gene_list: Optional[Sequence[str]] = None,
protein_list: Optional[Sequence[str]] = None,
library_size: Optional[Union[float, Literal["latent"]]] = 1,
n_samples: int = 1,
sample_protein_mixing: bool = False,
scale_protein: bool = False,
include_protein_background: bool = False,
batch_size: Optional[int] = None,
return_mean: bool = True,
return_numpy: Optional[bool] = None,
) -> Tuple[Union[np.ndarray, pd.DataFrame], Union[np.ndarray, pd.DataFrame]]:
r"""Returns the normalized gene expression and protein expression.
This is denoted as :math:`\rho_n` in the totalVI paper for genes, and TODO
for proteins, :math:`(1-\pi_{nt})\alpha_{nt}\beta_{nt}`.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples_overall
Number of samples to use in total
transform_batch
Batch to condition on.
If transform_batch is:
- None, then real observed batch is used
- int, then batch transform_batch is used
- List[int], then average over batches in list
gene_list
Return frequencies of expression for a subset of genes.
This can save memory when working with large datasets and few genes are
of interest.
protein_list
Return protein expression for a subset of genes.
This can save memory when working with large datasets and few genes are
of interest.
library_size
Scale the expression frequencies to a common library size.
This allows gene expression levels to be interpreted on a common scale of relevant
magnitude.
n_samples
Get sample scale from multiple samples.
sample_protein_mixing
Sample mixing bernoulli, setting background to zero
scale_protein
Make protein expression sum to 1
include_protein_background
Include background component for protein expression
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_mean
Whether to return the mean of the samples.
return_numpy
Return a `np.ndarray` instead of a `pd.DataFrame`. Includes gene
names as columns. If either n_samples=1 or return_mean=True, defaults to False.
Otherwise, it defaults to True.
Returns
-------
- **gene_normalized_expression** - normalized expression for RNA
- **protein_normalized_expression** - normalized expression for proteins
If ``n_samples`` > 1 and ``return_mean`` is False, then the shape is ``(samples, cells, genes)``.
Otherwise, shape is ``(cells, genes)``. Return type is ``pd.DataFrame`` unless ``return_numpy`` is True.
"""
adata = self._validate_anndata(adata)
adata_manager = self.get_anndata_manager(adata)
if indices is None:
indices = np.arange(adata.n_obs)
if n_samples_overall is not None:
indices = np.random.choice(indices, n_samples_overall)
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
if gene_list is None:
gene_mask = slice(None)
else:
all_genes = _get_var_names_from_manager(adata_manager)
gene_mask = [True if gene in gene_list else False for gene in all_genes]
if protein_list is None:
protein_mask = slice(None)
else:
all_proteins = self.protein_state_registry.column_names
protein_mask = [True if p in protein_list else False for p in all_proteins]
if indices is None:
indices = np.arange(adata.n_obs)
if n_samples > 1 and return_mean is False:
if return_numpy is False:
warnings.warn(
"`return_numpy` must be `True` if `n_samples > 1` and `return_mean` "
"is `False`, returning an `np.ndarray`.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
return_numpy = True
if not isinstance(transform_batch, IterableClass):
transform_batch = [transform_batch]
transform_batch = _get_batch_code_from_category(adata_manager, transform_batch)
scale_list_gene = []
scale_list_pro = []
for tensors in post:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
px_scale = torch.zeros_like(x)[..., gene_mask]
py_scale = torch.zeros_like(y)[..., protein_mask]
if n_samples > 1:
px_scale = torch.stack(n_samples * [px_scale])
py_scale = torch.stack(n_samples * [py_scale])
for b in transform_batch:
generative_kwargs = {"transform_batch": b}
inference_kwargs = {"n_samples": n_samples}
_, generative_outputs = self.module.forward(
tensors=tensors,
inference_kwargs=inference_kwargs,
generative_kwargs=generative_kwargs,
compute_loss=False,
)
if library_size == "latent":
px_scale += generative_outputs["px_"]["rate"].cpu()[..., gene_mask]
else:
px_scale += generative_outputs["px_"]["scale"].cpu()[..., gene_mask]
py_ = generative_outputs["py_"]
# probability of background
protein_mixing = 1 / (1 + torch.exp(-py_["mixing"].cpu()))
if sample_protein_mixing is True:
protein_mixing = torch.distributions.Bernoulli(
protein_mixing
).sample()
protein_val = py_["rate_fore"].cpu() * (1 - protein_mixing)
if include_protein_background is True:
protein_val += py_["rate_back"].cpu() * protein_mixing
if scale_protein is True:
protein_val = torch.nn.functional.normalize(
protein_val, p=1, dim=-1
)
protein_val = protein_val[..., protein_mask]
py_scale += protein_val
px_scale /= len(transform_batch)
py_scale /= len(transform_batch)
scale_list_gene.append(px_scale)
scale_list_pro.append(py_scale)
if n_samples > 1:
# concatenate along batch dimension -> result shape = (samples, cells, features)
scale_list_gene = torch.cat(scale_list_gene, dim=1)
scale_list_pro = torch.cat(scale_list_pro, dim=1)
# (cells, features, samples)
scale_list_gene = scale_list_gene.permute(1, 2, 0)
scale_list_pro = scale_list_pro.permute(1, 2, 0)
else:
scale_list_gene = torch.cat(scale_list_gene, dim=0)
scale_list_pro = torch.cat(scale_list_pro, dim=0)
if return_mean is True and n_samples > 1:
scale_list_gene = torch.mean(scale_list_gene, dim=-1)
scale_list_pro = torch.mean(scale_list_pro, dim=-1)
scale_list_gene = scale_list_gene.cpu().numpy()
scale_list_pro = scale_list_pro.cpu().numpy()
if return_numpy is None or return_numpy is False:
gene_df = pd.DataFrame(
scale_list_gene,
columns=_get_var_names_from_manager(adata_manager)[gene_mask],
index=adata.obs_names[indices],
)
protein_names = self.protein_state_registry.column_names
pro_df = pd.DataFrame(
scale_list_pro,
columns=protein_names[protein_mask],
index=adata.obs_names[indices],
)
return gene_df, pro_df
else:
return scale_list_gene, scale_list_pro
@torch.inference_mode()
def get_protein_foreground_probability(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
transform_batch: Optional[Sequence[Union[Number, str]]] = None,
protein_list: Optional[Sequence[str]] = None,
n_samples: int = 1,
batch_size: Optional[int] = None,
return_mean: bool = True,
return_numpy: Optional[bool] = None,
):
r"""Returns the foreground probability for proteins.
This is denoted as :math:`(1 - \pi_{nt})` in the totalVI paper.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
transform_batch
Batch to condition on.
If transform_batch is:
- None, then real observed batch is used
- int, then batch transform_batch is used
- List[int], then average over batches in list
protein_list
Return protein expression for a subset of genes.
This can save memory when working with large datasets and few genes are
of interest.
n_samples
Number of posterior samples to use for estimation.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
return_mean
Whether to return the mean of the samples.
return_numpy
Return a :class:`~numpy.ndarray` instead of a :class:`~pandas.DataFrame`. DataFrame includes
gene names as columns. If either `n_samples=1` or `return_mean=True`, defaults to `False`.
Otherwise, it defaults to `True`.
Returns
-------
- **foreground_probability** - probability foreground for each protein
If `n_samples` > 1 and `return_mean` is False, then the shape is `(samples, cells, genes)`.
Otherwise, shape is `(cells, genes)`. In this case, return type is :class:`~pandas.DataFrame` unless `return_numpy` is True.
"""
adata = self._validate_anndata(adata)
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
if protein_list is None:
protein_mask = slice(None)
else:
all_proteins = self.protein_state_registry.column_names
protein_mask = [True if p in protein_list else False for p in all_proteins]
if n_samples > 1 and return_mean is False:
if return_numpy is False:
warnings.warn(
"`return_numpy` must be `True` if `n_samples > 1` and `return_mean` "
"is `False`, returning an `np.ndarray`.",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
return_numpy = True
if indices is None:
indices = np.arange(adata.n_obs)
py_mixings = []
if not isinstance(transform_batch, IterableClass):
transform_batch = [transform_batch]
transform_batch = _get_batch_code_from_category(
self.adata_manager, transform_batch
)
for tensors in post:
y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
py_mixing = torch.zeros_like(y[..., protein_mask])
if n_samples > 1:
py_mixing = torch.stack(n_samples * [py_mixing])
for b in transform_batch:
generative_kwargs = {"transform_batch": b}
inference_kwargs = {"n_samples": n_samples}
_, generative_outputs = self.module.forward(
tensors=tensors,
inference_kwargs=inference_kwargs,
generative_kwargs=generative_kwargs,
compute_loss=False,
)
py_mixing += torch.sigmoid(generative_outputs["py_"]["mixing"])[
..., protein_mask
].cpu()
py_mixing /= len(transform_batch)
py_mixings += [py_mixing]
if n_samples > 1:
# concatenate along batch dimension -> result shape = (samples, cells, features)
py_mixings = torch.cat(py_mixings, dim=1)
# (cells, features, samples)
py_mixings = py_mixings.permute(1, 2, 0)
else:
py_mixings = torch.cat(py_mixings, dim=0)
if return_mean is True and n_samples > 1:
py_mixings = torch.mean(py_mixings, dim=-1)
py_mixings = py_mixings.cpu().numpy()
if return_numpy is True:
return 1 - py_mixings
else:
pro_names = self.protein_state_registry.column_names
foreground_prob = pd.DataFrame(
1 - py_mixings,
columns=pro_names[protein_mask],
index=adata.obs_names[indices],
)
return foreground_prob
def _expression_for_de(
self,
adata=None,
indices=None,
n_samples_overall=None,
transform_batch: Optional[Sequence[Union[Number, str]]] = None,
scale_protein=False,
batch_size: Optional[int] = None,
sample_protein_mixing=False,
include_protein_background=False,
protein_prior_count=0.5,
):
rna, protein = self.get_normalized_expression(
adata=adata,
indices=indices,
n_samples_overall=n_samples_overall,
transform_batch=transform_batch,
return_numpy=True,
n_samples=1,
batch_size=batch_size,
scale_protein=scale_protein,
sample_protein_mixing=sample_protein_mixing,
include_protein_background=include_protein_background,
)
protein += protein_prior_count
joint = np.concatenate([rna, protein], axis=1)
return joint
@de_dsp.dedent
def differential_expression(
self,
adata: Optional[AnnData] = None,
groupby: Optional[str] = None,
group1: Optional[Iterable[str]] = None,
group2: Optional[str] = None,
idx1: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
idx2: Optional[Union[Sequence[int], Sequence[bool], str]] = None,
mode: Literal["vanilla", "change"] = "change",
delta: float = 0.25,
batch_size: Optional[int] = None,
all_stats: bool = True,
batch_correction: bool = False,
batchid1: Optional[Iterable[str]] = None,
batchid2: Optional[Iterable[str]] = None,
fdr_target: float = 0.05,
silent: bool = False,
protein_prior_count: float = 0.1,
scale_protein: bool = False,
sample_protein_mixing: bool = False,
include_protein_background: bool = False,
**kwargs,
) -> pd.DataFrame:
r"""A unified method for differential expression analysis.
Implements `"vanilla"` DE :cite:p:`Lopez18`. and `"change"` mode DE :cite:p:`Boyeau19`.
Parameters
----------
%(de_adata)s
%(de_groupby)s
%(de_group1)s
%(de_group2)s
%(de_idx1)s
%(de_idx2)s
%(de_mode)s
%(de_delta)s
%(de_batch_size)s
%(de_all_stats)s
%(de_batch_correction)s
%(de_batchid1)s
%(de_batchid2)s
%(de_fdr_target)s
%(de_silent)s
protein_prior_count
Prior count added to protein expression before LFC computation
scale_protein
Force protein values to sum to one in every single cell (post-hoc normalization)
sample_protein_mixing
Sample the protein mixture component, i.e., use the parameter to sample a Bernoulli
that determines if expression is from foreground/background.
include_protein_background
Include the protein background component as part of the protein expression
**kwargs
Keyword args for :meth:`scvi.model.base.DifferentialComputation.get_bayes_factors`
Returns
-------
Differential expression DataFrame.
"""
adata = self._validate_anndata(adata)
adata_manager = self.get_anndata_manager(adata, required=True)
model_fn = partial(
self._expression_for_de,
scale_protein=scale_protein,
sample_protein_mixing=sample_protein_mixing,
include_protein_background=include_protein_background,
protein_prior_count=protein_prior_count,
batch_size=batch_size,
)
col_names = np.concatenate(
[
np.asarray(_get_var_names_from_manager(adata_manager)),
self.protein_state_registry.column_names,
]
)
result = _de_core(
adata_manager,
model_fn,
None,
groupby,
group1,
group2,
idx1,
idx2,
all_stats,
cite_seq_raw_counts_properties,
col_names,
mode,
batchid1,
batchid2,
delta,
batch_correction,
fdr_target,
silent,
**kwargs,
)
return result
@torch.inference_mode()
def posterior_predictive_sample(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
n_samples: int = 1,
batch_size: Optional[int] = None,
gene_list: Optional[Sequence[str]] = None,
protein_list: Optional[Sequence[str]] = None,
) -> np.ndarray:
r"""Generate observation samples from the posterior predictive distribution.
The posterior predictive distribution is written as :math:`p(\hat{x}, \hat{y} \mid x, y)`.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples
Number of required samples for each cell
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
gene_list
Names of genes of interest
protein_list
Names of proteins of interest
Returns
-------
x_new : :class:`~numpy.ndarray`
tensor with shape (n_cells, n_genes, n_samples)
"""
if self.module.gene_likelihood not in ["nb"]:
raise ValueError("Invalid gene_likelihood")
adata = self._validate_anndata(adata)
adata_manager = self.get_anndata_manager(adata, required=True)
if gene_list is None:
gene_mask = slice(None)
else:
all_genes = _get_var_names_from_manager(adata_manager)
gene_mask = [True if gene in gene_list else False for gene in all_genes]
if protein_list is None:
protein_mask = slice(None)
else:
all_proteins = self.protein_state_registry.column_names
protein_mask = [True if p in protein_list else False for p in all_proteins]
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
scdl_list = []
for tensors in scdl:
rna_sample, protein_sample = self.module.sample(
tensors, n_samples=n_samples
)
rna_sample = rna_sample[..., gene_mask]
protein_sample = protein_sample[..., protein_mask]
data = torch.cat([rna_sample, protein_sample], dim=-1).numpy()
scdl_list += [data]
if n_samples > 1:
scdl_list[-1] = np.transpose(scdl_list[-1], (1, 2, 0))
scdl_list = np.concatenate(scdl_list, axis=0)
return scdl_list
@torch.inference_mode()
def _get_denoised_samples(
self,
adata=None,
indices=None,
n_samples: int = 25,
batch_size: int = 64,
rna_size_factor: int = 1000,
transform_batch: Optional[int] = None,
) -> np.ndarray:
"""Return samples from an adjusted posterior predictive.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
indices of `adata` to use
n_samples
How may samples per cell
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
rna_size_factor
size factor for RNA prior to sampling gamma distribution
transform_batch
int of which batch to condition on for all cells
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)
scdl_list = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.PROTEIN_EXP_KEY]
generative_kwargs = {"transform_batch": transform_batch}
inference_kwargs = {"n_samples": n_samples}
with torch.inference_mode():
(
inference_outputs,
generative_outputs,
) = self.module.forward(
tensors,
inference_kwargs=inference_kwargs,
generative_kwargs=generative_kwargs,
compute_loss=False,
)
px_ = generative_outputs["px_"]
py_ = generative_outputs["py_"]
device = px_["r"].device
pi = 1 / (1 + torch.exp(-py_["mixing"]))
mixing_sample = torch.distributions.Bernoulli(pi).sample()
protein_rate = py_["rate_fore"]
rate = torch.cat((rna_size_factor * px_["scale"], protein_rate), dim=-1)
if len(px_["r"].size()) == 2:
px_dispersion = px_["r"]
else:
px_dispersion = torch.ones_like(x).to(device) * px_["r"]
if len(py_["r"].size()) == 2:
py_dispersion = py_["r"]
else:
py_dispersion = torch.ones_like(y).to(device) * py_["r"]
dispersion = torch.cat((px_dispersion, py_dispersion), dim=-1)
# This gamma is really l*w using scVI manuscript notation
p = rate / (rate + dispersion)
r = dispersion
l_train = torch.distributions.Gamma(r, (1 - p) / p).sample()
data = l_train.cpu().numpy()
# make background 0
data[:, :, x.shape[1] :] = (
data[:, :, x.shape[1] :] * (1 - mixing_sample).cpu().numpy()
)
scdl_list += [data]
scdl_list[-1] = np.transpose(scdl_list[-1], (1, 2, 0))
return np.concatenate(scdl_list, axis=0)
@torch.inference_mode()
def get_feature_correlation_matrix(
self,
adata=None,
indices=None,
n_samples: int = 10,
batch_size: int = 64,
rna_size_factor: int = 1000,
transform_batch: Optional[Sequence[Union[Number, str]]] = None,
correlation_type: Literal["spearman", "pearson"] = "spearman",
log_transform: bool = False,
) -> pd.DataFrame:
"""Generate gene-gene correlation matrix using scvi uncertainty and expression.
Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
n_samples
Number of posterior samples to use for estimation.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
rna_size_factor
size factor for RNA prior to sampling gamma distribution
transform_batch
Batches to condition on.
If transform_batch is:
- None, then real observed batch is used
- int, then batch transform_batch is used
- list of int, then values are averaged over provided batches.
correlation_type
One of "pearson", "spearman".
log_transform
Whether to log transform denoised values prior to correlation calculation.
Returns
-------
Gene-protein-gene-protein correlation matrix