Skip to content

Commit

Permalink
[query] Preserve field order in StructExpression.rename (#14807)
Browse files Browse the repository at this point in the history
Either set construction from an iterator or set difference is not order
preserving. If we iterate through the struct fields in order, then our
select inside of rename will preserve order for existing fields.

Brief description and justification of what this PR is doing.

## Security Assessment

Delete all except the correct answer:
- This change has no security impact

### Impact Description
Query only. Simply causes a single iteration to happen in a different
order.
  • Loading branch information
chrisvittal authored Feb 4, 2025
1 parent 61955df commit 3080087
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
6 changes: 3 additions & 3 deletions hail/python/hail/expr/expressions/typed_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2064,8 +2064,7 @@ def select(self, *fields, **named_exprs):

if len(named_exprs) == 0:
return selected_expr
else:
return selected_expr.annotate(**named_exprs)
return selected_expr.annotate(**named_exprs)

@typecheck_method(mapping=dictof(str, str))
def rename(self, mapping):
Expand Down Expand Up @@ -2110,7 +2109,8 @@ def rename(self, mapping):
new_to_old[new] = old

return self.select(
*list(set(self._fields) - set(mapping)), **{new: self._get_field(old) for old, new in mapping.items()}
*[field for field in self._fields if field not in mapping],
**{new: self._get_field(old) for old, new in mapping.items()},
)

@typecheck_method(fields=str)
Expand Down
13 changes: 13 additions & 0 deletions hail/python/test/hail/expr/test_expr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math
import random
import string
import unittest

import numpy as np
Expand Down Expand Up @@ -4538,6 +4539,18 @@ def test_reservoir_sampling():
)


def test_struct_expr_rename_order():
keys = set(''.join(random.choice(string.ascii_letters) for _ in range(8)) for _ in range(10))
values = {k: random.randrange(30) for k in keys}
to_rename = set(x for i, x in enumerate(keys) if i % 2 == 1)
mapping = {old: ''.join(random.choice(string.ascii_letters) for _ in range(7)) for old in to_rename}
struct_expr = hl.struct(**values)
renamed_expr = struct_expr.rename(mapping)
common_left = struct_expr.drop(*to_rename)
common_right = renamed_expr.drop(*mapping.values())
assert common_left.dtype == common_right.dtype


def test_local_agg():
x = hl.literal([1, 2, 3, 4])
assert hl.eval(x.aggregate(lambda x: hl.agg.sum(x))) == 10
Expand Down

0 comments on commit 3080087

Please sign in to comment.