diff --git a/torchtree/cli/advi.py b/torchtree/cli/advi.py index ccc489e9..9972e65f 100644 --- a/torchtree/cli/advi.py +++ b/torchtree/cli/advi.py @@ -11,6 +11,7 @@ from torchtree.cli import PLUGIN_MANAGER from torchtree.cli.argparse_utils import list_or_int from torchtree.cli.evolution import ( + COALESCENT_PIECEWISE, create_alignment, create_evolution_joint, create_evolution_parser, @@ -833,12 +834,7 @@ def create_logger(id_, parameters, arg): models = ['joint', 'like', 'prior'] if arg.coalescent: models.append('coalescent') - if arg.coalescent in ( - "skyride", - "skygrid", - "piecewise-exponential", - "piecewise-linear", - ): + if arg.coalescent in COALESCENT_PIECEWISE: models.append('gmrf') models.append( { @@ -866,7 +862,7 @@ def create_sampler(id_, var_id, parameters, arg): tree_file_name = 'samples.trees' parameters2 = list(filter(lambda x: 'tree.ratios' != x, parameters)) - models = ['joint', 'like', 'prior', var_id] + models = ['joint.jacobian', 'joint', 'like', 'prior', var_id] if arg.location_regex: models.append('like.location') @@ -875,7 +871,7 @@ def create_sampler(id_, var_id, parameters, arg): if arg.coalescent: models.append('coalescent') - if arg.coalescent in ('skyride', 'skygrid'): + if arg.coalescent in COALESCENT_PIECEWISE: models.append('gmrf') models.append( { @@ -935,14 +931,8 @@ def build_advi(arg): jacobians_list = create_jacobians(json_list) if arg.clock is not None and arg.heights == 'ratio': jacobians_list.append('tree') - if arg.coalescent in ( - "skyglide", - "skygrid", - "skyride", - "piecewise-exponential", - "piecewise-constant", - "piecewise-linear", - ): + + if arg.coalescent in COALESCENT_PIECEWISE: jacobians_list.remove("coalescent.theta") joint_jacobian = { @@ -1004,18 +994,7 @@ def build_advi(arg): if arg.coalescent_integrated is None: parameters.append("coalescent.theta") - if ( - arg.coalescent - in ( - "skyglide", - "skygrid", - "skyride", - "piecewise-constant", - "piecewise-exponential", - "piecewise-linear", - ) - and not arg.gmrf_integrated - ): + if arg.coalescent in COALESCENT_PIECEWISE and not arg.gmrf_integrated: parameters.append('gmrf.precision') elif arg.coalescent == 'exponential': parameters.append('coalescent.growth') diff --git a/torchtree/cli/cli.py b/torchtree/cli/cli.py index 15eb2430..3e4c453b 100644 --- a/torchtree/cli/cli.py +++ b/torchtree/cli/cli.py @@ -1,15 +1,36 @@ from __future__ import annotations import argparse +import importlib import json from torchtree._version import __version__ -from torchtree.cli import PLUGIN_MANAGER +from torchtree.cli import PLUGIN_MANAGER, evolution from torchtree.cli.advi import create_variational_parser from torchtree.cli.hmc import create_hmc_parser from torchtree.cli.map import create_map_parser from torchtree.cli.mcmc import create_mcmc_parser from torchtree.cli.utils import remove_constraints +from torchtree.core.utils import package_contents + + +def create_show_parser(subprasers): + parser = subprasers.add_parser("show", help="Show some information") + parser.add_argument('what') + from torchtree.core.utils import REGISTERED_CLASSES + + def show(arg): + if arg.what == "classes": + for module in package_contents('torchtree'): + importlib.import_module(module) + for klass in REGISTERED_CLASSES: + print(klass) + elif arg.what == "plugins": + for plugin in PLUGIN_MANAGER.plugins(): + print(plugin) + exit(0) + + parser.set_defaults(func=show) def main(): @@ -33,6 +54,8 @@ def main(): create_hmc_parser(subprasers) + create_show_parser(subprasers) + PLUGIN_MANAGER.load_plugins() PLUGIN_MANAGER.load_arguments(subprasers) @@ -41,6 +64,8 @@ def main(): parser.print_help() exit(2) + evolution.check_arguments(arg, parser) + json_dic = arg.func(arg) if not arg.debug: diff --git a/torchtree/cli/evolution.py b/torchtree/cli/evolution.py index 16b1cd5e..89d5b27d 100644 --- a/torchtree/cli/evolution.py +++ b/torchtree/cli/evolution.py @@ -47,6 +47,16 @@ _engine = None +COALESCENT_PIECEWISE = [ + "piecewise-constant", + "piecewise-exponential", + "piecewise-linear", + "skyglide", + "skygrid", + "skyride", +] + + def create_evolution_parser(parser): group = parser.add_mutually_exclusive_group() group.add_argument("-i", "--input", required=False, help="""alignment file""") @@ -232,13 +242,8 @@ def add_coalescent(parser): choices=[ "constant", "exponential", - "skyride", - "skygrid", - "skyglide", - "piecewise-constant", - "piecewise-exponential", - "piecewise-linear", - ], + ] + + COALESCENT_PIECEWISE, default=None, help="""type of coalescent""", ) @@ -278,6 +283,25 @@ def add_coalescent(parser): return parser +def check_arguments(arg, parser): + if arg.coalescent in COALESCENT_PIECEWISE: + piecewise_grid = COALESCENT_PIECEWISE.copy() + piecewise_grid.remove("skyride") + if arg.coalescent == "skyride" and ( + arg.cutoff is not None or arg.grid is not None + ): + parser.error( + "skyride coalescent model does not require cutoff or grid arguments" + ) + elif arg.coalescent in piecewise_grid and ( + arg.cutoff is None or arg.grid is None + ): + parser.error( + ", ".join(piecewise_grid) + + " coalescent models require cutoff and grid arguments" + ) + + def distribution_type(arg, choices): """Used by argparse for specifying distributions with optional parameters.""" @@ -1191,14 +1215,7 @@ def create_coalesent(id_, tree_id, taxa, arg): }, ) ) - elif arg.coalescent in ( - "skyride", - "skygrid", - "skyglide", - "piecewise-constant", - "piecewise-exponential", - "piecewise-linear", - ): + elif arg.coalescent in COALESCENT_PIECEWISE: if arg.coalescent == "skyride": theta_shape = [len(taxa["taxa"]) - 1] else: diff --git a/torchtree/cli/hmc.py b/torchtree/cli/hmc.py index a7403931..fe7cf8e0 100644 --- a/torchtree/cli/hmc.py +++ b/torchtree/cli/hmc.py @@ -2,6 +2,7 @@ from torchtree.cli import PLUGIN_MANAGER from torchtree.cli.evolution import ( + COALESCENT_PIECEWISE, create_alignment, create_evolution_joint, create_evolution_parser, @@ -260,9 +261,7 @@ def build_hmc(arg): jacobians_list = create_jacobians(json_list) if arg.clock is not None and arg.heights == "ratio": jacobians_list.append("tree") - if arg.coalescent in ("skygrid", "skyride") or arg.coalescent.startswith( - "piecewise" - ): + if arg.coalescent in COALESCENT_PIECEWISE: jacobians_list.remove("coalescent.theta") joint_jacobian = { diff --git a/torchtree/cli/loggers.py b/torchtree/cli/loggers.py index b561048b..be478798 100644 --- a/torchtree/cli/loggers.py +++ b/torchtree/cli/loggers.py @@ -1,10 +1,14 @@ from __future__ import annotations +from torchtree.cli.evolution import COALESCENT_PIECEWISE + def create_loggers(parameters: list[str], arg) -> dict: models = ["joint.jacobian", "joint", "like", "prior"] if arg.coalescent: models.append("coalescent") + if arg.coalescent in COALESCENT_PIECEWISE: + models.append('gmrf') return [ { "id": "logger", diff --git a/torchtree/cli/mcmc.py b/torchtree/cli/mcmc.py index bc927f62..a1f4e241 100644 --- a/torchtree/cli/mcmc.py +++ b/torchtree/cli/mcmc.py @@ -2,6 +2,7 @@ from torchtree.cli import PLUGIN_MANAGER from torchtree.cli.evolution import ( + COALESCENT_PIECEWISE, create_alignment, create_evolution_joint, create_evolution_parser, @@ -14,7 +15,7 @@ create_block_updating_operator, create_sliding_window_operator, ) -from torchtree.cli.utils import make_unconstrained +from torchtree.cli.utils import CONSTRAINT, make_unconstrained def create_mcmc_parser(subprasers): @@ -58,18 +59,20 @@ def create_mcmc(joint, parameters, parameters_unres, arg): "operators": [], } - for param in parameters_unres: - if param["id"].endswith("theta.log"): + if param["id"].endswith("theta.log") and arg.coalescent in ( + "skygrid", + "piecewise-constant", + ): operator = create_block_updating_operator( param["id"], "gmrf", "coalescent", arg ) - mcmc_json["operators"].append(operator) else: operator = create_sliding_window_operator(param["id"], joint, param, arg) - mcmc_json["operators"].append(operator) + mcmc_json["operators"].append(operator) if arg.stem: - mcmc_json["loggers"] = create_loggers(parameters, arg) + parameters2 = list(filter(lambda x: 'tree.ratios' != x, parameters)) + mcmc_json["loggers"] = create_loggers(parameters2, arg) return mcmc_json @@ -94,9 +97,7 @@ def build_mcmc(arg): jacobians_list = create_jacobians(json_list) if arg.clock is not None and arg.heights == "ratio": jacobians_list.append("tree") - if arg.coalescent in ("skygrid", "skyride") or arg.coalescent.startswith( - "piecewise" - ): + if arg.coalescent in COALESCENT_PIECEWISE: jacobians_list.remove("coalescent.theta") joint_jacobian = { diff --git a/torchtree/evolution/coalescent.py b/torchtree/evolution/coalescent.py index 494c51f7..930c60c8 100644 --- a/torchtree/evolution/coalescent.py +++ b/torchtree/evolution/coalescent.py @@ -302,36 +302,82 @@ def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: self.theta * self.growth ) lchoose2 = lineage_count * (lineage_count - 1) / 2.0 - log_thetas = torch.where( - node_mask_sorted == -1, - torch.log(self.theta * torch.exp(-heights_sorted * self.growth)), - torch.zeros(1, dtype=heights_sorted.dtype), - ) + log_thetas = torch.log( + self.theta * torch.exp(-heights_sorted * self.growth) + ) * (node_mask_sorted == -1) return torch.sum(-lchoose2 * integral - log_thetas[..., 1:], -1, keepdim=True) class PiecewiseConstantCoalescent(AbstractCoalescentDistribution): - def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: - taxa_shape = node_heights.shape[:-1] + (int((node_heights.shape[-1] + 1) / 2),) + def _sorted_terms(self, node_heights): + batch_shape = max(node_heights.shape, self.theta.shape, key=len)[:-1] + # if node_heights is fixed there is no batch dimension + if node_heights.dim() < self.theta.dim(): + heights = node_heights.expand(batch_shape + torch.Size([-1])) + else: + heights = node_heights + + taxa_shape = heights.shape[:-1] + (int((node_heights.shape[-1] + 1) / 2),) node_mask = torch.cat( [ - torch.full(taxa_shape, 1, dtype=torch.int), # sampling event + # sampling event + torch.full(taxa_shape, 1, dtype=torch.int), + # coalescent event torch.full( taxa_shape[:-1] + (taxa_shape[-1] - 1,), -1, dtype=torch.int, ), - ], # coalescent event + ], dim=-1, ) - - indices = torch.argsort(node_heights, descending=False) - heights_sorted = torch.gather(node_heights, -1, indices) + indices = torch.argsort(heights, descending=False) + heights_sorted = torch.gather(heights, -1, indices) node_mask_sorted = torch.gather(node_mask, -1, indices) lineage_count = node_mask_sorted.cumsum(-1)[..., :-1] - durations = heights_sorted[..., 1:] - heights_sorted[..., :-1] lchoose2 = lineage_count * (lineage_count - 1) / 2.0 + intervals = heights_sorted[..., 1:] - heights_sorted[..., :-1] + return node_mask_sorted, lchoose2, intervals + + def sufficient_statistics( + self, node_heights: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Returns sorted sufficient statistics and number of coalescent events + per interval. + + This method is used by the block updating MCMC operator. + + :param torch.Tensor node_heights: node heights. + :return: sufficient statistics and number of coalescent events + per interval. + :rtype tuple[torch.Tensor, torch.Tensor] + """ + node_mask_sorted, lchoose2, intervals = self._sorted_terms(node_heights) + + if self.theta.dim() > 1: + sufficient_statistics = [] + for i in range(self.theta.shape[-2]): + groups = torch.tensor_split( + lchoose2[i] * intervals[i], + torch.where(node_mask_sorted[i] == -1)[0], + ) + ss = torch.tensor(list(map(torch.sum, groups[:-1]))) + sufficient_statistics.append(ss) + sufficient_statistics = torch.stack(sufficient_statistics) + else: + groups = torch.tensor_split( + lchoose2 * intervals, torch.where(node_mask_sorted == -1)[0] + ) + sufficient_statistics = torch.tensor(list(map(torch.sum, groups[:-1]))) + + internal_shape = sufficient_statistics.shape[:-1] + ( + int((node_heights.shape[-1] + 1) / 2) - 1, + ) + return sufficient_statistics, torch.ones(internal_shape) + + def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: + node_mask_sorted, lchoose2, durations = self._sorted_terms(node_heights) thetas_indices = torch.where( node_mask_sorted == -1, @@ -394,7 +440,7 @@ def maximum_likelihood(cls, node_heights: torch.Tensor) -> torch.Tensor: @register_class -class PiecewiseConstantCoalescentModel(ConstantCoalescentModel): +class PiecewiseConstantCoalescentModel(AbstractCoalescentModel): def distribution(self) -> AbstractCoalescentDistribution: return PiecewiseConstantCoalescent(self.theta.tensor) @@ -420,12 +466,8 @@ def __init__( super().__init__(thetas, validate_args) self.grid = grid - def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: + def _sorted_terms(self, node_heights: torch.Tensor): batch_shape = max(node_heights.shape, self.theta.shape, key=len)[:-1] - if node_heights.dim() > self.theta.dim(): - thetas = self.theta.expand(batch_shape + torch.Size([-1])) - else: - thetas = self.theta grid = self.grid.expand(batch_shape + torch.Size([-1])) @@ -464,6 +506,24 @@ def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: durations = heights_sorted[..., 1:] - heights_sorted[..., :-1] lchoose2 = lineage_count * (lineage_count - 1) / 2.0 + return node_mask_sorted, lchoose2, durations + + def sufficient_statistics(self, node_heights: torch.Tensor): + node_mask_sorted, lchoose2, durations = self._sorted_terms(node_heights) + groups = torch.tensor_split( + lchoose2 * durations, torch.where(node_mask_sorted == 0)[0] + ) + sufficient_statistics = torch.tensor(list(map(torch.sum, groups))) + groups = torch.tensor_split( + node_mask_sorted == -1, torch.where(node_mask_sorted == 0)[0] + ) + coalescent_counts = torch.tensor(list(map(torch.sum, groups))) + return sufficient_statistics, coalescent_counts + + def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: + batch_shape = max(node_heights.shape, self.theta.shape, key=len)[:-1] + node_mask_sorted, lchoose2, durations = self._sorted_terms(node_heights) + thetas_indices = torch.where( node_mask_sorted == 0, torch.tensor([1], dtype=torch.long), @@ -480,7 +540,7 @@ def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: log_thetas = torch.where( node_mask_sorted == -1, torch.log(thetas), - torch.zeros(1, dtype=heights.dtype), + torch.zeros(1, dtype=node_heights.dtype, device=node_heights.device), ) return torch.sum( -lchoose2 * durations / thetas[..., :-1] - log_thetas[..., 1:], @@ -925,7 +985,9 @@ def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: indices = torch.argsort(grid_heights, descending=False) grid_heights_sorted = torch.gather(grid_heights, -1, indices) grid_heights_sorted[..., 0] = 0 - grid[..., 0] = 0 + grid = torch.cat((torch.tensor([0.0]), self.grid)).expand( + batch_shape + torch.Size([-1]) + ) event_mask_sorted = torch.gather(event_mask, -1, indices) lineage_count = event_mask_sorted.cumsum(-1)[..., :-1] @@ -937,10 +999,10 @@ def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: pop_sizes = torch.zeros_like(grid_heights) # set population size at every grid point - indices_grid = torch.zeros(grid_heights_sorted.shape, dtype=torch.long) - indices_grid[..., 1:] = 1 - indices_grid = indices_grid.cumsum(-1)[event_mask_sorted == 0].reshape( - grid.shape + indices_grid = ( + torch.arange(grid_heights_sorted.shape[-1]) + .expand(batch_shape + (-1,))[event_mask_sorted == 0] + .reshape(grid.shape) ) pop_sizes = pop_sizes.scatter(-1, indices_grid, thetas) @@ -969,10 +1031,10 @@ def log_prob(self, node_heights: torch.Tensor) -> torch.Tensor: end_grid[idx] - start_grid[idx] ) - p = torch.zeros(pop_sizes.shape, dtype=torch.long) - p[..., 1:] = 1 - indices = p.cumsum(-1)[event_mask_sorted != 0].reshape( - pop_size_node_heights.shape + indices = ( + torch.arange(pop_sizes.shape[-1]) + .expand(batch_shape + (-1,))[event_mask_sorted != 0] + .reshape(pop_size_node_heights.shape) ) pop_sizes = pop_sizes.scatter(-1, indices, pop_size_node_heights) diff --git a/torchtree/inference/mcmc/gmrf_block_updating.py b/torchtree/inference/mcmc/gmrf_block_updating.py index c51e9de4..ddffb278 100644 --- a/torchtree/inference/mcmc/gmrf_block_updating.py +++ b/torchtree/inference/mcmc/gmrf_block_updating.py @@ -219,6 +219,12 @@ def _step(self) -> Tensor: ) return log_q_backward - log_q_forward + def _state_dict(self) -> dict[str, Any]: + return {"scaler": self._scaler} + + def _load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._scaler = state_dict["scaler"] + @classmethod def from_json( cls, data: dict[str, Any], dic: dict[str, Identifiable] diff --git a/torchtree/inference/mcmc/operator.py b/torchtree/inference/mcmc/operator.py index 31d818e9..e3010a7a 100644 --- a/torchtree/inference/mcmc/operator.py +++ b/torchtree/inference/mcmc/operator.py @@ -233,10 +233,10 @@ def _step(self) -> Tensor: ) def _state_dict(self) -> dict[str, Any]: - return {"width": self._scaler} + return {"width": self._width} def _load_state_dict(self, state_dict: dict[str, Any]) -> None: - self._scaler = state_dict["width"] + self._width = state_dict["width"] @classmethod def from_json(cls, data, dic): diff --git a/torchtree/optim/lr_scheduler.py b/torchtree/optim/lr_scheduler.py index 95897194..cb620671 100644 --- a/torchtree/optim/lr_scheduler.py +++ b/torchtree/optim/lr_scheduler.py @@ -1,6 +1,7 @@ from __future__ import annotations import inspect +from typing import Any from torch.optim.lr_scheduler import _LRScheduler as TorchScheduler @@ -21,6 +22,12 @@ def __init__(self, scheduler: TorchScheduler) -> None: def step(self, *args) -> None: self.scheduler.step(*args) + def state_dict(self) -> dict[str, Any]: + return self.scheduler.state_dict() + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self.scheduler.load_state_dict(state_dict) + @classmethod def from_json( cls, data: dict[str, any], dic: dict[str, any], **kwargs diff --git a/torchtree/optim/optimizer.py b/torchtree/optim/optimizer.py index dee02a64..0e1701c1 100644 --- a/torchtree/optim/optimizer.py +++ b/torchtree/optim/optimizer.py @@ -193,11 +193,16 @@ def run(self) -> None: self._run() def state_dict(self) -> dict[str, Any]: - return {"iteration": self._epoch, "optimizer": self.optimizer.state_dict()} + state = {"iteration": self._epoch, "optimizer": self.optimizer.state_dict()} + if self.scheduler is not None: + state["scheduler"] = self.scheduler.state_dict() + return state def load_state_dict(self, state_dict: dict[str, Any]) -> None: self._epoch = state_dict["iteration"] self.optimizer.load_state_dict(state_dict["optimizer"]) + if self.scheduler is not None: + self.scheduler.load_state_dict(state_dict["scheduler"]) def save_full_state(self, checkpoint, safely=True, overwrite=False) -> None: optimizer_state = {