Skip to content

Commit

Permalink
Align autogenerated dimension names when dims and default_dims ar…
Browse files Browse the repository at this point in the history
…e provided (#2395)

* Align autogenerated dimension names when dims and default_dims are provided

* Add to CHANGELOG.md
  • Loading branch information
lucianopaz authored Nov 19, 2024
1 parent 0868c9e commit 529d795
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Change Log

## Unreleased

### Maintenance and fixes
- Make `arviz.data.generate_dims_coords` handle `dims` and `default_dims` consistently ([2395](https://github.com/arviz-devs/arviz/pull/2395))

## v0.20.0 (2024 Sep 28)

### New features
Expand Down
4 changes: 2 additions & 2 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def generate_dims_coords(
for i, dim_len in enumerate(shape):
idx = i + len([dim for dim in default_dims if dim in dims])
if len(dims) < idx + 1:
dim_name = f"{var_name}_dim_{idx}"
dim_name = f"{var_name}_dim_{i}"
dims.append(dim_name)
elif dims[idx] is None:
dim_name = f"{var_name}_dim_{idx}"
dim_name = f"{var_name}_dim_{i}"
dims[idx] = dim_name
dim_name = dims[idx]
if dim_name not in coords:
Expand Down
19 changes: 18 additions & 1 deletion arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
extract,
)

from ...data.base import dict_to_dataset, generate_dims_coords, infer_stan_dtypes, make_attrs
from ...data.base import (
dict_to_dataset,
generate_dims_coords,
infer_stan_dtypes,
make_attrs,
numpy_to_data_array,
)
from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata
from ..helpers import ( # pylint: disable=unused-import
chains,
Expand Down Expand Up @@ -231,6 +237,17 @@ def test_dims_coords_skip_event_dims(shape):
assert "z" not in coords


@pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]])
def test_numpy_to_data_array_with_dims(dims):
da = numpy_to_data_array(
np.empty((4, 500, 7)),
var_name="a",
dims=dims,
default_dims=["chain", "draw"],
)
assert list(da.dims) == ["chain", "draw", "a_dim_0"]


def test_make_attrs():
extra_attrs = {"key": "Value"}
attrs = make_attrs(attrs=extra_attrs)
Expand Down

0 comments on commit 529d795

Please sign in to comment.