From 65c34b7be8525af11ef0a9ebc71b7f919c499afd Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Fri, 9 Feb 2024 11:50:27 +0100 Subject: [PATCH 01/26] Start working on issue, deactivate some breaking validation, add some todos --- xugrid/ugrid/partitioning.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index a14790a66..98073dd76 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -147,7 +147,7 @@ def merge_edges(grids, node_inverse): def validate_partition_topology(grouped, n_partition: int): n = n_partition - if not all(len(v) == n for v in grouped.values()): + if False: raise ValueError( f"Expected {n} UGRID topologies for {n} partitions, received: " f"{grouped}" ) @@ -182,6 +182,7 @@ def group_grids_by_name(partitions): def validate_partition_objects(data_objects): # Check presence of variables. + # TODO: Groupby gridtype, then test if variables present all grids per type. allvars = list({tuple(sorted(ds.data_vars)) for ds in data_objects}) if len(allvars) > 1: raise ValueError( @@ -200,7 +201,7 @@ def validate_partition_objects(data_objects): def separate_variables(data_objects, ugrid_dims): """Separate into UGRID variables grouped by dimension, and other variables.""" - validate_partition_objects(data_objects) + # validate_partition_objects(data_objects) def assert_single_dim(intersection): if len(intersection) > 1: @@ -216,6 +217,7 @@ def all_equal(iterator): return all(element == first for element in iterator) # Group variables by UGRID dimension. + #TODO: Take first grid per mesh type first = data_objects[0] variables = first.variables vardims = {var: tuple(first[var].dims) for var in variables} @@ -280,6 +282,7 @@ def merge_partitions(partitions): grids = [grid for p in partitions for grid in p.grids] ugrid_dims = {dim for grid in grids for dim in grid.dimensions} grids_by_name = group_grids_by_name(partitions) + # TODO: make sure 1D variables also in vars_by_dim vars_by_dim, other_vars = separate_variables(data_objects, ugrid_dims) # First, take identical non-UGRID variables from the first partition: @@ -291,6 +294,7 @@ def merge_partitions(partitions): for grids in grids_by_name.values(): # First, merge the grid topology. grid = grids[0] + # TODO: shortcut for length 1 merge_partitions merged_grid, indexes = grid.merge_partitions(grids) merged_grids.append(merged_grid) From 92ab0643f23fbde0d31b0e4860c443b9f2353213 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Mon, 12 Feb 2024 15:31:38 +0100 Subject: [PATCH 02/26] * Add ``sizes`` property * Add ``max_face_node_dimension`` property for Ugrid2D * Add ``max_connectivity_dimensions`` property --- xugrid/ugrid/ugrid2d.py | 10 ++++++++++ xugrid/ugrid/ugridbase.py | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/xugrid/ugrid/ugrid2d.py b/xugrid/ugrid/ugrid2d.py index fd744ece5..9c9c9c1c2 100644 --- a/xugrid/ugrid/ugrid2d.py +++ b/xugrid/ugrid/ugrid2d.py @@ -417,6 +417,16 @@ def dimensions(self): self.face_dimension: self.n_face, } + @property + def max_face_node_dimension(self): + return self._attrs["max_face_nodes_dimension"] + + @property + def max_connectivity_dimensions(self): + return { + self.max_face_node_dimension: self.n_max_node_per_face, + } + @property def topology_dimension(self): """Highest dimensionality of the geometric elements: 2""" diff --git a/xugrid/ugrid/ugridbase.py b/xugrid/ugrid/ugridbase.py index 0fd84782b..052c40d6c 100644 --- a/xugrid/ugrid/ugridbase.py +++ b/xugrid/ugrid/ugridbase.py @@ -299,6 +299,14 @@ def edge_dimension(self): """Name of edge dimension""" return self._attrs["edge_dimension"] + @property + def max_connectivity_dimensions(self): + return {} + + @property + def sizes(self): + return self.dimensions + @property def node_coordinates(self) -> FloatArray: """Coordinates (x, y) of the nodes (vertices)""" From 9866cb2aa9ec6a09c82346411826a8c19461b353 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Mon, 12 Feb 2024 15:45:36 +0100 Subject: [PATCH 03/26] Update changelog --- docs/changelog.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 34e327ea0..abbf62a41 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -31,6 +31,11 @@ Added UGRID topologies from "intervals": the (M + 1, N + 1) vertex coordinates for N faces. - :meth:`xugrid.UgridDataArrayAccessor.from_structured` now takes ``x`` and ``y`` arguments to specify which coordinates to use as the UGRID x and y coordinates. +- :attr:`xugrid.UgridDataset.sizes` as an alternative to :attr:`xugrid.UgridDataset.dimensions` +- :attr:`xugrid.Ugrid2d.max_face_node_dimension` which returns the dimension + name designating nodes per face. +- :attr:`xugrid.AbstractUgrid.max_connectivity_dimensions` which returns all + maximum connectivity dimensions and their corresponding size. Changed ~~~~~~~ From e2bd159663752d0efb04b7ed8ab61d80b6d73bbd Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Mon, 12 Feb 2024 15:49:02 +0100 Subject: [PATCH 04/26] * Support merging partitions with different grids per partition * Add function to maybe pad connectivity max dims * Group data objects and other vars by gridname --- xugrid/ugrid/partitioning.py | 90 +++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 32 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 98073dd76..85d0ea175 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -180,6 +180,22 @@ def group_grids_by_name(partitions): return grouped +def group_data_objects_by_gridname(partitions): + # Convert to dataset for convenience + data_objects = [partition.obj for partition in partitions] + data_objects = [ + obj.to_dataset() if isinstance(obj, xr.DataArray) else obj + for obj in data_objects + ] + + grouped = defaultdict(list) + for partition, obj in zip(partitions, data_objects): + for grid in partition.grids: + grouped[grid.name].append(obj) + + return grouped + + def validate_partition_objects(data_objects): # Check presence of variables. # TODO: Groupby gridtype, then test if variables present all grids per type. @@ -199,7 +215,7 @@ def validate_partition_objects(data_objects): ) -def separate_variables(data_objects, ugrid_dims): +def separate_variables(objects_by_gridname, ugrid_dims): """Separate into UGRID variables grouped by dimension, and other variables.""" # validate_partition_objects(data_objects) @@ -217,32 +233,44 @@ def all_equal(iterator): return all(element == first for element in iterator) # Group variables by UGRID dimension. - #TODO: Take first grid per mesh type - first = data_objects[0] - variables = first.variables - vardims = {var: tuple(first[var].dims) for var in variables} grouped = defaultdict(list) # UGRID associated vars - other = [] # other vars - for var, da in variables.items(): - shapes = (obj[var].shape for obj in data_objects) - - # Check if variable depends on UGRID dimension. - intersection = ugrid_dims.intersection(da.dims) - if intersection: - assert_single_dim(intersection) - # Now check whether the non-UGRID dimensions match. - dim = intersection.pop() # Get the single element in the set. - axis = vardims[var].index(dim) - shapes = [remove_item(shape, axis) for shape in shapes] - if all_equal(shapes): - grouped[dim].append(var) - - elif all_equal(shapes): - other.append(var) + other = defaultdict(list) # other vars + + for gridname, data_objects in objects_by_gridname.items(): + first = data_objects[0] + variables = first.variables + vardims = {var: tuple(first[var].dims) for var in variables} + for var, da in variables.items(): + shapes = (obj[var].shape for obj in data_objects) + + # Check if variable depends on UGRID dimension. + intersection = ugrid_dims.intersection(da.dims) + if intersection: + assert_single_dim(intersection) + # Now check whether the non-UGRID dimensions match. + dim = intersection.pop() # Get the single element in the set. + axis = vardims[var].index(dim) + shapes = [remove_item(shape, axis) for shape in shapes] + if all_equal(shapes): + grouped[dim].append(var) + + elif all_equal(shapes): + other[gridname].append(var) return grouped, other +def maybe_pad_connectivity_dims_to_max(selection, merged_grid): + nmax_dict = merged_grid.max_connectivity_dimensions + nmax_dict = {key: value for key, value in nmax_dict.items() if key in selection[0].dims} + if not nmax_dict: + return selection + + pad_width_ls = [{dim: (0, nmax - obj.sizes[dim]) for dim, nmax in nmax_dict.items()} for obj in selection] + + return [obj.pad(pad_width = pad_width) for obj, pad_width in zip(selection, pad_width_ls)] + + def merge_partitions(partitions): """ Merge topology and data, partitioned along UGRID dimensions, into a single @@ -272,27 +300,23 @@ def merge_partitions(partitions): if obj_type not in (UgridDataArray, UgridDataset): raise TypeError(msg.format(obj_type.__name__)) - # Convert to dataset for convenience - data_objects = [partition.obj for partition in partitions] - data_objects = [ - obj.to_dataset() if isinstance(obj, xr.DataArray) else obj - for obj in data_objects - ] # Collect grids grids = [grid for p in partitions for grid in p.grids] ugrid_dims = {dim for grid in grids for dim in grid.dimensions} grids_by_name = group_grids_by_name(partitions) # TODO: make sure 1D variables also in vars_by_dim - vars_by_dim, other_vars = separate_variables(data_objects, ugrid_dims) + data_objects_by_name = group_data_objects_by_gridname(partitions) + vars_by_dim, other_vars_by_name = separate_variables(data_objects_by_name, ugrid_dims) # First, take identical non-UGRID variables from the first partition: - merged = data_objects[0][other_vars] + merged = xr.Dataset() # data_objects[0][other_vars] # Merge the UGRID topologies into one, and find the indexes to index into # the data to avoid duplicates. merged_grids = [] - for grids in grids_by_name.values(): + for grids, data_objects, other_vars in zip(grids_by_name.values(), data_objects_by_name.values(), other_vars_by_name.values()): # First, merge the grid topology. + merged.update(data_objects[0][other_vars]) grid = grids[0] # TODO: shortcut for length 1 merge_partitions merged_grid, indexes = grid.merge_partitions(grids) @@ -305,7 +329,9 @@ def merge_partitions(partitions): obj[vars].isel({dim: index}, missing_dims="ignore") for obj, index in zip(data_objects, dim_indexes) ] - merged_selection = xr.concat(selection, dim=dim) + selection_padded = maybe_pad_connectivity_dims_to_max(selection, merged_grid) + + merged_selection = xr.concat(selection_padded, dim=dim) merged.update(merged_selection) return UgridDataset(merged, merged_grids) From ba6e25968123ea7edb491c37938135044451794b Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Mon, 12 Feb 2024 15:50:01 +0100 Subject: [PATCH 05/26] Format --- xugrid/ugrid/partitioning.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 85d0ea175..f2e1adcf6 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -262,13 +262,20 @@ def all_equal(iterator): def maybe_pad_connectivity_dims_to_max(selection, merged_grid): nmax_dict = merged_grid.max_connectivity_dimensions - nmax_dict = {key: value for key, value in nmax_dict.items() if key in selection[0].dims} + nmax_dict = { + key: value for key, value in nmax_dict.items() if key in selection[0].dims + } if not nmax_dict: return selection - - pad_width_ls = [{dim: (0, nmax - obj.sizes[dim]) for dim, nmax in nmax_dict.items()} for obj in selection] - return [obj.pad(pad_width = pad_width) for obj, pad_width in zip(selection, pad_width_ls)] + pad_width_ls = [ + {dim: (0, nmax - obj.sizes[dim]) for dim, nmax in nmax_dict.items()} + for obj in selection + ] + + return [ + obj.pad(pad_width=pad_width) for obj, pad_width in zip(selection, pad_width_ls) + ] def merge_partitions(partitions): @@ -306,15 +313,21 @@ def merge_partitions(partitions): grids_by_name = group_grids_by_name(partitions) # TODO: make sure 1D variables also in vars_by_dim data_objects_by_name = group_data_objects_by_gridname(partitions) - vars_by_dim, other_vars_by_name = separate_variables(data_objects_by_name, ugrid_dims) + vars_by_dim, other_vars_by_name = separate_variables( + data_objects_by_name, ugrid_dims + ) # First, take identical non-UGRID variables from the first partition: - merged = xr.Dataset() # data_objects[0][other_vars] + merged = xr.Dataset() # data_objects[0][other_vars] # Merge the UGRID topologies into one, and find the indexes to index into # the data to avoid duplicates. merged_grids = [] - for grids, data_objects, other_vars in zip(grids_by_name.values(), data_objects_by_name.values(), other_vars_by_name.values()): + for grids, data_objects, other_vars in zip( + grids_by_name.values(), + data_objects_by_name.values(), + other_vars_by_name.values(), + ): # First, merge the grid topology. merged.update(data_objects[0][other_vars]) grid = grids[0] @@ -329,7 +342,9 @@ def merge_partitions(partitions): obj[vars].isel({dim: index}, missing_dims="ignore") for obj, index in zip(data_objects, dim_indexes) ] - selection_padded = maybe_pad_connectivity_dims_to_max(selection, merged_grid) + selection_padded = maybe_pad_connectivity_dims_to_max( + selection, merged_grid + ) merged_selection = xr.concat(selection_padded, dim=dim) merged.update(merged_selection) From 151f416bce129841240ed51b9c7f12d1caaffdec Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Mon, 12 Feb 2024 15:53:34 +0100 Subject: [PATCH 06/26] Update changelog --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index abbf62a41..5441a7ac6 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -19,6 +19,8 @@ Fixed copies). - Fixed bug in :meth:`xugrid.Ugrid1d.merge_partitions`, which caused ``ValueError: indexes must be provided for attrs``. +- :func:`xugrid.merge_partitions` merges partitions with grids not contained in + other partitions. Added ~~~~~ From 68b2d94dc4d1303ccc81117b896b901e51e72499 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Mon, 12 Feb 2024 16:25:08 +0100 Subject: [PATCH 07/26] Add test --- tests/test_partitioning.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index 04596cf4a..9e0133fa6 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -313,3 +313,20 @@ def test_merge_partitions(self): assert merged["c"] == 0 assert self.dataset_expected.equals(merged) + + def test_merge_partitions_inconsistent_grid_types(self): + self.datasets_parts[0] = self.datasets_parts[0].drop_vars("b") + b = self.dataset_expected["b"].isel(mesh1d_nEdges=[0, 1, 2]) + self.dataset_expected = self.dataset_expected.drop_vars(["b", "mesh1d_nEdges"]) + self.dataset_expected["b"] = b + self.dataset_expected["c"] = 1 + + merged = pt.merge_partitions(self.datasets_parts) + assert isinstance(merged, xu.UgridDataset) + assert len(merged.ugrid.grids) == 2 + # In case of non-UGRID data, it should default to the first partition in + # the last grid: + assert merged["c"] == 1 + + assert self.dataset_expected.equals(merged) + From 8ab89ade3aef005657b153b513ea1524b7584e7d Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 11:39:32 +0100 Subject: [PATCH 08/26] Support merging partitions with inconsistent grids --- xugrid/ugrid/partitioning.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index f2e1adcf6..8839cae6e 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -233,18 +233,17 @@ def all_equal(iterator): return all(element == first for element in iterator) # Group variables by UGRID dimension. - grouped = defaultdict(list) # UGRID associated vars - other = defaultdict(list) # other vars + grouped = defaultdict(set) # UGRID associated vars + other = defaultdict(set) # other vars for gridname, data_objects in objects_by_gridname.items(): - first = data_objects[0] - variables = first.variables - vardims = {var: tuple(first[var].dims) for var in variables} + variables = {varname: data for obj in data_objects for varname, data in obj.variables.items()} + vardims = {varname: data.dims for varname, data in variables.items()} for var, da in variables.items(): - shapes = (obj[var].shape for obj in data_objects) + shapes = [obj[var].shape for obj in data_objects if var in obj] # Check if variable depends on UGRID dimension. - intersection = ugrid_dims.intersection(da.dims) + intersection = ugrid_dims.intersection(vardims[var]) if intersection: assert_single_dim(intersection) # Now check whether the non-UGRID dimensions match. @@ -252,10 +251,10 @@ def all_equal(iterator): axis = vardims[var].index(dim) shapes = [remove_item(shape, axis) for shape in shapes] if all_equal(shapes): - grouped[dim].append(var) + grouped[dim].add(var) elif all_equal(shapes): - other[gridname].append(var) + other[gridname].add(var) return grouped, other @@ -311,14 +310,14 @@ def merge_partitions(partitions): grids = [grid for p in partitions for grid in p.grids] ugrid_dims = {dim for grid in grids for dim in grid.dimensions} grids_by_name = group_grids_by_name(partitions) - # TODO: make sure 1D variables also in vars_by_dim + data_objects_by_name = group_data_objects_by_gridname(partitions) vars_by_dim, other_vars_by_name = separate_variables( data_objects_by_name, ugrid_dims ) # First, take identical non-UGRID variables from the first partition: - merged = xr.Dataset() # data_objects[0][other_vars] + merged = xr.Dataset() # Merge the UGRID topologies into one, and find the indexes to index into # the data to avoid duplicates. @@ -331,16 +330,20 @@ def merge_partitions(partitions): # First, merge the grid topology. merged.update(data_objects[0][other_vars]) grid = grids[0] - # TODO: shortcut for length 1 merge_partitions + merged_grid, indexes = grid.merge_partitions(grids) merged_grids.append(merged_grid) # Now remove duplicates, then concatenate along the UGRID dimension. for dim, dim_indexes in indexes.items(): vars = vars_by_dim[dim] + if len(vars) == 0: + continue + first_var = next(iter(vars)) + objects_indexes_to_select = [(obj[vars], index) for obj, index in zip(data_objects, dim_indexes) if first_var in obj] selection = [ obj[vars].isel({dim: index}, missing_dims="ignore") - for obj, index in zip(data_objects, dim_indexes) + for obj, index in objects_indexes_to_select ] selection_padded = maybe_pad_connectivity_dims_to_max( selection, merged_grid From 30d12b9101e36f2778da5b46d794e73051563f63 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 11:41:18 +0100 Subject: [PATCH 09/26] Fix comments and drop mesh1d_nEdges in paritition as well. --- tests/test_partitioning.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index 9e0133fa6..833ca34ab 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -260,7 +260,8 @@ def test_merge_partitions(self): merged = pt.merge_partitions(self.datasets_partitioned) assert isinstance(merged, xu.UgridDataset) assert len(merged.ugrid.grids) == 1 - # In case of non-UGRID data, it should default to the first partition: + # In case of non-UGRID data, it should default to the first partition of + # the grid that's checked last. assert merged["c"] == 0 assert self.dataset_expected.ugrid.grid.equals(merged.ugrid.grid) @@ -289,7 +290,13 @@ def setup(self): ds["a"] = ((part_a.face_dimension), values_a) ds["b"] = ((part_b.edge_dimension), values_b) ds["c"] = i - datasets_parts.append(ds) + + coords = { + part_a.face_dimension: values_a, + part_b.edge_dimension: values_b, + } + + datasets_parts.append(ds.assign_coords(**coords)) ds_expected = xu.UgridDataset(grids=[grid_a, grid_b]) ds_expected["a"] = ((grid_a.face_dimension), np.concatenate(values_parts_a)) @@ -297,8 +304,8 @@ def setup(self): ds_expected["c"] = 0 # Assign coordinates also added during merge_partitions coords = { - grid_a.face_dimension: np.arange(grid_a.n_face), - grid_b.edge_dimension: np.arange(grid_b.n_edge), + grid_a.face_dimension: np.concatenate(values_parts_a), + grid_b.edge_dimension: np.concatenate(values_parts_b), } ds_expected = ds_expected.assign_coords(**coords) @@ -309,13 +316,14 @@ def test_merge_partitions(self): merged = pt.merge_partitions(self.datasets_parts) assert isinstance(merged, xu.UgridDataset) assert len(merged.ugrid.grids) == 2 - # In case of non-UGRID data, it should default to the first partition: + # In case of non-UGRID data, it should default to the first partition of + # the grid that's checked last. assert merged["c"] == 0 assert self.dataset_expected.equals(merged) def test_merge_partitions_inconsistent_grid_types(self): - self.datasets_parts[0] = self.datasets_parts[0].drop_vars("b") + self.datasets_parts[0] = self.datasets_parts[0].drop_vars(["b", "mesh1d_nEdges"]) b = self.dataset_expected["b"].isel(mesh1d_nEdges=[0, 1, 2]) self.dataset_expected = self.dataset_expected.drop_vars(["b", "mesh1d_nEdges"]) self.dataset_expected["b"] = b @@ -324,8 +332,8 @@ def test_merge_partitions_inconsistent_grid_types(self): merged = pt.merge_partitions(self.datasets_parts) assert isinstance(merged, xu.UgridDataset) assert len(merged.ugrid.grids) == 2 - # In case of non-UGRID data, it should default to the first partition in - # the last grid: + # In case of non-UGRID data, it should default to the first partition of + # the grid that's checked last. assert merged["c"] == 1 assert self.dataset_expected.equals(merged) From 5815063f9bbca29eff8b8d2f5dcafaa38a659b55 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 14:36:26 +0100 Subject: [PATCH 10/26] Fix validation --- xugrid/ugrid/partitioning.py | 65 ++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 32 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 8839cae6e..de6caba10 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -1,6 +1,6 @@ """Create and merge partitioned UGRID topologies.""" from collections import defaultdict -from itertools import accumulate +from itertools import accumulate, chain from typing import List import numpy as np @@ -145,13 +145,7 @@ def merge_edges(grids, node_inverse): return _merge_connectivity(all_edges, slices) -def validate_partition_topology(grouped, n_partition: int): - n = n_partition - if False: - raise ValueError( - f"Expected {n} UGRID topologies for {n} partitions, received: " f"{grouped}" - ) - +def validate_partition_topology(grouped): for name, grids in grouped.items(): types = {type(grid) for grid in grids} if len(types) > 1: @@ -176,7 +170,7 @@ def group_grids_by_name(partitions): for grid in partition.grids: grouped[grid.name].append(grid) - validate_partition_topology(grouped, len(partitions)) + validate_partition_topology(grouped) return grouped @@ -196,28 +190,27 @@ def group_data_objects_by_gridname(partitions): return grouped -def validate_partition_objects(data_objects): - # Check presence of variables. - # TODO: Groupby gridtype, then test if variables present all grids per type. - allvars = list({tuple(sorted(ds.data_vars)) for ds in data_objects}) - if len(allvars) > 1: - raise ValueError( - "These variables are present in some partitions, but not in " - f"others: {set(allvars[0]).symmetric_difference(allvars[1])}" - ) - # Check dimensions - for var in allvars.pop(): - vardims = list({ds[var].dims for ds in data_objects}) - if len(vardims) > 1: - raise ValueError( - f"Dimensions for {var} do not match across partitions: " - f"{vardims[0]} versus {vardims[1]}" - ) +def validate_partition_objects(objects_by_gridname): + for data_objects in objects_by_gridname.values(): + allvars = list({tuple(sorted(ds.data_vars)) for ds in data_objects}) + unique_vars = set(chain(*allvars)) + # Check dimensions + dims_per_var = [ + {ds[var].dims for ds in data_objects if var in ds.data_vars} + for var in unique_vars + ] + for var, vardims in zip(unique_vars, dims_per_var): + if len(vardims) > 1: + vardims_ls = list(vardims) + raise ValueError( + f"Dimensions for '{var}' do not match across partitions: " + f"{vardims_ls[0]} versus {vardims_ls[1]}" + ) def separate_variables(objects_by_gridname, ugrid_dims): """Separate into UGRID variables grouped by dimension, and other variables.""" - # validate_partition_objects(data_objects) + validate_partition_objects(objects_by_gridname) def assert_single_dim(intersection): if len(intersection) > 1: @@ -237,18 +230,22 @@ def all_equal(iterator): other = defaultdict(set) # other vars for gridname, data_objects in objects_by_gridname.items(): - variables = {varname: data for obj in data_objects for varname, data in obj.variables.items()} + variables = { + varname: data + for obj in data_objects + for varname, data in obj.variables.items() + } vardims = {varname: data.dims for varname, data in variables.items()} - for var, da in variables.items(): + for var, dims in vardims.items(): shapes = [obj[var].shape for obj in data_objects if var in obj] # Check if variable depends on UGRID dimension. - intersection = ugrid_dims.intersection(vardims[var]) + intersection = ugrid_dims.intersection(dims) if intersection: assert_single_dim(intersection) # Now check whether the non-UGRID dimensions match. dim = intersection.pop() # Get the single element in the set. - axis = vardims[var].index(dim) + axis = dims.index(dim) shapes = [remove_item(shape, axis) for shape in shapes] if all_equal(shapes): grouped[dim].add(var) @@ -340,7 +337,11 @@ def merge_partitions(partitions): if len(vars) == 0: continue first_var = next(iter(vars)) - objects_indexes_to_select = [(obj[vars], index) for obj, index in zip(data_objects, dim_indexes) if first_var in obj] + objects_indexes_to_select = [ + (obj[vars], index) + for obj, index in zip(data_objects, dim_indexes) + if first_var in obj + ] selection = [ obj[vars].isel({dim: index}, missing_dims="ignore") for obj, index in objects_indexes_to_select From 147f3d4505c0bfc4af147236e4d70f402da62888 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 16:37:33 +0100 Subject: [PATCH 11/26] Add test --- tests/test_partitioning.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index 833ca34ab..da220ae79 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -192,11 +192,6 @@ def test_merge_partitions(self): assert merged["c"] == 0 def test_merge_partitions__errors(self): - pa = self.datasets[0][["a"]] - pb = self.datasets[1][["b"]] - with pytest.raises(ValueError, match="Expected 2 UGRID topologies"): - pt.merge_partitions([pa, pb]) - grid_a = self.datasets[1].ugrid.grids[0].copy() grid_c = self.datasets[1].ugrid.grids[1].copy() grid_c._attrs["face_dimension"] = "abcdef" @@ -322,8 +317,10 @@ def test_merge_partitions(self): assert self.dataset_expected.equals(merged) - def test_merge_partitions_inconsistent_grid_types(self): - self.datasets_parts[0] = self.datasets_parts[0].drop_vars(["b", "mesh1d_nEdges"]) + def test_merge_partitions__inconsistent_grid_types(self): + self.datasets_parts[0] = self.datasets_parts[0].drop_vars( + ["b", "mesh1d_nEdges"] + ) b = self.dataset_expected["b"].isel(mesh1d_nEdges=[0, 1, 2]) self.dataset_expected = self.dataset_expected.drop_vars(["b", "mesh1d_nEdges"]) self.dataset_expected["b"] = b @@ -338,3 +335,8 @@ def test_merge_partitions_inconsistent_grid_types(self): assert self.dataset_expected.equals(merged) + def test_merge_partitions__errors(self): + pa = self.datasets_parts[0][["a"]] * xr.DataArray([1.0, 1.0], dims=("error_dim",)) + pb = self.datasets_parts[1][["a"]] + with pytest.raises(ValueError, match="Dimensions for 'a' do not match across partitions: "): + pt.merge_partitions([pa, pb]) \ No newline at end of file From edfa12b24c026ff6eef967ece9c93142e5dea866 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 16:47:27 +0100 Subject: [PATCH 12/26] Add validation if vars in all data objects and ensure all other_vars are added to merged --- xugrid/ugrid/partitioning.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index de6caba10..c52d405dd 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -207,6 +207,14 @@ def validate_partition_objects(objects_by_gridname): f"{vardims_ls[0]} versus {vardims_ls[1]}" ) +def validate_vars_in_all_data_objects(vars, data_objects, gridname): + for var in vars: + var_in_objects = [True if var in obj.variables else False for obj in data_objects] + if not all(var_in_objects): + raise ValueError( + f"'{var}' does not occur not in all partitions with '{gridname}'" + ) + def separate_variables(objects_by_gridname, ugrid_dims): """Separate into UGRID variables grouped by dimension, and other variables.""" @@ -319,23 +327,26 @@ def merge_partitions(partitions): # Merge the UGRID topologies into one, and find the indexes to index into # the data to avoid duplicates. merged_grids = [] - for grids, data_objects, other_vars in zip( - grids_by_name.values(), - data_objects_by_name.values(), - other_vars_by_name.values(), - ): + for gridname, grids in grids_by_name.items(): + data_objects = data_objects_by_name[gridname] + other_vars = other_vars_by_name[gridname] + # First, merge the grid topology. - merged.update(data_objects[0][other_vars]) grid = grids[0] - merged_grid, indexes = grid.merge_partitions(grids) merged_grids.append(merged_grid) + # Add all other vars to dataset + for obj in data_objects: + other_vars_obj = set(other_vars).intersection(set(obj.data_vars)) + merged.update(obj[other_vars_obj]) + # Now remove duplicates, then concatenate along the UGRID dimension. for dim, dim_indexes in indexes.items(): vars = vars_by_dim[dim] if len(vars) == 0: continue + validate_vars_in_all_data_objects(vars, data_objects, gridname) first_var = next(iter(vars)) objects_indexes_to_select = [ (obj[vars], index) From 1570561c52be3bcc786ee307dfe081ecba20c23e Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 16:49:05 +0100 Subject: [PATCH 13/26] Add extra tests and adapt some tests --- tests/test_partitioning.py | 46 +++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index da220ae79..23d22ba73 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -139,12 +139,12 @@ def test_merge_partitions__errors(self): grid1 = partitions[1].ugrid.grid partitions[1]["extra"] = (grid1.face_dimension, np.ones(grid1.n_face)) - with pytest.raises(ValueError, match="These variables are present"): + with pytest.raises(ValueError, match="'extra' does not occur not in all partitions with 'mesh2d'"): pt.merge_partitions(partitions) partitions = self.uds.ugrid.partition(n_part=2) partitions[1]["face_z"] = partitions[1]["face_z"].expand_dims("layer", axis=0) - with pytest.raises(ValueError, match="Dimensions for face_z do not match"): + with pytest.raises(ValueError, match="Dimensions for 'face_z' do not match"): pt.merge_partitions(partitions) uds = self.uds.copy() @@ -188,10 +188,29 @@ def test_merge_partitions(self): merged = pt.merge_partitions(self.datasets) assert isinstance(merged, xu.UgridDataset) assert len(merged.ugrid.grids) == 2 - # In case of non-UGRID data, it should default to the first partition: - assert merged["c"] == 0 + # In case of non-UGRID data, it should default to the last partition: + assert merged["c"] == 1 + + assert len(merged["first_nFaces"]) == 6 + assert len(merged["second_nFaces"]) == 20 + + def test_merge_partitions__unique_grid_per_partition(self): + pa = self.datasets[0][["a"]] + pb = self.datasets[1][["b"]] + merged = pt.merge_partitions([pa, pb]) + + assert isinstance(merged, xu.UgridDataset) + assert len(merged.ugrid.grids) == 2 + + assert len(merged["first_nFaces"]) == 3 + assert len(merged["second_nFaces"]) == 10 def test_merge_partitions__errors(self): + pa = self.datasets[0][["a"]] * xr.DataArray([1.0, 1.0], dims=("error_dim",)) + pb = self.datasets[1][["a"]] + with pytest.raises(ValueError, match="Dimensions for 'a' do not match across partitions: "): + pt.merge_partitions([pa, pb]) + grid_a = self.datasets[1].ugrid.grids[0].copy() grid_c = self.datasets[1].ugrid.grids[1].copy() grid_c._attrs["face_dimension"] = "abcdef" @@ -243,7 +262,7 @@ def setup(self): ds_expected = xu.UgridDataset(grids=[grid]) ds_expected["a"] = ((grid.edge_dimension), np.concatenate(values_parts)) - ds_expected["c"] = 0 + ds_expected["c"] = 1 # Assign coordinates also added during merge_partitions coords = {grid.edge_dimension: np.arange(grid.n_edge)} ds_expected = ds_expected.assign_coords(**coords) @@ -255,9 +274,9 @@ def test_merge_partitions(self): merged = pt.merge_partitions(self.datasets_partitioned) assert isinstance(merged, xu.UgridDataset) assert len(merged.ugrid.grids) == 1 - # In case of non-UGRID data, it should default to the first partition of + # In case of non-UGRID data, it should default to the last partition of # the grid that's checked last. - assert merged["c"] == 0 + assert merged["c"] == 1 assert self.dataset_expected.ugrid.grid.equals(merged.ugrid.grid) assert self.dataset_expected["a"].equals(merged["a"]) @@ -296,7 +315,7 @@ def setup(self): ds_expected = xu.UgridDataset(grids=[grid_a, grid_b]) ds_expected["a"] = ((grid_a.face_dimension), np.concatenate(values_parts_a)) ds_expected["b"] = ((grid_b.edge_dimension), np.concatenate(values_parts_b)) - ds_expected["c"] = 0 + ds_expected["c"] = 1 # Assign coordinates also added during merge_partitions coords = { grid_a.face_dimension: np.concatenate(values_parts_a), @@ -311,9 +330,9 @@ def test_merge_partitions(self): merged = pt.merge_partitions(self.datasets_parts) assert isinstance(merged, xu.UgridDataset) assert len(merged.ugrid.grids) == 2 - # In case of non-UGRID data, it should default to the first partition of + # In case of non-UGRID data, it should default to the last partition of # the grid that's checked last. - assert merged["c"] == 0 + assert merged["c"] == 1 assert self.dataset_expected.equals(merged) @@ -329,14 +348,9 @@ def test_merge_partitions__inconsistent_grid_types(self): merged = pt.merge_partitions(self.datasets_parts) assert isinstance(merged, xu.UgridDataset) assert len(merged.ugrid.grids) == 2 - # In case of non-UGRID data, it should default to the first partition of + # In case of non-UGRID data, it should default to the last partition of # the grid that's checked last. assert merged["c"] == 1 assert self.dataset_expected.equals(merged) - def test_merge_partitions__errors(self): - pa = self.datasets_parts[0][["a"]] * xr.DataArray([1.0, 1.0], dims=("error_dim",)) - pb = self.datasets_parts[1][["a"]] - with pytest.raises(ValueError, match="Dimensions for 'a' do not match across partitions: "): - pt.merge_partitions([pa, pb]) \ No newline at end of file From 92fd23fff4405de0bb04454c9411a61a8a62fb92 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 16:49:53 +0100 Subject: [PATCH 14/26] format --- tests/test_partitioning.py | 10 +++++++--- xugrid/ugrid/partitioning.py | 5 ++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index 23d22ba73..c72b13f5b 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -139,7 +139,10 @@ def test_merge_partitions__errors(self): grid1 = partitions[1].ugrid.grid partitions[1]["extra"] = (grid1.face_dimension, np.ones(grid1.n_face)) - with pytest.raises(ValueError, match="'extra' does not occur not in all partitions with 'mesh2d'"): + with pytest.raises( + ValueError, + match="'extra' does not occur not in all partitions with 'mesh2d'", + ): pt.merge_partitions(partitions) partitions = self.uds.ugrid.partition(n_part=2) @@ -208,7 +211,9 @@ def test_merge_partitions__unique_grid_per_partition(self): def test_merge_partitions__errors(self): pa = self.datasets[0][["a"]] * xr.DataArray([1.0, 1.0], dims=("error_dim",)) pb = self.datasets[1][["a"]] - with pytest.raises(ValueError, match="Dimensions for 'a' do not match across partitions: "): + with pytest.raises( + ValueError, match="Dimensions for 'a' do not match across partitions: " + ): pt.merge_partitions([pa, pb]) grid_a = self.datasets[1].ugrid.grids[0].copy() @@ -353,4 +358,3 @@ def test_merge_partitions__inconsistent_grid_types(self): assert merged["c"] == 1 assert self.dataset_expected.equals(merged) - diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index c52d405dd..0efea2ac4 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -207,9 +207,12 @@ def validate_partition_objects(objects_by_gridname): f"{vardims_ls[0]} versus {vardims_ls[1]}" ) + def validate_vars_in_all_data_objects(vars, data_objects, gridname): for var in vars: - var_in_objects = [True if var in obj.variables else False for obj in data_objects] + var_in_objects = [ + True if var in obj.variables else False for obj in data_objects + ] if not all(var_in_objects): raise ValueError( f"'{var}' does not occur not in all partitions with '{gridname}'" From 0c85d0bcdb93dcebd427e44922c3e5b315396af2 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Tue, 13 Feb 2024 16:54:57 +0100 Subject: [PATCH 15/26] Update changelog --- docs/changelog.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5441a7ac6..ba399f5ba 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -44,6 +44,8 @@ Changed - :meth:`xugrid.Ugrid2d.from_structured` now takes ``x`` and ``y`` arguments instead of ``x_bounds`` and ``y_bounds`` arguments. +- :func:`xugrid.merge_partitions` allows merging partitions with different grids + per partition. [0.8.1] 2024-01-19 ------------------ From 22d93a20cca200fdef42195bbca799373dee2014 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 11:15:59 +0100 Subject: [PATCH 16/26] Rename max_connectivity_dimensions to max_connectivity_sizes and let the former method return a tuple of names instead --- docs/changelog.rst | 9 ++++++--- xugrid/ugrid/partitioning.py | 2 +- xugrid/ugrid/ugrid2d.py | 6 +++++- xugrid/ugrid/ugridbase.py | 6 +++++- 4 files changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index ba399f5ba..17555486c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -19,8 +19,8 @@ Fixed copies). - Fixed bug in :meth:`xugrid.Ugrid1d.merge_partitions`, which caused ``ValueError: indexes must be provided for attrs``. -- :func:`xugrid.merge_partitions` merges partitions with grids not contained in - other partitions. +- :func:`xugrid.merge_partitions` now also merges datasets with grids that are + only contained in some of the partition datasets. Added ~~~~~ @@ -36,8 +36,11 @@ Added - :attr:`xugrid.UgridDataset.sizes` as an alternative to :attr:`xugrid.UgridDataset.dimensions` - :attr:`xugrid.Ugrid2d.max_face_node_dimension` which returns the dimension name designating nodes per face. -- :attr:`xugrid.AbstractUgrid.max_connectivity_dimensions` which returns all +- :attr:`xugrid.AbstractUgrid.max_connectivity_sizes` which returns all maximum connectivity dimensions and their corresponding size. +- :attr:`xugrid.AbstractUgrid.max_connectivity_dimensions` which returns all + maximum connectivity dimensions. + Changed ~~~~~~~ diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 0efea2ac4..8a4ca14c9 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -268,7 +268,7 @@ def all_equal(iterator): def maybe_pad_connectivity_dims_to_max(selection, merged_grid): - nmax_dict = merged_grid.max_connectivity_dimensions + nmax_dict = merged_grid.max_connectivity_sizes nmax_dict = { key: value for key, value in nmax_dict.items() if key in selection[0].dims } diff --git a/xugrid/ugrid/ugrid2d.py b/xugrid/ugrid/ugrid2d.py index 9c9c9c1c2..3b7ba90ee 100644 --- a/xugrid/ugrid/ugrid2d.py +++ b/xugrid/ugrid/ugrid2d.py @@ -422,11 +422,15 @@ def max_face_node_dimension(self): return self._attrs["max_face_nodes_dimension"] @property - def max_connectivity_dimensions(self): + def max_connectivity_sizes(self) -> dict[str, int]: return { self.max_face_node_dimension: self.n_max_node_per_face, } + @property + def max_connectivity_dimensions(self) -> tuple[str]: + return (self.max_face_node_dimension,) + @property def topology_dimension(self): """Highest dimensionality of the geometric elements: 2""" diff --git a/xugrid/ugrid/ugridbase.py b/xugrid/ugrid/ugridbase.py index 052c40d6c..4d935191c 100644 --- a/xugrid/ugrid/ugridbase.py +++ b/xugrid/ugrid/ugridbase.py @@ -300,7 +300,11 @@ def edge_dimension(self): return self._attrs["edge_dimension"] @property - def max_connectivity_dimensions(self): + def max_connectivity_dimensions(self) -> tuple: + return () + + @property + def max_connectivity_sizes(self) -> dict[str, int]: return {} @property From 8c465bcd27d84417debc3a10b56dcd5778429275 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 11:26:28 +0100 Subject: [PATCH 17/26] Remove duplicate message --- docs/changelog.rst | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 17555486c..8885abfcc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -19,8 +19,6 @@ Fixed copies). - Fixed bug in :meth:`xugrid.Ugrid1d.merge_partitions`, which caused ``ValueError: indexes must be provided for attrs``. -- :func:`xugrid.merge_partitions` now also merges datasets with grids that are - only contained in some of the partition datasets. Added ~~~~~ @@ -47,8 +45,8 @@ Changed - :meth:`xugrid.Ugrid2d.from_structured` now takes ``x`` and ``y`` arguments instead of ``x_bounds`` and ``y_bounds`` arguments. -- :func:`xugrid.merge_partitions` allows merging partitions with different grids - per partition. +- :func:`xugrid.merge_partitions` now also merges datasets with grids that are + only contained in some of the partition datasets. [0.8.1] 2024-01-19 ------------------ From 694c31cb56c0ac72b8a2e90461630fb45ae960dc Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 11:37:15 +0100 Subject: [PATCH 18/26] Add type annotations and add return None --- xugrid/ugrid/partitioning.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 8a4ca14c9..069c4d9f0 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -9,6 +9,7 @@ from xugrid.constants import IntArray, IntDType from xugrid.core.wrap import UgridDataArray, UgridDataset from xugrid.ugrid.connectivity import renumber +from xugrid.ugrid.ugridbase import UgridType def labels_to_indices(labels: IntArray) -> List[IntArray]: @@ -145,7 +146,7 @@ def merge_edges(grids, node_inverse): return _merge_connectivity(all_edges, slices) -def validate_partition_topology(grouped): +def validate_partition_topology(grouped: defaultdict[str, UgridType]) -> None: for name, grids in grouped.items(): types = {type(grid) for grid in grids} if len(types) > 1: @@ -164,7 +165,7 @@ def validate_partition_topology(grouped): return None -def group_grids_by_name(partitions): +def group_grids_by_name(partitions: list[UgridDataset]) -> defaultdict[str, UgridType]: grouped = defaultdict(list) for partition in partitions: for grid in partition.grids: @@ -174,7 +175,7 @@ def group_grids_by_name(partitions): return grouped -def group_data_objects_by_gridname(partitions): +def group_data_objects_by_gridname(partitions: list[UgridDataset]) -> defaultdict[str, xr.Dataset]: # Convert to dataset for convenience data_objects = [partition.obj for partition in partitions] data_objects = [ @@ -190,7 +191,7 @@ def group_data_objects_by_gridname(partitions): return grouped -def validate_partition_objects(objects_by_gridname): +def validate_partition_objects(objects_by_gridname: defaultdict[str, xr.Dataset]) -> None: for data_objects in objects_by_gridname.values(): allvars = list({tuple(sorted(ds.data_vars)) for ds in data_objects}) unique_vars = set(chain(*allvars)) @@ -206,9 +207,10 @@ def validate_partition_objects(objects_by_gridname): f"Dimensions for '{var}' do not match across partitions: " f"{vardims_ls[0]} versus {vardims_ls[1]}" ) + return None -def validate_vars_in_all_data_objects(vars, data_objects, gridname): +def validate_vars_in_all_data_objects(vars: list[str], data_objects: list[xr.Dataset], gridname: str): for var in vars: var_in_objects = [ True if var in obj.variables else False for obj in data_objects @@ -217,6 +219,7 @@ def validate_vars_in_all_data_objects(vars, data_objects, gridname): raise ValueError( f"'{var}' does not occur not in all partitions with '{gridname}'" ) + return None def separate_variables(objects_by_gridname, ugrid_dims): From c92566706fafc02f89f7b8e320155c19dcb4ce8c Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 11:37:52 +0100 Subject: [PATCH 19/26] format --- xugrid/ugrid/partitioning.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 069c4d9f0..349cd82ed 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -175,7 +175,9 @@ def group_grids_by_name(partitions: list[UgridDataset]) -> defaultdict[str, Ugri return grouped -def group_data_objects_by_gridname(partitions: list[UgridDataset]) -> defaultdict[str, xr.Dataset]: +def group_data_objects_by_gridname( + partitions: list[UgridDataset] +) -> defaultdict[str, xr.Dataset]: # Convert to dataset for convenience data_objects = [partition.obj for partition in partitions] data_objects = [ @@ -191,7 +193,9 @@ def group_data_objects_by_gridname(partitions: list[UgridDataset]) -> defaultdic return grouped -def validate_partition_objects(objects_by_gridname: defaultdict[str, xr.Dataset]) -> None: +def validate_partition_objects( + objects_by_gridname: defaultdict[str, xr.Dataset] +) -> None: for data_objects in objects_by_gridname.values(): allvars = list({tuple(sorted(ds.data_vars)) for ds in data_objects}) unique_vars = set(chain(*allvars)) @@ -210,7 +214,9 @@ def validate_partition_objects(objects_by_gridname: defaultdict[str, xr.Dataset] return None -def validate_vars_in_all_data_objects(vars: list[str], data_objects: list[xr.Dataset], gridname: str): +def validate_vars_in_all_data_objects( + vars: list[str], data_objects: list[xr.Dataset], gridname: str +): for var in vars: var_in_objects = [ True if var in obj.variables else False for obj in data_objects From 0eb60f4bd26d786f03aafda9951c523ad2048a2f Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 11:54:04 +0100 Subject: [PATCH 20/26] type annotate separate_variables and fix comment --- xugrid/ugrid/partitioning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 349cd82ed..e2a7cd3ab 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -228,7 +228,7 @@ def validate_vars_in_all_data_objects( return None -def separate_variables(objects_by_gridname, ugrid_dims): +def separate_variables(objects_by_gridname: defaultdict[str, xr.Dataset], ugrid_dims: set[str]): """Separate into UGRID variables grouped by dimension, and other variables.""" validate_partition_objects(objects_by_gridname) @@ -348,7 +348,7 @@ def merge_partitions(partitions): merged_grid, indexes = grid.merge_partitions(grids) merged_grids.append(merged_grid) - # Add all other vars to dataset + # Add other vars, unassociated with UGRID dimensions, to dataset. for obj in data_objects: other_vars_obj = set(other_vars).intersection(set(obj.data_vars)) merged.update(obj[other_vars_obj]) From 96f814b12c70e1c240cb60ac7c421194ad0bdf85 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 13:41:52 +0100 Subject: [PATCH 21/26] Add type annotation --- xugrid/ugrid/ugrid2d.py | 2 +- xugrid/ugrid/ugridbase.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xugrid/ugrid/ugrid2d.py b/xugrid/ugrid/ugrid2d.py index 3b7ba90ee..1055b72dc 100644 --- a/xugrid/ugrid/ugrid2d.py +++ b/xugrid/ugrid/ugrid2d.py @@ -418,7 +418,7 @@ def dimensions(self): } @property - def max_face_node_dimension(self): + def max_face_node_dimension(self) -> str: return self._attrs["max_face_nodes_dimension"] @property diff --git a/xugrid/ugrid/ugridbase.py b/xugrid/ugrid/ugridbase.py index 4d935191c..423051c19 100644 --- a/xugrid/ugrid/ugridbase.py +++ b/xugrid/ugrid/ugridbase.py @@ -300,7 +300,7 @@ def edge_dimension(self): return self._attrs["edge_dimension"] @property - def max_connectivity_dimensions(self) -> tuple: + def max_connectivity_dimensions(self) -> tuple[str]: return () @property @@ -308,7 +308,7 @@ def max_connectivity_sizes(self) -> dict[str, int]: return {} @property - def sizes(self): + def sizes(self) -> dict[str, int]: return self.dimensions @property From d9b3494f448c9d1cacc5d2f6d0cc06bd27366de1 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 13:42:49 +0100 Subject: [PATCH 22/26] Remove useless filter loop --- xugrid/ugrid/partitioning.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index e2a7cd3ab..839fbed5e 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -359,15 +359,9 @@ def merge_partitions(partitions): if len(vars) == 0: continue validate_vars_in_all_data_objects(vars, data_objects, gridname) - first_var = next(iter(vars)) - objects_indexes_to_select = [ - (obj[vars], index) - for obj, index in zip(data_objects, dim_indexes) - if first_var in obj - ] selection = [ obj[vars].isel({dim: index}, missing_dims="ignore") - for obj, index in objects_indexes_to_select + for obj, index in zip(data_objects, dim_indexes) ] selection_padded = maybe_pad_connectivity_dims_to_max( selection, merged_grid From 4b262e441bacc2ba61e7b73623a7ce121b642f37 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 13:53:25 +0100 Subject: [PATCH 23/26] Type annotate merge_partitions --- xugrid/ugrid/ugrid1d.py | 2 +- xugrid/ugrid/ugrid2d.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xugrid/ugrid/ugrid1d.py b/xugrid/ugrid/ugrid1d.py index a5f705d16..748450f7a 100644 --- a/xugrid/ugrid/ugrid1d.py +++ b/xugrid/ugrid/ugrid1d.py @@ -630,7 +630,7 @@ def contract_vertices(self, indices: IntArray) -> "Ugrid1d": ) @staticmethod - def merge_partitions(grids: Sequence["Ugrid1d"]) -> "Ugrid1d": + def merge_partitions(grids: Sequence["Ugrid1d"]) -> tuple["Ugrid1d", dict[str, np.array]]: """ Merge grid partitions into a single whole. diff --git a/xugrid/ugrid/ugrid2d.py b/xugrid/ugrid/ugrid2d.py index 1055b72dc..4f8470b56 100644 --- a/xugrid/ugrid/ugrid2d.py +++ b/xugrid/ugrid/ugrid2d.py @@ -1505,7 +1505,7 @@ def partition(self, n_part: int): return [self.topology_subset(index) for index in indices] @staticmethod - def merge_partitions(grids: Sequence["Ugrid2d"]) -> "Ugrid2d": + def merge_partitions(grids: Sequence["Ugrid2d"]) -> tuple["Ugrid2d", dict[str, np.array]]: """ Merge grid partitions into a single whole. From 365ee195a07351b70559293fad2d2873f335febb Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 14:57:39 +0100 Subject: [PATCH 24/26] Simplify logic with reviewer's suggestions --- tests/test_partitioning.py | 2 +- xugrid/ugrid/partitioning.py | 80 ++++++++++++++++++------------------ 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/tests/test_partitioning.py b/tests/test_partitioning.py index c72b13f5b..19b3a3b4c 100644 --- a/tests/test_partitioning.py +++ b/tests/test_partitioning.py @@ -141,7 +141,7 @@ def test_merge_partitions__errors(self): partitions[1]["extra"] = (grid1.face_dimension, np.ones(grid1.n_face)) with pytest.raises( ValueError, - match="'extra' does not occur not in all partitions with 'mesh2d'", + match="Missing variables: {'extra'} in partition", ): pt.merge_partitions(partitions) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 839fbed5e..2edc350f3 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -214,19 +214,6 @@ def validate_partition_objects( return None -def validate_vars_in_all_data_objects( - vars: list[str], data_objects: list[xr.Dataset], gridname: str -): - for var in vars: - var_in_objects = [ - True if var in obj.variables else False for obj in data_objects - ] - if not all(var_in_objects): - raise ValueError( - f"'{var}' does not occur not in all partitions with '{gridname}'" - ) - return None - def separate_variables(objects_by_gridname: defaultdict[str, xr.Dataset], ugrid_dims: set[str]): """Separate into UGRID variables grouped by dimension, and other variables.""" @@ -276,22 +263,45 @@ def all_equal(iterator): return grouped, other -def maybe_pad_connectivity_dims_to_max(selection, merged_grid): - nmax_dict = merged_grid.max_connectivity_sizes - nmax_dict = { - key: value for key, value in nmax_dict.items() if key in selection[0].dims - } - if not nmax_dict: - return selection - - pad_width_ls = [ - {dim: (0, nmax - obj.sizes[dim]) for dim, nmax in nmax_dict.items()} - for obj in selection - ] - - return [ - obj.pad(pad_width=pad_width) for obj, pad_width in zip(selection, pad_width_ls) - ] +def merge_data_along_dim( + data_objects: list[xr.Dataset], + vars: list[str], + merge_dim: str, + indexes: list[np.array], + merged_grid: UgridType, +) -> xr.Dataset: + """" + Select variables from the data objects. + Pad connectivity dims if needed. + Concatenate along dim. + """ + max_sizes = merged_grid.max_connectivity_sizes + ugrid_connectivity_dims = set(max_sizes) + + to_merge = [] + for obj, index in zip(data_objects, indexes): + # Check for presence of vars + missing_vars = set(vars).difference(set(obj.variables.keys())) + if missing_vars: + raise ValueError(f"Missing variables: {missing_vars} in partition {obj}") + + selection = obj[vars].isel({merge_dim: index}, missing_dims="ignore") + + # Pad the ugrid connectivity dims (e.g. n_max_face_node_connectivity) if + # needed. + present_dims = ugrid_connectivity_dims.intersection(selection.dims) + pad_width = {} + for dim in present_dims: + nmax = max_sizes[dim] + size = selection.sizes[dim] + if size != nmax: + pad_width[dim] = (0, nmax - size) + if pad_width: + selection = selection.pad(pad_width=pad_width) + + to_merge.append(selection) + + return xr.concat(to_merge, dim=merge_dim) def merge_partitions(partitions): @@ -353,21 +363,11 @@ def merge_partitions(partitions): other_vars_obj = set(other_vars).intersection(set(obj.data_vars)) merged.update(obj[other_vars_obj]) - # Now remove duplicates, then concatenate along the UGRID dimension. for dim, dim_indexes in indexes.items(): vars = vars_by_dim[dim] if len(vars) == 0: continue - validate_vars_in_all_data_objects(vars, data_objects, gridname) - selection = [ - obj[vars].isel({dim: index}, missing_dims="ignore") - for obj, index in zip(data_objects, dim_indexes) - ] - selection_padded = maybe_pad_connectivity_dims_to_max( - selection, merged_grid - ) - - merged_selection = xr.concat(selection_padded, dim=dim) + merged_selection = merge_data_along_dim(data_objects, vars, dim, dim_indexes, merged_grid) merged.update(merged_selection) return UgridDataset(merged, merged_grids) From 20181c380ee07220ce6551cccfbaedddc934ac05 Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 14:59:07 +0100 Subject: [PATCH 25/26] format --- xugrid/ugrid/partitioning.py | 13 ++++++++----- xugrid/ugrid/ugrid1d.py | 4 +++- xugrid/ugrid/ugrid2d.py | 4 +++- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 2edc350f3..0cf22a71d 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -214,8 +214,9 @@ def validate_partition_objects( return None - -def separate_variables(objects_by_gridname: defaultdict[str, xr.Dataset], ugrid_dims: set[str]): +def separate_variables( + objects_by_gridname: defaultdict[str, xr.Dataset], ugrid_dims: set[str] +): """Separate into UGRID variables grouped by dimension, and other variables.""" validate_partition_objects(objects_by_gridname) @@ -270,7 +271,7 @@ def merge_data_along_dim( indexes: list[np.array], merged_grid: UgridType, ) -> xr.Dataset: - """" + """ " Select variables from the data objects. Pad connectivity dims if needed. Concatenate along dim. @@ -286,7 +287,7 @@ def merge_data_along_dim( raise ValueError(f"Missing variables: {missing_vars} in partition {obj}") selection = obj[vars].isel({merge_dim: index}, missing_dims="ignore") - + # Pad the ugrid connectivity dims (e.g. n_max_face_node_connectivity) if # needed. present_dims = ugrid_connectivity_dims.intersection(selection.dims) @@ -367,7 +368,9 @@ def merge_partitions(partitions): vars = vars_by_dim[dim] if len(vars) == 0: continue - merged_selection = merge_data_along_dim(data_objects, vars, dim, dim_indexes, merged_grid) + merged_selection = merge_data_along_dim( + data_objects, vars, dim, dim_indexes, merged_grid + ) merged.update(merged_selection) return UgridDataset(merged, merged_grids) diff --git a/xugrid/ugrid/ugrid1d.py b/xugrid/ugrid/ugrid1d.py index 748450f7a..93cb82d6b 100644 --- a/xugrid/ugrid/ugrid1d.py +++ b/xugrid/ugrid/ugrid1d.py @@ -630,7 +630,9 @@ def contract_vertices(self, indices: IntArray) -> "Ugrid1d": ) @staticmethod - def merge_partitions(grids: Sequence["Ugrid1d"]) -> tuple["Ugrid1d", dict[str, np.array]]: + def merge_partitions( + grids: Sequence["Ugrid1d"] + ) -> tuple["Ugrid1d", dict[str, np.array]]: """ Merge grid partitions into a single whole. diff --git a/xugrid/ugrid/ugrid2d.py b/xugrid/ugrid/ugrid2d.py index 4f8470b56..261cb3ba5 100644 --- a/xugrid/ugrid/ugrid2d.py +++ b/xugrid/ugrid/ugrid2d.py @@ -1505,7 +1505,9 @@ def partition(self, n_part: int): return [self.topology_subset(index) for index in indices] @staticmethod - def merge_partitions(grids: Sequence["Ugrid2d"]) -> tuple["Ugrid2d", dict[str, np.array]]: + def merge_partitions( + grids: Sequence["Ugrid2d"] + ) -> tuple["Ugrid2d", dict[str, np.array]]: """ Merge grid partitions into a single whole. From a96d298c948078401ff725b72b8d2462515fc5ec Mon Sep 17 00:00:00 2001 From: Joeri van Engelen Date: Wed, 14 Feb 2024 15:01:30 +0100 Subject: [PATCH 26/26] Fix typo in docstring --- xugrid/ugrid/partitioning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xugrid/ugrid/partitioning.py b/xugrid/ugrid/partitioning.py index 0cf22a71d..5d5a6e2fb 100644 --- a/xugrid/ugrid/partitioning.py +++ b/xugrid/ugrid/partitioning.py @@ -271,7 +271,7 @@ def merge_data_along_dim( indexes: list[np.array], merged_grid: UgridType, ) -> xr.Dataset: - """ " + """ Select variables from the data objects. Pad connectivity dims if needed. Concatenate along dim.