Skip to content

Commit

Permalink
Add stats features (#518)
Browse files Browse the repository at this point in the history
* add features to stats and filterby

* add documentation

* Update views.py

* add unit tests

* response to review
  • Loading branch information
nwlandry authored Mar 3, 2024
1 parent 2e3a757 commit b6201d6
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 4 deletions.
4 changes: 3 additions & 1 deletion docs/source/api/stats/xgi.stats.DiEdgeStat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@
~DiEdgeStat.min
~DiEdgeStat.std
~DiEdgeStat.var
~DiEdgeStat.moment
~DiEdgeStat.moment
~DiEdgeStat.argmin
~DiEdgeStat.argmax
2 changes: 2 additions & 0 deletions docs/source/api/stats/xgi.stats.DiNodeStat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@
~DiNodeStat.std
~DiNodeStat.var
~DiNodeStat.moment
~DiNodeStat.argmin
~DiNodeStat.argmax
4 changes: 3 additions & 1 deletion docs/source/api/stats/xgi.stats.EdgeStat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@
~EdgeStat.min
~EdgeStat.std
~EdgeStat.var
~EdgeStat.moment
~EdgeStat.moment
~EdgeStat.argmin
~EdgeStat.argmax
4 changes: 3 additions & 1 deletion docs/source/api/stats/xgi.stats.NodeStat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,6 @@
~NodeStat.min
~NodeStat.std
~NodeStat.var
~NodeStat.moment
~NodeStat.moment
~NodeStat.argmin
~NodeStat.argmax
6 changes: 6 additions & 0 deletions tests/stats/test_nodestats.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def test_aggregates(edgelist1, edgelist2, edgelist8):
H = xgi.Hypergraph(edgelist1)
assert H.nodes.degree.max() == 2
assert H.nodes.degree.min() == 1
assert H.nodes.degree.argmax() == 6
assert H.nodes.degree.argmin() == 1
assert H.nodes.degree.sum() == 9
assert round(H.nodes.degree.mean(), 3) == 1.125
assert round(H.nodes.degree.std(), 3) == 0.331
Expand All @@ -240,6 +242,8 @@ def test_aggregates(edgelist1, edgelist2, edgelist8):
H = xgi.Hypergraph(edgelist2)
assert H.nodes.degree.max() == 2
assert H.nodes.degree.min() == 1
assert H.nodes.degree.argmax() == 4
assert H.nodes.degree.argmin() == 1
assert H.nodes.degree.sum() == 7
assert round(H.nodes.degree.mean(), 3) == 1.167
assert round(H.nodes.degree.std(), 3) == 0.373
Expand All @@ -248,6 +252,8 @@ def test_aggregates(edgelist1, edgelist2, edgelist8):
H = xgi.Hypergraph(edgelist8)
assert H.nodes.degree.max() == 6
assert H.nodes.degree.min() == 2
assert H.nodes.degree.argmax() == 0
assert H.nodes.degree.argmin() == 5
assert H.nodes.degree.sum() == 26
assert round(H.nodes.degree.mean(), 3) == 3.714
assert round(H.nodes.degree.std(), 3) == 1.385
Expand Down
2 changes: 1 addition & 1 deletion xgi/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def filterby(self, stat, val, mode="eq"):
elif mode == "geq":
bunch = [idx for idx in self if values[idx] >= val]
elif mode == "between":
bunch = [node for node in self if val[0] <= values[node] <= val[1]]
bunch = [idx for idx in self if val[0] <= values[idx] <= val[1]]
else:
raise ValueError(
f"Unrecognized mode {mode}. mode must be one of "
Expand Down
28 changes: 28 additions & 0 deletions xgi/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,34 @@ def moment(self, order=2, center=False):
arr = self.asnumpy()
return spmoment(arr, moment=order) if center else np.mean(arr**order)

def argmin(self):
"""The ID corresponding to the minimum of the stat
When the minimum value is not unique, returns first
ID corresponding to the minimum value.
Returns
-------
hashable
The ID to which the minimum value corresponds.
"""
d = self.asdict()
return min(d, key=d.get)

def argmax(self):
"""The ID corresponding to the maximum of the stat
When the maximal value is not unique, returns first
ID corresponding to the maximal value.
Returns
-------
hashable
The ID to which the maximum value corresponds.
"""
d = self.asdict()
return max(d, key=d.get)


class NodeStat(IDStat):
"""An arbitrary node-quantity mapping.
Expand Down

0 comments on commit b6201d6

Please sign in to comment.