Skip to content

Commit

Permalink
use new refactored validation function
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisvittal committed Feb 7, 2025
1 parent d4fff08 commit 4999e0f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 51 deletions.
48 changes: 6 additions & 42 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import itertools
import operator
import os.path
from typing import Any, Callable, Iterable, Optional, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -7104,60 +7105,23 @@ def query_matrix_table(path, point_or_interval, entries_name='entries_array'):
matrix_table = hl.read_matrix_table(path)
if entries_name in matrix_table.row:
raise ValueError(
f'query_matrix_table: field "{entries_name}" is present in matrix table row fields, pick a different `entries_name`'
f'field "{entries_name}" is present in matrix table row fields, use a different `entries_name`'
)

entries_table = hl.read_table(path + '/entries')
entries_table = hl.read_table(os.path.join(path, '/entries'))
[entry_id] = list(entries_table.row)

full_row_type = tstruct(**matrix_table.row.dtype, **entries_table.row.dtype)
key_typ = matrix_table.row_key.dtype
key_names = list(key_typ)
len = builtins.len

def coerce_endpoint(point):
if point.dtype == key_typ[0]:
point = hl.struct(**{key_names[0]: point})
ts = point.dtype
if isinstance(ts, tstruct):
i = 0
while i < len(ts):
if i >= len(key_typ):
raise ValueError(
f"query_matrix_table: queried with {len(ts)} row key field(s), but matrix table only has {len(key_typ)} row key field(s)"
)
if key_typ[i] != ts[i]:
raise ValueError(
f"query_matrix_table: mismatch at row key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, matrix table row key type is {key_typ[i]}"
)
i += 1

if i == 0:
raise ValueError("query_matrix_table: cannot query with empty row key")

point_size = builtins.len(point.dtype)
return hl.tuple([
hl.struct(**{
key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i]))
for i in builtins.range(builtins.len(key_typ))
}),
hl.int32(point_size),
])
else:
raise ValueError(
f"query_matrix_table: row key mismatch: cannot query a matrix table with row key "
f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}"
)

if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval):
partition_interval = hl.interval(
start=coerce_endpoint(point_or_interval.start),
end=coerce_endpoint(point_or_interval.end),
start=__validate_and_coerce_endpoint(point_or_interval.start, key_typ),
end=__validate_and_coerce_endpoint(point_or_interval.end, key_typ),
includes_start=point_or_interval.includes_start,
includes_end=point_or_interval.includes_end,
)
else:
point = coerce_endpoint(point_or_interval)
point = __validate_and_coerce_endpoint(point_or_interval, key_typ)
partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True)
read_part_ir = ir.ReadPartition(
partition_interval._ir, reader=ir.PartitionZippedNativeIntervalReader(path, full_row_type)
Expand Down
18 changes: 9 additions & 9 deletions hail/python/test/hail/matrixtable/test_matrix_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2398,16 +2398,16 @@ def query_mt_mt():


def test_query_matrix_table_errors(query_mt_mt):
with pytest.raises(ValueError, match='query_matrix_table: field "s" is present'):
with pytest.raises(ValueError, match='field "s" is present'):
hl.query_matrix_table(query_mt_mt, 0, 's')
with pytest.raises(ValueError, match='query_matrix_table: mismatch at row key field'):
hl.query_matrix_table(query_mt_mt, hl.interval('1', '2'))
with pytest.raises(ValueError, match='query_matrix_table: row key mismatch: cannot query'):
hl.query_matrix_table(query_mt_mt, '1')
with pytest.raises(ValueError, match='query_matrix_table: cannot query with empty row key'):
hl.query_matrix_table(query_mt_mt, hl.struct())
with pytest.raises(ValueError, match='query_matrix_table: queried with 2 row key field'):
hl.query_matrix_table(query_mt_mt, hl.struct(row_idx=5, foo='s'))
with pytest.raises(ValueError, match='key mismatch: cannot use'):
hl.query_table(query_mt_mt, hl.interval('1', '2'))
with pytest.raises(ValueError, match='key mismatch: cannot use'):
hl.query_table(query_mt_mt, '1')
with pytest.raises(ValueError, match='query point value cannot be an empty struct'):
hl.query_table(query_mt_mt, hl.struct())
with pytest.raises(ValueError, match='query point type has 2 field'):
hl.query_table(query_mt_mt, hl.struct(idx=5, foo='s'))


def query_matrix_table_test_parameters():
Expand Down

0 comments on commit 4999e0f

Please sign in to comment.