Skip to content

Commit

Permalink
Added 'where' clause and corresponding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenjp committed May 16, 2024
1 parent 6c86afc commit d2d82a6
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 10 deletions.
17 changes: 11 additions & 6 deletions q2_diversity/_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,21 @@ def filter_distance_matrix(distance_matrix: skbio.DistanceMatrix,
"All samples were filtered out of the distance matrix.")


# This function filters the SampleData[AlphaDiversity] table by the metadata
# only sample present in the metadata will remain
"""
This function filters the SampleData[AlphaDiversity] table by the metadata
only sample present in the metadata will remain
"""


def filter_alpha_diversity_artifact(alpha_diversity: pd.Series,
metadata: qiime2.Metadata,
where: str = None,
exclude_ids: bool = False) -> pd.Series:
ids_to_keep = metadata.get_ids()
ids_to_keep = metadata.get_ids(where=where)
if exclude_ids:
ids_to_keep = set(alpha_diversity.index) - set(ids_to_keep)
filtered_table = alpha_diversity[alpha_diversity.index.isin(ids_to_keep)]
if filtered_table.empty:
filtered_metric = alpha_diversity[alpha_diversity.index.isin(ids_to_keep)]
if filtered_metric.empty:
raise ValueError(
"All samples were filtered out of the alpha diversity artifact.")
return filtered_table
return filtered_metric
14 changes: 10 additions & 4 deletions q2_diversity/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,8 @@
},
parameters={
'metadata': Metadata,
'exclude_ids': Bool
'exclude_ids': Bool,
'where': Str
},
outputs=[
('filtered_alpha_diversity_artifact', SampleData[AlphaDiversity])
Expand All @@ -669,9 +670,14 @@
'metadata': 'Sample metadata used to select samples to retain from '
'the sample data (default) or select samples to exclude '
'using the `exclude_ids` parameter.',
'exclude_ids': 'If `True`, the samples selected by `metadata` '
'will be excluded from the filtered '
'sample data instead of being retained.'
'where': 'SQLite WHERE clause specifying sample metadata criteria '
'that must be met to be included in the filtered alpha '
'diversity artifact. If not provided, all samples in '
'`metadata` that are also in the input alpha diversity '
'artifact will be retained.',
'exclude_ids': 'If `True`, the samples selected by `metadata` or the '
'`where` parameters will be excluded from the filtered '
'alpha diversity artifact instead of being retained.'
}
)

Expand Down
73 changes: 73 additions & 0 deletions q2_diversity/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,79 @@ def test_filter_alpha_diversity_artifact_exclude_ids_some_filtered(self):
expected = pd.Series([3.0], index=['S3'])
self.assertTrue(filtered.sort_values().equals(expected.sort_values()))

def test_filter_alpha_diversity_artifact_exclude_ids_none_filtered(self):
df = pd.DataFrame({'Subject': ['subject-1', 'subject-1'],
'SampleType': ['gut', 'tongue']},
index=pd.Index(['S1', 'S2'], name='id'))
metadata = qiime2.Metadata(df)

alpha_diversity = pd.Series([1.0, 2.0], index=['S3', 'S4'])

filtered = filter_alpha_diversity_artifact(alpha_diversity, metadata,
exclude_ids=True)

self.assertTrue(filtered.equals(alpha_diversity))

def test_filter_alpha_diversity_artifact_test_where_no_filtering(self):
df = pd.DataFrame({'Subject': ['subject-1', 'subject-1', 'subject-2'],
'SampleType': ['gut', 'tongue', 'gut']},
index=pd.Index(['S1', 'S2', 'S3'], name='id'))
metadata = qiime2.Metadata(df)

alpha_diversity = pd.Series([1.0, 2.0, 3.0], index=['S1', 'S2', 'S3'])

filtered = filter_alpha_diversity_artifact(
alpha_diversity,
metadata,
where="SampleType='gut' OR SampleType='tongue'"
)

self.assertTrue(filtered.equals(alpha_diversity))

def test_filter_alpha_diversity_artifact_test_where_some_filtered(self):
df = pd.DataFrame({'Subject': ['subject-1', 'subject-1'],
'SampleType': ['gut', 'tongue']},
index=pd.Index(['S1', 'S2'], name='id'))
metadata = qiime2.Metadata(df)

alpha_diversity = pd.Series([1.0, 2.0, 3.0], index=['S1', 'S2', 'S3'])

filtered = filter_alpha_diversity_artifact(alpha_diversity, metadata,
where="SampleType='gut'")

expected = pd.Series([1.0], index=['S1'])
self.assertTrue(filtered.sort_values().equals(expected.sort_values()))

def test_filter_alpha_diversity_artifact_test_where_all_filtered(self):
df = pd.DataFrame({'Subject': ['subject-1', 'subject-1', 'subject-2'],
'SampleType': ['gut', 't', 'gut']},
index=pd.Index(['S1', 'S2', 'S3'], name='id'))
metadata = qiime2.Metadata(df)

alpha_diversity = pd.Series([1.0, 2.0, 3.0], index=['S1', 'S2', 'S3'])

with self.assertRaisesRegex(ValueError, "All samples.*filtered"):
filter_alpha_diversity_artifact(alpha_diversity, metadata,
where="SampleType='palm'")

def test_filter_alpha_diversity_artifact_test_where_extra_ids(self):
df = pd.DataFrame({'Subject': ['subject-1', 'subject-1', 'subject-2',
'subject-2'],
'SampleType': ['gut', 'tongue', 'gut', 'tongue']},
index=pd.Index(['S1', 'S4', 'S2', 'S5'], name='id'))
metadata = qiime2.Metadata(df)

alpha_diversity = pd.Series([1.0, 2.0, 3.0], index=['S1', 'S2', 'S3'])

filtered = filter_alpha_diversity_artifact(
alpha_diversity,
metadata,
where="SampleType='gut' OR SampleType='tongue'"
)

expected = pd.Series([1.0, 2.0], index=['S1', 'S2'])
self.assertTrue(filtered.sort_values().equals(expected.sort_values()))


if __name__ == "__main__":
unittest.main()

0 comments on commit d2d82a6

Please sign in to comment.