From 857dd136208249d32dddafc1b6b25ba772145381 Mon Sep 17 00:00:00 2001 From: Liam Bindle Date: Fri, 22 Jul 2022 15:20:08 -0700 Subject: [PATCH] Fixes #2. Added dimension name mangling. Also fixed crashing when input dataset has variables with dimensions that don't include the input core dims. --- sparselt/esmf.py | 5 ++++- sparselt/linear_transform.py | 42 ++++++++++++++++++++++++++---------- sparselt/xr.py | 23 +++++++++++++++++--- 3 files changed, 55 insertions(+), 15 deletions(-) diff --git a/sparselt/esmf.py b/sparselt/esmf.py index 14b8600..ffad7db 100644 --- a/sparselt/esmf.py +++ b/sparselt/esmf.py @@ -3,7 +3,10 @@ def load_weights(filename, input_dims, output_dims): - ds_weights = xr.open_dataset(filename) + if isinstance(filename, xr.Dataset): + ds_weights = filename + else: + ds_weights = xr.open_dataset(filename) # Get sparse matrix elements weights = ds_weights.S diff --git a/sparselt/linear_transform.py b/sparselt/linear_transform.py index e9f229c..5608d7b 100644 --- a/sparselt/linear_transform.py +++ b/sparselt/linear_transform.py @@ -1,12 +1,13 @@ import numpy as np import scipy.sparse - +import random +import string class SparseLinearTransform: _input_core_shape = None - _input_core_dims = None + _mangled_input_core_dims = None _output_core_shape = None - _output_core_dims = None + _mangled_output_core_dims = None _matrix = None _order = None _vfunc = None @@ -21,11 +22,15 @@ def __init__(self, weights, row_ind, col_ind, row_ind -= 1 col_ind -= 1 - self._input_core_dims = tuple(input_transform_dims[0]) + self._input_mangle_suffix = ''.join(random.choice(string.ascii_letters) for _ in range(10)) + self._demangled_input_core_dims = tuple(input_transform_dims[0]) + self._mangled_input_core_dims = tuple(f'{name}_{self._input_mangle_suffix}' for name in self.demangled_input_core_dims) self._input_core_shape = tuple(input_transform_dims[1]) input_size = np.product(self._input_core_shape) - self._output_core_dims = tuple(output_transform_dims[0]) + self._output_mangle_suffix = ''.join(random.choice(string.ascii_letters) for _ in range(10)) + self._demangled_output_core_dims = tuple(output_transform_dims[0]) + self._mangled_output_core_dims = tuple(f'{name}_{self._output_mangle_suffix}' for name in self.demangled_output_core_dims) self._output_core_shape = tuple(output_transform_dims[1]) output_size = np.product(self._output_core_shape) @@ -34,13 +39,20 @@ def __init__(self, weights, row_ind, col_ind, self._vfunc = self._create_vfunc() + def mangle_dim_names(self, dim_names, are_input_dims): + return [f'{name}_{are_input_dims if self._input_mangle_suffix else self._output_mangle_suffix}' for name in dim_names] + + def demangle_dim_names(self, dim_names): + mangle_suffix_len = 11 + return [name[:-mangle_suffix_len] for name in dim_names] + def _func(self, a: np.ndarray): a = a.flatten(order=self._order) return self._matrix.dot(a).reshape(self._output_core_shape, order=self._order) def _create_vfunc(self) -> callable: - input_signature = ','.join(self.input_core_dims) - output_signature = ','.join(self.output_core_dims) + input_signature = ','.join(self.mangled_input_core_dims) + output_signature = ','.join(self.mangled_output_core_dims) return np.vectorize(self._func, signature='({})->({})'.format(input_signature, output_signature)) @property @@ -48,9 +60,17 @@ def vfunc(self) -> callable: return self._vfunc @property - def input_core_dims(self) -> tuple: - return self._input_core_dims + def mangled_input_core_dims(self) -> tuple: + return self._mangled_input_core_dims + + @property + def demangled_input_core_dims(self) -> tuple: + return self._demangled_input_core_dims @property - def output_core_dims(self) -> tuple: - return self._output_core_dims + def mangled_output_core_dims(self) -> tuple: + return self._mangled_output_core_dims + + @property + def demangled_output_core_dims(self) -> tuple: + return self._demangled_output_core_dims diff --git a/sparselt/xr.py b/sparselt/xr.py index 2b824fa..59fe013 100644 --- a/sparselt/xr.py +++ b/sparselt/xr.py @@ -1,20 +1,37 @@ +import numpy as np import xarray as xr def apply(transform, data_in, output_dataset_template=None): + drop_list = [name for name in data_in.data_vars if any([dim not in data_in[name].dims for dim in transform.demangled_input_core_dims])] + data_in = data_in.drop_vars(drop_list) + data_in = data_in.rename({k: v for k, v in zip(transform.demangled_input_core_dims, transform.mangled_input_core_dims)}) + + keep_attrs = output_dataset_template is None data_out = xr.apply_ufunc( transform.vfunc, data_in, - input_core_dims=[transform.input_core_dims], - output_core_dims=[transform.output_core_dims], + input_core_dims=[transform.mangled_input_core_dims], + output_core_dims=[transform.mangled_output_core_dims], + keep_attrs=keep_attrs ) + data_out = data_out.rename({k: v for k, v in zip(transform.mangled_output_core_dims, transform.demangled_output_core_dims)}) if output_dataset_template is not None: output_dataset = output_dataset_template.copy() + + for name in output_dataset.data_vars: + output_dataset[name][...] = np.nan if isinstance(data_out, xr.DataArray): output_dataset[data_out.name] = data_out else: - output_dataset.update(data_out) + data_vars_intersection = set(output_dataset_template.data_vars).intersection(set(data_out.data_vars)) + + for name in data_vars_intersection: + output_dataset[name].values = data_out[name].values + + data_out = data_out.drop(data_vars_intersection) + output_dataset = xr.merge([output_dataset, data_out], compat='override', join='override', combine_attrs='override') return output_dataset else: return data_out