Skip to content

Commit

Permalink
Minor changes and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
4ment committed Aug 29, 2024
1 parent 8193473 commit 8071cd3
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 39 deletions.
36 changes: 26 additions & 10 deletions torchtree/cli/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,14 @@ def create_tree_likelihood_general(trait: str, data_type: dict, taxa: Taxa, arg)
"id": f"site_pattern.{trait}",
"type": "AttributePattern",
"taxa": "taxa",
"data_type": data_type,
"data_type": data_type["id"],
"attribute": trait,
}
site_model = create_site_model(f"sitemodel.{trait}", arg)
# site_model = create_site_model(f"sitemodel.{trait}", arg)
site_model = {
"id": f"sitemodel.{trait}",
"type": "ConstantSiteModel",
}

state_count = len(data_type["codes"])
mapping = torch.arange(state_count * (state_count - 1))
Expand All @@ -566,6 +570,7 @@ def create_tree_likelihood_general(trait: str, data_type: dict, taxa: Taxa, arg)
# CONSTRAINT.SIMPLEX.value: True
},
"state_count": state_count,
"data_type": data_type,
}
# substitution_model = {
# "id": f"substmodel.{trait}",
Expand All @@ -582,7 +587,17 @@ def create_tree_likelihood_general(trait: str, data_type: dict, taxa: Taxa, arg)
"site_pattern": site_pattern,
}
if arg.clock is not None:
treelikelihood_model["branch_model"] = "branchmodel"
treelikelihood_model["branch_model"] = {
"id": f"branchmodel.{trait}",
"type": "StrictClockModel",
"tree_model": "tree",
"rate": {
"id": f"rate.{trait}",
"type": "Parameter",
"tensor": [1.0],
CONSTRAINT.LOWER.value: 0.0,
},
}

if arg.use_ambiguities:
treelikelihood_model["use_ambiguities"] = True
Expand Down Expand Up @@ -954,7 +969,7 @@ def create_general_data_type(id_, trait, taxa):
data_type = {
"id": id_,
"type": "GeneralDataType",
"codes": unique_codes,
"codes": sorted(unique_codes),
}
return data_type

Expand Down Expand Up @@ -1062,10 +1077,11 @@ def create_taxa(id_, arg):
break
taxa_map = {taxon["id"]: taxon for taxon in taxa["taxa"]}
for line in reader:
if "attributes" not in taxa_map[line[1]]:
taxa_map[line[1]]["attributes"] = {}
taxon = line[0]
if "attributes" not in taxa_map[taxon]:
taxa_map[taxon]["attributes"] = {}
for trait, idx in zip(arg.trait, indices):
taxa_map[line[1]]["attributes"][trait] = line[idx]
taxa_map[taxon]["attributes"][trait] = line[idx]
return taxa


Expand Down Expand Up @@ -1692,9 +1708,6 @@ def create_evolution_joint(taxa, alignment, arg):
],
}

if len(prior_dic["distributions"]) > 0:
joint_dic["distributions"].append(prior_dic)

if arg.location_regex:
data_type_location = create_general_data_type(
"data_type.location", "location", taxa
Expand All @@ -1713,6 +1726,9 @@ def create_evolution_joint(taxa, alignment, arg):
)
joint_dic["distributions"].append(trait_dic)

if len(prior_dic["distributions"]) > 0:
joint_dic["distributions"].append(prior_dic)

return joint_dic


Expand Down
11 changes: 8 additions & 3 deletions torchtree/cli/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def create_hmc_parser(subprasers):
)
parser.add_argument(
"--stem",
required=True,
help="""stem for output file""",
default="samples",
help="""stem for output file [default: %(default)s]""",
)
parser.add_argument(
"--mass_matrix",
Expand Down Expand Up @@ -236,7 +236,12 @@ def create_hmc(joint, parameters, parameters_unres, arg):
)

if arg.stem:
hmc_json["loggers"] = create_loggers(parameters, arg)
parameters2 = list(filter(lambda x: 'tree.ratios' != x, parameters))
if "tree.root_height.unshifted" in parameters2:
idx = parameters2.index("tree.root_height.unshifted")
parameters2[idx] = "tree.root_height"

hmc_json["loggers"] = create_loggers(parameters2, arg)

return hmc_json

