Skip to content

Commit

Permalink
Align autogenerated dimension names when dims and default_dims are pr…
Browse files Browse the repository at this point in the history
…ovided
  • Loading branch information
lucianopaz committed Nov 13, 2024
1 parent 0868c9e commit ea4f2b3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
4 changes: 2 additions & 2 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def generate_dims_coords(
for i, dim_len in enumerate(shape):
idx = i + len([dim for dim in default_dims if dim in dims])
if len(dims) < idx + 1:
dim_name = f"{var_name}_dim_{idx}"
dim_name = f"{var_name}_dim_{i}"
dims.append(dim_name)
elif dims[idx] is None:
dim_name = f"{var_name}_dim_{idx}"
dim_name = f"{var_name}_dim_{i}"
dims[idx] = dim_name
dim_name = dims[idx]
if dim_name not in coords:
Expand Down
19 changes: 18 additions & 1 deletion arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
extract,
)

from ...data.base import dict_to_dataset, generate_dims_coords, infer_stan_dtypes, make_attrs
from ...data.base import (
dict_to_dataset,
generate_dims_coords,
infer_stan_dtypes,
make_attrs,
numpy_to_data_array,
)
from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
from ..helpers import ( # pylint: disable=unused-import
chains,
Expand Down Expand Up @@ -231,6 +237,17 @@ def test_dims_coords_skip_event_dims(shape):
assert "z" not in coords


@pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
def test_numpy_to_data_array_with_dims(dims):
da = numpy_to_data_array(
np.empty((4, 500, 7)),
var_name="a",
dims=dims,
default_dims=["chain", "draw"],
)
assert list(da.dims) == ["chain", "draw", "a_dim_0"]


def test_make_attrs():
extra_attrs = {"key": "Value"}
attrs = make_attrs(attrs=extra_attrs)
Expand Down

0 comments on commit ea4f2b3

Please sign in to comment.