Skip to content

Commit

Permalink
Fixes #2. Added dimension name mangling. Also fixed crashing when inp…
Browse files Browse the repository at this point in the history
…ut dataset has variables with dimensions that don't include the input core dims.
  • Loading branch information
LiamBindle committed Jul 22, 2022
1 parent a615dca commit 857dd13
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
5 changes: 4 additions & 1 deletion sparselt/esmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@


def load_weights(filename, input_dims, output_dims):
ds_weights = xr.open_dataset(filename)
if isinstance(filename, xr.Dataset):
ds_weights = filename
else:
ds_weights = xr.open_dataset(filename)

# Get sparse matrix elements
weights = ds_weights.S
Expand Down
42 changes: 31 additions & 11 deletions sparselt/linear_transform.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np
import scipy.sparse

import random
import string

class SparseLinearTransform:
_input_core_shape = None
_input_core_dims = None
_mangled_input_core_dims = None
_output_core_shape = None
_output_core_dims = None
_mangled_output_core_dims = None
_matrix = None
_order = None
_vfunc = None
Expand All @@ -21,11 +22,15 @@ def __init__(self, weights, row_ind, col_ind,
row_ind -= 1
col_ind -= 1

self._input_core_dims = tuple(input_transform_dims[0])
self._input_mangle_suffix = ''.join(random.choice(string.ascii_letters) for _ in range(10))
self._demangled_input_core_dims = tuple(input_transform_dims[0])
self._mangled_input_core_dims = tuple(f'{name}_{self._input_mangle_suffix}' for name in self.demangled_input_core_dims)
self._input_core_shape = tuple(input_transform_dims[1])
input_size = np.product(self._input_core_shape)

self._output_core_dims = tuple(output_transform_dims[0])
self._output_mangle_suffix = ''.join(random.choice(string.ascii_letters) for _ in range(10))
self._demangled_output_core_dims = tuple(output_transform_dims[0])
self._mangled_output_core_dims = tuple(f'{name}_{self._output_mangle_suffix}' for name in self.demangled_output_core_dims)
self._output_core_shape = tuple(output_transform_dims[1])
output_size = np.product(self._output_core_shape)

Expand All @@ -34,23 +39,38 @@ def __init__(self, weights, row_ind, col_ind,

self._vfunc = self._create_vfunc()

def mangle_dim_names(self, dim_names, are_input_dims):
return [f'{name}_{are_input_dims if self._input_mangle_suffix else self._output_mangle_suffix}' for name in dim_names]

def demangle_dim_names(self, dim_names):
mangle_suffix_len = 11
return [name[:-mangle_suffix_len] for name in dim_names]

def _func(self, a: np.ndarray):
a = a.flatten(order=self._order)
return self._matrix.dot(a).reshape(self._output_core_shape, order=self._order)

def _create_vfunc(self) -> callable:
input_signature = ','.join(self.input_core_dims)
output_signature = ','.join(self.output_core_dims)
input_signature = ','.join(self.mangled_input_core_dims)
output_signature = ','.join(self.mangled_output_core_dims)
return np.vectorize(self._func, signature='({})->({})'.format(input_signature, output_signature))

@property
def vfunc(self) -> callable:
return self._vfunc

@property
def input_core_dims(self) -> tuple:
return self._input_core_dims
def mangled_input_core_dims(self) -> tuple:
return self._mangled_input_core_dims

@property
def demangled_input_core_dims(self) -> tuple:
return self._demangled_input_core_dims

@property
def output_core_dims(self) -> tuple:
return self._output_core_dims
def mangled_output_core_dims(self) -> tuple:
return self._mangled_output_core_dims

@property
def demangled_output_core_dims(self) -> tuple:
return self._demangled_output_core_dims
23 changes: 20 additions & 3 deletions sparselt/xr.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,37 @@
import numpy as np
import xarray as xr


def apply(transform, data_in, output_dataset_template=None):
drop_list = [name for name in data_in.data_vars if any([dim not in data_in[name].dims for dim in transform.demangled_input_core_dims])]
data_in = data_in.drop_vars(drop_list)
data_in = data_in.rename({k: v for k, v in zip(transform.demangled_input_core_dims, transform.mangled_input_core_dims)})

keep_attrs = output_dataset_template is None
data_out = xr.apply_ufunc(
transform.vfunc,
data_in,
input_core_dims=[transform.input_core_dims],
output_core_dims=[transform.output_core_dims],
input_core_dims=[transform.mangled_input_core_dims],
output_core_dims=[transform.mangled_output_core_dims],
keep_attrs=keep_attrs
)
data_out = data_out.rename({k: v for k, v in zip(transform.mangled_output_core_dims, transform.demangled_output_core_dims)})

if output_dataset_template is not None:
output_dataset = output_dataset_template.copy()

for name in output_dataset.data_vars:
output_dataset[name][...] = np.nan
if isinstance(data_out, xr.DataArray):
output_dataset[data_out.name] = data_out
else:
output_dataset.update(data_out)
data_vars_intersection = set(output_dataset_template.data_vars).intersection(set(data_out.data_vars))

for name in data_vars_intersection:
output_dataset[name].values = data_out[name].values

data_out = data_out.drop(data_vars_intersection)
output_dataset = xr.merge([output_dataset, data_out], compat='override', join='override', combine_attrs='override')
return output_dataset
else:
return data_out

0 comments on commit 857dd13

Please sign in to comment.