Expand Down
2 changes: 1 addition & 1 deletion torchtree/cli/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def create_loggers(parameters: list[str], arg) -> dict:
models = ["joint.jacobian", "joint", "like", "prior"]
if arg.coalescent:
models.append("coalescent")
if arg.coalescent in COALESCENT_PIECEWISE:
if arg.coalescent in COALESCENT_PIECEWISE and not arg.gmrf_integrated:
models.append('gmrf')
return [
{
Expand Down
2 changes: 1 addition & 1 deletion torchtree/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def read_dates_from_csv(input_file, date_format=None):
index_date = line.index('date')
break
for line in reader:
dates[line[index_name]] = line[index_date]
dates[line[index_name]] = float(line[index_date])

if date_format is not None:
res = re.split(r"[/-]", date_format)
Expand Down
19 changes: 9 additions & 10 deletions torchtree/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,20 @@ def _call(self, *args, **kwargs) -> torch.Tensor:

@property
def event_shape(self) -> torch.Size:
return self.dist(
**{name: p.tensor for name, p in self.dict_parameters.items()},
**self.kwargs,
).event_shape
return self.distribution.event_shape

@property
def batch_shape(self) -> torch.Size:
return self.dist(
**{name: p.tensor for name, p in self.dict_parameters.items()},
**self.kwargs,
).batch_shape
return self.distribution.batch_shape

def _sample_shape(self) -> torch.Size:
offset = 1 if len(self.batch_shape) == 0 else len(self.batch_shape)
return self.x.tensor.shape[:-offset]
x_shape = self.x.tensor.shape
if len(x_shape) > len(self.batch_shape):
offset = 1 if len(self.batch_shape) == 0 else len(self.batch_shape)
return x_shape[:-offset]
else:
# the distribution is a likelihood term
return self.batch_shape[: -len(x_shape)]

@property
def distribution(self) -> torch.distributions.Distribution:
Expand Down
7 changes: 7 additions & 0 deletions torchtree/distributions/joint_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,16 @@ def log_prob(self, x: Union[list[Parameter], Parameter] = None) -> torch.Tensor:
lp = distr()
sample_shape = distr.sample_shape
if lp.shape == sample_shape:
# [...] -> [...,1]
log_p.append(lp.unsqueeze(-1))
elif lp.shape == torch.Size([]):
# [] -> [1]
log_p.append(lp.unsqueeze(0))
elif len(lp.shape) - len(sample_shape) > 0:
# [..., x, y] -> [..., x*y].sum(-1)
log_p.append(
lp.view(lp.shape[: len(sample_shape)] + (-1,)).sum(-1, keepdim=True)
)
elif lp.shape[-1] != 1:
log_p.append(lp.sum(-1, keepdim=True))
elif lp.dim() == 1:
Expand Down
13 changes: 11 additions & 2 deletions torchtree/evolution/substitution_model/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ def from_json(cls, data, dic):
data_type = process_object(data['data_type'], dic)
rates = process_object(data['rates'], dic)
frequencies = process_object(data['frequencies'], dic)
mapping = process_object(data['mapping'], dic)
if 'mapping' not in data:
mapping_count = data_type.state_count * (data_type.state_count - 1) // 2
mapping = Parameter(None, torch.arange(mapping_count))
elif isinstance(data['mapping'], list):
mapping = Parameter(None, torch.tensor(data['mapping']))
else:
mapping = process_object(data['mapping'], dic)
return cls(id_, data_type, mapping, rates, frequencies)


Expand Down Expand Up @@ -187,7 +193,10 @@ def from_json(cls, data, dic):
data_type = process_object(data['data_type'], dic)
rates = process_object(data['rates'], dic)
frequencies = process_object(data['frequencies'], dic)
if isinstance(data['mapping'], list):
if 'mapping' not in data:
mapping_count = data_type.state_count * (data_type.state_count - 1)
mapping = Parameter(None, torch.arange(mapping_count))
elif isinstance(data['mapping'], list):
mapping = Parameter(None, torch.tensor(data['mapping']))
else:
mapping = process_object(data['mapping'], dic)
Expand Down
2 changes: 1 addition & 1 deletion torchtree/evolution/substitution_model/nucleotide.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def p_t_analytical(self, branch_lengths: torch.Tensor) -> torch.Tensor:

def q(self) -> torch.Tensor:
if len(self.frequencies.shape) == 1:
pi = self.frequencies.unsqueeze(0)
pi = self.frequencies.expand(self.kappa.shape[:-1] + (4,)).unsqueeze(-2)
else:
pi = self.frequencies.unsqueeze(-2)
kappa = self.kappa
Expand Down
3 changes: 3 additions & 0 deletions torchtree/evolution/tree_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ def _call(self, *args, **kwargs) -> torch.Tensor:
rates = rates.reshape(sample_shape + (1, -1))
probs = self.site_model.probabilities().unsqueeze(-1).unsqueeze(-1)
if self.clock_model is None:
if branch_lengths.dim() == 1:
branch_lengths = branch_lengths.expand(sample_shape + (-1,))

bls = torch.cat(
(
branch_lengths,
Expand Down
1 change: 0 additions & 1 deletion torchtree/inference/hmc/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def _step(self) -> Tensor:
trial = 0
while trial < max_trials:
momentum = self._hamiltonian.sample_momentum(self.mass_matrix)
ok = True
try:
kinetic_energy0 = self._hamiltonian.kinetic_energy(
momentum, self.inverse_mass_matrix
Expand Down
49 changes: 41 additions & 8 deletions torchtree/inference/mcmc/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@

@register_class
class MCMC(Identifiable, Runnable):
r"""Represents a Markov Chain Monte Carlo (MCMC) algorithm for inference.
:param id_: The ID of the MCMC instance.
:type id_: ID
:param joint: The joint model used for inference.
:type joint: CallableModel
:param operators: The list of MCMC operators.
:type operators: List[MCMCOperator]
:param int iterations: The number of iterations for the MCMC algorithm.
:param dict kwargs: Additional keyword arguments.
"""

def __init__(
self,
id_: ID,
Expand All @@ -28,6 +40,8 @@ def __init__(
iterations: int,
**kwargs,
) -> None:
"""Initialize an instance of the MCMC class."""

Identifiable.__init__(self, id_)
self._operators = operators
self.joint = joint
Expand All @@ -43,6 +57,7 @@ def __init__(
self.parameters.extend(parameter.parameters())

def run(self) -> None:
"""Run the MCMC algorithm."""
accept = 0

for logger in self.loggers:
Expand Down Expand Up @@ -129,11 +144,13 @@ def run(self) -> None:
)

def state_dict(self) -> dict[str, Any]:
"""Returns the current state of the MCMC object as a dictionary."""
states = {"iteration": self._epoch}
states["operators"] = [op.state_dict() for op in self._operators]
return states

def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""Load the state dictionary into the MCMC algorithm."""
for op in self._operators:
for op_state in state_dict["operators"]:
if op.id == op_state["id"]:
Expand All @@ -142,6 +159,7 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
self._epoch = state_dict["iteration"]

def save_full_state(self) -> None:
"""Save the full state of the MCMC algorithm."""
mcmc_state = {
"id": self.id,
"type": "MCMC",
Expand All @@ -152,13 +170,30 @@ def save_full_state(self) -> None:

@classmethod
def from_json(cls, data: dict[str, Any], dic: dict[str, Any]) -> MCMC:
r"""Creates an MCMC instance from a dictionary.
:param dict[str, Any] data: dictionary representation of a parameter object.
:param dict[str, Identifiable] dic: dictionary containing torchtree objects
keyed by their ID.
**JSON attributes**:
Mandatory:
- id (str): identifier of object.
- joint (str or dict): joint distribution of interest implementing CallableModel.
- operators (list of dict): list of operators implementing MCMCOperator.
- iterations (int): number of iterations.
Optional:
- loggers (list of dict): list of loggers implementing MCMCOperator.
- checkpoint (bool or str): checkpoint file name (Default: checkpoint.json).
No checkpointing if False is specified.
- checkpoint_frequency (int): frequency of checkpointing (Default: 1000).
- every (int): on-screen logging frequency (Default: 100).
"""
iterations = data["iterations"]

optionals = {}
# checkpointing is used by default and the default file name is checkpoint.json
# it can be disabled if "checkpoint": false is used
# the name of the checkpoint file can be modified using
# "checkpoint": "checkpointer.json"
if "checkpoint" in data:
if isinstance(data["checkpoint"], bool) and data["checkpoint"]:
optionals["checkpoint"] = "checkpoint.json"
Expand All @@ -167,8 +202,7 @@ def from_json(cls, data: dict[str, Any], dic: dict[str, Any]) -> MCMC:
else:
optionals["checkpoint"] = "checkpoint.json"

if "checkpoint_frequency" in data:
optionals["checkpoint_frequency"] = data["checkpoint_frequency"]
optionals["checkpoint_frequency"] = data.get("checkpoint_frequency", 1000)

if "loggers" in data:
loggers = process_objects(data["loggers"], dic)
Expand All @@ -180,7 +214,6 @@ def from_json(cls, data: dict[str, Any], dic: dict[str, Any]) -> MCMC:

operators = process_objects(data["operators"], dic)

if "every" in data:
optionals["every"] = data["every"]
optionals["every"] = data.get("every", 100)

return cls(data["id"], joint, operators, iterations, **optionals)
3 changes: 1 addition & 2 deletions torchtree/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,7 @@ def closure():

def _run(self) -> None:
for logger in self.loggers:
if hasattr(logger, 'init'):
logger.init()
logger.initialize()

handler = SignalHandler()
if self.convergence is not None:
Expand Down
1 change: 1 addition & 0 deletions torchtree/torchtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def main():
# now we update the state_dict of the algorithms (e.g. Optimizer, MCMC)
if (
arg.checkpoint is not None
and hasattr(obj, "id")
and obj.id in others
and hasattr(obj, "load_state_dict")
):
Expand Down

0 comments on commit 8071cd3

Please sign in to comment.