Skip to content

Commit

Permalink
Add missing code in coalescent for block updating operator
Browse files Browse the repository at this point in the history
- Improve CLI for piecewise models
- Add missing load_state_dict/state_dict methods
  • Loading branch information
4ment committed May 19, 2024
1 parent 47a997c commit 0c81dde
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 88 deletions.
2 changes: 1 addition & 1 deletion torchtree/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "1.0.2-dev4"
__version__ = "1.0.2-dev5"
35 changes: 7 additions & 28 deletions torchtree/cli/advi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torchtree.cli import PLUGIN_MANAGER
from torchtree.cli.argparse_utils import list_or_int
from torchtree.cli.evolution import (
COALESCENT_PIECEWISE,
create_alignment,
create_evolution_joint,
create_evolution_parser,
Expand Down Expand Up @@ -833,12 +834,7 @@ def create_logger(id_, parameters, arg):
models = ['joint', 'like', 'prior']
if arg.coalescent:
models.append('coalescent')
if arg.coalescent in (
"skyride",
"skygrid",
"piecewise-exponential",
"piecewise-linear",
):
if arg.coalescent in COALESCENT_PIECEWISE:
models.append('gmrf')
models.append(
{
Expand Down Expand Up @@ -866,7 +862,7 @@ def create_sampler(id_, var_id, parameters, arg):
tree_file_name = 'samples.trees'

parameters2 = list(filter(lambda x: 'tree.ratios' != x, parameters))
models = ['joint', 'like', 'prior', var_id]
models = ['joint.jacobian', 'joint', 'like', 'prior', var_id]

if arg.location_regex:
models.append('like.location')
Expand All @@ -875,7 +871,7 @@ def create_sampler(id_, var_id, parameters, arg):

if arg.coalescent:
models.append('coalescent')
if arg.coalescent in ('skyride', 'skygrid'):
if arg.coalescent in COALESCENT_PIECEWISE:
models.append('gmrf')
models.append(
{
Expand Down Expand Up @@ -935,14 +931,8 @@ def build_advi(arg):
jacobians_list = create_jacobians(json_list)
if arg.clock is not None and arg.heights == 'ratio':
jacobians_list.append('tree')
if arg.coalescent in (
"skyglide",
"skygrid",
"skyride",
"piecewise-exponential",
"piecewise-constant",
"piecewise-linear",
):

if arg.coalescent in COALESCENT_PIECEWISE:
jacobians_list.remove("coalescent.theta")

joint_jacobian = {
Expand Down Expand Up @@ -1004,18 +994,7 @@ def build_advi(arg):
if arg.coalescent_integrated is None:
parameters.append("coalescent.theta")

if (
arg.coalescent
in (
"skyglide",
"skygrid",
"skyride",
"piecewise-constant",
"piecewise-exponential",
"piecewise-linear",
)
and not arg.gmrf_integrated
):
if arg.coalescent in COALESCENT_PIECEWISE and not arg.gmrf_integrated:
parameters.append('gmrf.precision')
elif arg.coalescent == 'exponential':
parameters.append('coalescent.growth')
Expand Down
27 changes: 26 additions & 1 deletion torchtree/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
from __future__ import annotations

import argparse
import importlib
import json

from torchtree._version import __version__
from torchtree.cli import PLUGIN_MANAGER
from torchtree.cli import PLUGIN_MANAGER, evolution
from torchtree.cli.advi import create_variational_parser
from torchtree.cli.hmc import create_hmc_parser
from torchtree.cli.map import create_map_parser
from torchtree.cli.mcmc import create_mcmc_parser
from torchtree.cli.utils import remove_constraints
from torchtree.core.utils import package_contents


def create_show_parser(subprasers):
parser = subprasers.add_parser("show", help="Show some information")
parser.add_argument('what')
from torchtree.core.utils import REGISTERED_CLASSES

def show(arg):
if arg.what == "classes":
for module in package_contents('torchtree'):
importlib.import_module(module)
for klass in REGISTERED_CLASSES:
print(klass)
elif arg.what == "plugins":
for plugin in PLUGIN_MANAGER.plugins():
print(plugin)
exit(0)

parser.set_defaults(func=show)


def main():
Expand All @@ -33,6 +54,8 @@ def main():

create_hmc_parser(subprasers)

create_show_parser(subprasers)

PLUGIN_MANAGER.load_plugins()
PLUGIN_MANAGER.load_arguments(subprasers)

Expand All @@ -41,6 +64,8 @@ def main():
parser.print_help()
exit(2)

evolution.check_arguments(arg, parser)

json_dic = arg.func(arg)

if not arg.debug:
Expand Down
47 changes: 32 additions & 15 deletions torchtree/cli/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@
_engine = None


COALESCENT_PIECEWISE = [
"piecewise-constant",
"piecewise-exponential",
"piecewise-linear",
"skyglide",
"skygrid",
"skyride",
]


def create_evolution_parser(parser):
group = parser.add_mutually_exclusive_group()
group.add_argument("-i", "--input", required=False, help="""alignment file""")
Expand Down Expand Up @@ -232,13 +242,8 @@ def add_coalescent(parser):
choices=[
"constant",
"exponential",
"skyride",
"skygrid",
"skyglide",
"piecewise-constant",
"piecewise-exponential",
"piecewise-linear",
],
]
+ COALESCENT_PIECEWISE,
default=None,
help="""type of coalescent""",
)
Expand Down Expand Up @@ -278,6 +283,25 @@ def add_coalescent(parser):
return parser


def check_arguments(arg, parser):
if arg.coalescent in COALESCENT_PIECEWISE:
piecewise_grid = COALESCENT_PIECEWISE.copy()
piecewise_grid.remove("skyride")
if arg.coalescent == "skyride" and (
arg.cutoff is not None or arg.grid is not None
):
parser.error(
"skyride coalescent model does not require cutoff or grid arguments"
)
elif arg.coalescent in piecewise_grid and (
arg.cutoff is None or arg.grid is None
):
parser.error(
", ".join(piecewise_grid)
+ " coalescent models require cutoff and grid arguments"
)


def distribution_type(arg, choices):
"""Used by argparse for specifying distributions with optional
parameters."""
Expand Down Expand Up @@ -1191,14 +1215,7 @@ def create_coalesent(id_, tree_id, taxa, arg):
},
)
)
elif arg.coalescent in (
"skyride",
"skygrid",
"skyglide",
"piecewise-constant",
"piecewise-exponential",
"piecewise-linear",
):
elif arg.coalescent in COALESCENT_PIECEWISE:
if arg.coalescent == "skyride":
theta_shape = [len(taxa["taxa"]) - 1]
else:
Expand Down
5 changes: 2 additions & 3 deletions torchtree/cli/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torchtree.cli import PLUGIN_MANAGER
from torchtree.cli.evolution import (
COALESCENT_PIECEWISE,
create_alignment,
create_evolution_joint,
create_evolution_parser,
Expand Down Expand Up @@ -260,9 +261,7 @@ def build_hmc(arg):
jacobians_list = create_jacobians(json_list)
if arg.clock is not None and arg.heights == "ratio":
jacobians_list.append("tree")
if arg.coalescent in ("skygrid", "skyride") or arg.coalescent.startswith(
"piecewise"
):
if arg.coalescent in COALESCENT_PIECEWISE:
jacobians_list.remove("coalescent.theta")

joint_jacobian = {
Expand Down
4 changes: 4 additions & 0 deletions torchtree/cli/loggers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from __future__ import annotations

from torchtree.cli.evolution import COALESCENT_PIECEWISE


def create_loggers(parameters: list[str], arg) -> dict:
models = ["joint.jacobian", "joint", "like", "prior"]
if arg.coalescent:
models.append("coalescent")
if arg.coalescent in COALESCENT_PIECEWISE:
models.append('gmrf')
return [
{
"id": "logger",
Expand Down
18 changes: 10 additions & 8 deletions torchtree/cli/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from torchtree.cli import PLUGIN_MANAGER
from torchtree.cli.evolution import (
COALESCENT_PIECEWISE,
create_alignment,
create_evolution_joint,
create_evolution_parser,
Expand All @@ -14,7 +15,7 @@
create_block_updating_operator,
create_sliding_window_operator,
)
from torchtree.cli.utils import make_unconstrained
from torchtree.cli.utils import CONSTRAINT, make_unconstrained


def create_mcmc_parser(subprasers):
Expand Down Expand Up @@ -59,17 +60,20 @@ def create_mcmc(joint, parameters, parameters_unres, arg):
}

for param in parameters_unres:
if param["id"].endswith("theta.log"):
if param["id"].endswith("theta.log") and arg.coalescent in (
"skygrid",
"piecewise-constant",
):
operator = create_block_updating_operator(
param["id"], "gmrf", "coalescent", arg
)
mcmc_json["operators"].append(operator)
else:
operator = create_sliding_window_operator(param["id"], joint, param, arg)
mcmc_json["operators"].append(operator)
mcmc_json["operators"].append(operator)

if arg.stem:
mcmc_json["loggers"] = create_loggers(parameters, arg)
parameters2 = list(filter(lambda x: 'tree.ratios' != x, parameters))
mcmc_json["loggers"] = create_loggers(parameters2, arg)

return mcmc_json

Expand All @@ -94,9 +98,7 @@ def build_mcmc(arg):
jacobians_list = create_jacobians(json_list)
if arg.clock is not None and arg.heights == "ratio":
jacobians_list.append("tree")
if arg.coalescent in ("skygrid", "skyride") or arg.coalescent.startswith(
"piecewise"
):
if arg.coalescent in COALESCENT_PIECEWISE:
jacobians_list.remove("coalescent.theta")

joint_jacobian = {
Expand Down
Loading

0 comments on commit 0c81dde

Please sign in to comment.