Skip to content

Commit

Permalink
MarkovChain: Add *_classes_indices which return indices
Browse files Browse the repository at this point in the history
- Change *_classes to return values
- Remove get_*_classes
  • Loading branch information
oyamad committed Apr 11, 2016
1 parent 78e7e67 commit f0c681f
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 80 deletions.
84 changes: 35 additions & 49 deletions quantecon/markov/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,41 @@ class MarkovChain(object):
num_communication_classes : int
The number of the communication classes.
communication_classes : list(ndarray(int))
List of numpy arrays containing the communication classes.
communication_classes_indices : list(ndarray(int))
List of numpy arrays containing the indices of the communication
classes.
communication_classes : list(ndarray)
List of numpy arrays containing the communication classes, where
the states are annotated with their values (if `state_values` is
not None).
num_recurrent_classes : int
The number of the recurrent classes.
recurrent_classes : list(ndarray(int))
List of numpy arrays containing the recurrent classes.
recurrent_classes_indices : list(ndarray(int))
List of numpy arrays containing the indices of the recurrent
classes.
recurrent_classes : list(ndarray)
List of numpy arrays containing the recurrent classes, where the
states are annotated with their values (if `state_values` is not
None).
is_aperiodic : bool
Indicate whether the Markov chain is aperiodic.
period : int
The period of the Markov chain.
cyclic_classes : list(ndarray(int))
List of numpy arrays containing the cyclic classes. Defined only
when the Markov chain is irreducible.
cyclic_classes_indices : list(ndarray(int))
List of numpy arrays containing the indices of the cyclic
classes. Defined only when the Markov chain is irreducible.
cyclic_classes : list(ndarray)
List of numpy arrays containing the cyclic classes, where the
states are annotated with their values (if `state_values` is not
None). Defined only when the Markov chain is irreducible.
Notes
-----
Expand Down Expand Up @@ -315,28 +332,26 @@ def is_irreducible(self):
def num_communication_classes(self):
return self.digraph.num_strongly_connected_components

@property
def communication_classes_indices(self):
return self.digraph.strongly_connected_components_indices

@property
def communication_classes(self):
return self.digraph.strongly_connected_components

def get_communication_classes(self, return_values=True):
return self.digraph.get_strongly_connected_components(
return_labels=return_values
)

@property
def num_recurrent_classes(self):
return self.digraph.num_sink_strongly_connected_components

@property
def recurrent_classes_indices(self):
return self.digraph.sink_strongly_connected_components_indices

@property
def recurrent_classes(self):
return self.digraph.sink_strongly_connected_components

def get_recurrent_classes(self, return_values=True):
return self.digraph.get_sink_strongly_connected_components(
return_labels=return_values
)

@property
def is_aperiodic(self):
if self.is_irreducible:
Expand Down Expand Up @@ -368,15 +383,14 @@ def cyclic_classes(self):
else:
return self.digraph.cyclic_components

def get_cyclic_classes(self, return_values=True):
@property
def cyclic_classes_indices(self):
if not self.is_irreducible:
raise NotImplementedError(
'Not defined for a reducible Markov chain'
)
else:
return self.digraph.get_cyclic_components(
return_labels=return_values
)
return self.digraph.cyclic_components_indices

def _compute_stationary(self):
"""
Expand Down Expand Up @@ -652,34 +666,6 @@ def _generate_sample_paths_sparse(P_cdfs1d, indices, indptr, init_states,
out[i, t+1] = indices[indptr[out[i, t]]+k]


_get_method_docstr = \
"""
Return a list of numpy arrays containing the {classes}.
Parameters
----------
return_values : bool(optional, default=True)
Whether to annotate the returned states with `state_values`.
Returns
-------
list(ndarray)
If `return_values=True`, and if `state_values` is not None,
each ndarray contains the state values, and the state indices
(integers) otherwise.
"""

MarkovChain.get_communication_classes.__doc__ = \
_get_method_docstr.format(classes='communication classes')

MarkovChain.get_recurrent_classes.__doc__ = \
_get_method_docstr.format(classes='recurrent classes')

MarkovChain.get_cyclic_classes.__doc__ = \
_get_method_docstr.format(classes='cyclic classes')


def mc_compute_stationary(P):
"""
Computes stationary distributions of P, one for each recurrent
Expand Down
54 changes: 23 additions & 31 deletions quantecon/markov/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,51 +406,43 @@ def test_com_rec_classes(self):
mc = mc_dict['mc']
coms = mc_dict['coms']
recs = mc_dict['recs']
methods = ['get_communication_classes',
'get_recurrent_classes']
for method, classes_ind in zip(methods, [coms, recs]):
for return_values in [True, False]:
if return_values:
classes = [self.state_values[i] for i in classes_ind]
key = lambda x: x[0, 0]
else:
properties = ['communication_classes',
'recurrent_classes']
suffix = '_indices'
for prop0, classes_ind in zip(properties, [coms, recs]):
for return_indices in [True, False]:
if return_indices:
classes = classes_ind
prop = prop0 + suffix
key = lambda x: x[0]
else:
classes = [self.state_values[i] for i in classes_ind]
prop = prop0
key = lambda x: x[0, 0]
list_of_array_equal(
sorted(getattr(mc, method)(return_values), key=key),
sorted(getattr(mc, prop), key=key),
sorted(classes, key=key)
)
# Check that the default of return_values is True
classes = [self.state_values[i] for i in classes_ind]
key = lambda x: x[0, 0]
list_of_array_equal(
sorted(getattr(mc, method)(), key=key),
sorted(classes, key=key)
)

def test_cyc_classes(self):
mc = self.mc_periodic_dict['mc']
cycs = self.mc_periodic_dict['cycs']
methods = ['get_cyclic_classes']
for method, classes_ind in zip(methods, [cycs]):
for return_values in [True, False]:
if return_values:
classes = [self.state_values[i] for i in classes_ind]
key = lambda x: x[0, 0]
else:
properties = ['cyclic_classes']
suffix = '_indices'
for prop0, classes_ind in zip(properties, [cycs]):
for return_indices in [True, False]:
if return_indices:
classes = classes_ind
prop = prop0 + suffix
key = lambda x: x[0]
else:
classes = [self.state_values[i] for i in classes_ind]
prop = prop0
key = lambda x: x[0, 0]
list_of_array_equal(
sorted(getattr(mc, method)(return_values), key=key),
sorted(getattr(mc, prop), key=key),
sorted(classes, key=key)
)
# Check that the default of return_values is True
classes = [self.state_values[i] for i in classes_ind]
key = lambda x: x[0, 0]
list_of_array_equal(
sorted(getattr(mc, method)(), key=key),
sorted(classes, key=key)
)

def test_simulate(self):
# Deterministic mc
Expand Down

0 comments on commit f0c681f

Please sign in to comment.