Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

frequencies: Calculate all pivot points based on end pivot #1150

Merged
merged 3 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

* translate: Fix error handling when features cannot be read from reference sequence file. [#1168][] (@victorlin)
* translate: Remove an unnecessary check which allowed for inaccurate error messages to be shown. [#1169][] (@victorlin)
* frequencies: Previously, monthly pivot points calculated from the end of a month may have been shifted by 1-3 days. This is now fixed. [#1150][] (@victorlin)

[#1150]: https://github.com/nextstrain/augur/pull/1150
[#1168]: https://github.com/nextstrain/augur/pull/1168
[#1169]: https://github.com/nextstrain/augur/pull/1169

Expand Down
2 changes: 1 addition & 1 deletion augur/frequency_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_pivots(observations, pivot_interval, start_date=None, end_date=None, piv
pivot = end
while pivot >= start:
pivots.appendleft(pivot)
pivot = pivot - delta
pivot = end - delta * len(pivots)

pivots = np.array([numeric_date(pivot) for pivot in pivots])

Expand Down
45 changes: 45 additions & 0 deletions tests/test_frequencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,51 @@ def test_get_pivots_by_invalid_unit():
with pytest.raises(ValueError, match=r".*invalid_unit.*is not supported.*"):
pivots = get_pivots(observations=[], pivot_interval=1, start_date=2015.0, end_date=2016.0, pivot_interval_units="invalid_unit")


@pytest.mark.parametrize(
"start, end, expected_pivots",
[
(
"2022-01-01",
"2022-04-01",
("2022-01-01", "2022-02-01", "2022-03-01", "2022-04-01")
),
(
"2022-01-31",
"2022-03-31",
("2022-01-31", "2022-02-28", "2022-03-31")
),
# Note that Jan 31 to Apr 30 gives the same amount of pivot points as
# Jan 31 to Mar 31.
(
"2022-01-31",
"2022-04-30",
("2022-02-28", "2022-03-30", "2022-04-30")
),
# However, in practice, the interval is more likely to be Jan 30 to Apr
# 30 as long as the start date is calculated relative to the end date
# (i.e. start date = 3 months before Apr 30 = Jan 30).
# That interval includes an additional pivot point as expected.
(
"2022-01-30",
"2022-04-30",
("2022-01-30", "2022-02-28", "2022-03-30", "2022-04-30")
),
]
)
def test_get_pivots_on_month_boundaries(start, end, expected_pivots):
"""Get pivots where the start/end dates are on month boundaries.
"""
pivots = get_pivots(
observations=[],
pivot_interval=1,
start_date=numeric_date(start),
end_date=numeric_date(end),
pivot_interval_units="months"
)
assert len(pivots) == len(expected_pivots)
assert np.allclose(pivots, [numeric_date(date) for date in expected_pivots], rtol=0, atol=1e-4)

#
# Test KDE frequency estimation for trees
#
Expand Down