diff --git a/arviz/data/base.py b/arviz/data/base.py index 520077bce1..ef5850c062 100644 --- a/arviz/data/base.py +++ b/arviz/data/base.py @@ -201,10 +201,10 @@ def generate_dims_coords( for i, dim_len in enumerate(shape): idx = i + len([dim for dim in default_dims if dim in dims]) if len(dims) < idx + 1: - dim_name = f"{var_name}_dim_{idx}" + dim_name = f"{var_name}_dim_{i}" dims.append(dim_name) elif dims[idx] is None: - dim_name = f"{var_name}_dim_{idx}" + dim_name = f"{var_name}_dim_{i}" dims[idx] = dim_name dim_name = dims[idx] if dim_name not in coords: diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index b751fd2931..0210cc5f34 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -32,7 +32,13 @@ extract, ) -from ...data.base import dict_to_dataset, generate_dims_coords, infer_stan_dtypes, make_attrs +from ...data.base import ( + dict_to_dataset, + generate_dims_coords, + infer_stan_dtypes, + make_attrs, + numpy_to_data_array, +) from ...data.datasets import LOCAL_DATASETS, REMOTE_DATASETS, RemoteFileMetadata from ..helpers import ( # pylint: disable=unused-import chains, @@ -231,6 +237,17 @@ def test_dims_coords_skip_event_dims(shape): assert "z" not in coords +@pytest.mark.parametrize("dims", [None, ["chain", "draw"], ["chain", "draw", None]]) +def test_numpy_to_data_array_with_dims(dims): + da = numpy_to_data_array( + np.empty((4, 500, 7)), + var_name="a", + dims=dims, + default_dims=["chain", "draw"], + ) + assert list(da.dims) == ["chain", "draw", "a_dim_0"] + + def test_make_attrs(): extra_attrs = {"key": "Value"} attrs = make_attrs(attrs=extra_attrs)