Skip to content

Commit

Permalink
Merge pull request #227 from galenseilis/patch-6
Browse files Browse the repository at this point in the history
Mixtures and __repr__ enhancements.
  • Loading branch information
geraintpalmer authored Apr 3, 2024
2 parents 782075c + 0978e28 commit 7747d8a
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 72 deletions.
121 changes: 104 additions & 17 deletions ciw/dists/distributions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ciw.auxiliary import *
from itertools import cycle
import numpy as np
'''Distributions available in Ciw.'''

import copy
from itertools import cycle
from operator import add, mul, sub, truediv
from random import (
expovariate,
Expand All @@ -11,7 +11,12 @@
lognormvariate,
weibullvariate,
)
from typing import List, NoReturn

import numpy as np

from ciw.auxiliary import *
from ciw.individual import Individual

class Distribution(object):
"""
Expand Down Expand Up @@ -99,7 +104,7 @@ def __init__(self, lower, upper):
self.upper = upper

def __repr__(self):
return "Uniform: {0}, {1}".format(self.lower, self.upper)
return f"Uniform({self.lower}, {self.upper})"

def sample(self, t=None, ind=None):
return uniform(self.lower, self.upper)
Expand All @@ -121,7 +126,7 @@ def __init__(self, value):
self.value = value

def __repr__(self):
return "Deterministic: {0}".format(self.value)
return f"Deterministic({self.value})"

def sample(self, t=None, ind=None):
return self.value
Expand Down Expand Up @@ -151,7 +156,7 @@ def __init__(self, lower, mode, upper):
self.upper = upper

def __repr__(self):
return "Triangular: {0}, {1}, {2}".format(self.lower, self.mode, self.upper)
return f"Triangular({self.lower}, {self.mode}, {self.upper})"

def sample(self, t=None, ind=None):
return triangular(self.lower, self.upper, self.mode)
Expand All @@ -173,7 +178,7 @@ def __init__(self, rate):
self.rate = rate

def __repr__(self):
return "Exponential: {0}".format(self.rate)
return f"Exponential({self.rate})"

def sample(self, t=None, ind=None):
return expovariate(self.rate)
Expand All @@ -193,7 +198,7 @@ def __init__(self, shape, scale):
self.scale = scale

def __repr__(self):
return "Gamma: {0}, {1}".format(self.shape, self.scale)
return f"Gamma({self.shape}, {self.scale})"

def sample(self, t=None, ind=None):
return gammavariate(self.shape, self.scale)
Expand All @@ -213,7 +218,7 @@ def __init__(self, mean, sd):
self.sd = sd

def __repr__(self):
return "Normal: {0}, {1}".format(self.mean, self.sd)
return f"Normal({self.mean}, {self.sd})"

def sample(self, t=None, ind=None):
return truncated_normal(self.mean, self.sd)
Expand All @@ -233,7 +238,7 @@ def __init__(self, mean, sd):
self.sd = sd

def __repr__(self):
return "Lognormal: {0}, {1}".format(self.mean, self.sd)
return f"Lognormal({self.mean}, {self.sd})"

def sample(self, t=None, ind=None):
return lognormvariate(self.mean, self.sd)
Expand All @@ -253,7 +258,7 @@ def __init__(self, scale, shape):
self.shape = shape

def __repr__(self):
return "Weibull: {0}, {1}".format(self.scale, self.shape)
return f"Weibull({self.scale}, {self.shape})"

def sample(self, t=None, ind=None):
return weibullvariate(self.scale, self.shape)
Expand Down Expand Up @@ -298,7 +303,10 @@ def __init__(self, sequence):
self.generator = cycle(self.sequence)

def __repr__(self):
return "Sequential"
if len(self.sequence) <= 3:
return f"Sequential({self.sequence})"
else:
return f"Sequential({self.sequence[0]}, ..., {self.sequence[-1]})"

def sample(self, t=None, ind=None):
return next(self.generator)
Expand All @@ -324,7 +332,7 @@ def __init__(self, values, probs):
self.probs = probs

def __repr__(self):
return "Pmf"
return f"Pmf({self.values}, {self.probs})"

def sample(self, t=None, ind=None):
return random_choice(self.values, self.probs)
Expand Down Expand Up @@ -420,7 +428,7 @@ def __init__(self, rate, num_phases):
super().__init__(initial_state, absorbing_matrix)

def __repr__(self):
return f"Erlang: {self.rate}, {self.num_phases}"
return f"Erlang({self.rate}, {self.num_phases})"


class HyperExponential(PhaseType):
Expand Down Expand Up @@ -611,7 +619,7 @@ def sample(self, t=None, ind=None):
return ciw.rng.poisson(lam=self.rate)

def __repr__(self):
return f"Poisson: {self.rate}"
return f"Poisson({self.rate})"


class Geometric(Distribution):
Expand All @@ -634,7 +642,7 @@ def sample(self, t=None, ind=None):
return ciw.rng.geometric(p=self.prob)

def __repr__(self):
return f"Geometric: {self.prob}"
return f"Geometric({self.prob})"


class Binomial(Distribution):
Expand Down Expand Up @@ -663,4 +671,83 @@ def sample(self, t=None, ind=None):
return ciw.rng.binomial(n=self.n, p=self.prob)

def __repr__(self):
return f"Binomial: {self.n}, {self.prob}"
return f"Binomial({self.n}, {self.prob})"


class MixtureDistribution(Distribution):
"""
A mixture distribution combining multiple probability distributions.
Parameters
----------
dists : List[Distribution]
A list of probability distributions to be combined in the mixture.
rhos : List[float]
A list of weights corresponding to the importance of each distribution in the mixture.
The weights must sum to 1.
Attributes
----------
rhos : List[float]
List of weights assigned to each distribution in the mixture.
dists : List[Distribution]
List of probability distributions in the mixture.
Methods
-------
sample(t: float, inds: List[Individual] = None) -> float:
Generate a random sample from the mixture distribution.
Notes
-----
The weights in `rhos` should sum to 1, indicating the relative importance of each distribution
in the mixture. The distributions in `dists` should be instances of `ciw.dists.Distribution`.
"""

def __init__(self, dists: List[Distribution], rhos: List[float]) -> NoReturn:
"""
Initialize the MixtureDistribution.
Parameters
----------
dists : List[Distribution]
A list of probability distributions to be combined in the mixture.
rhos : List[float]
A list of weights corresponding to the importance of each distribution in the mixture.
The weights must sum to 1.
"""
self.rhos = rhos
self.dists = dists

def sample(self, t: float, inds: List[Individual] = None) -> float:
"""
Generate a random sample from the mixture distribution.
Parameters
----------
t : float
The time parameter for the sample generation.
inds : List[Individual], optional
List of individuals associated with the sample, if applicable.
Returns
-------
float
A random sample from the mixture distribution.
"""
chosen_dist = random.choices(
population=self.dists,
weights=self.rhos,
k=1)[0]

return chosen_dist.sample(t, inds)

def __repr__(self):

dist_strs = [f'{rho} * {dist}' for rho,dist in zip(self.rhos, self.dists)]

if len(dist_strs) <= 3:
inside = ', '.join(dist_strs)
return f"Mixture({inside})"
else:
return f"Mixture({dist_strs[0]}, ..., {dist_strs[-1]})"
Loading

0 comments on commit 7747d8a

Please sign in to comment.