Skip to content

Commit

Permalink
bug_2931: resolving numpy integer check and adding test (#2451)
Browse files Browse the repository at this point in the history
  • Loading branch information
MCBoarder289 authored Jun 3, 2020
1 parent 03d161c commit 267c187
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
9 changes: 8 additions & 1 deletion packages/python/plotly/plotly/basedatatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,7 +1560,14 @@ def _validate_rows_cols(name, n, vals):
if len(vals) != n:
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)

if [r for r in vals if not isinstance(r, int)]:
try:
import numpy as np

int_type = (int, np.integer)
except ImportError:
int_type = (int,)

if [r for r in vals if not isinstance(r, int_type)]:
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
else:
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,21 @@ def test_node_generator(self):
]
for i, item in enumerate(node_generator(node0)):
self.assertEqual(item, expected_node_path_tuples[i])


class TestNumpyIntegerBaseType(TestCase):
def test_numpy_integer_import(self):
# should generate a figure with subplots of array and not throw a ValueError
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

indices_rows = np.array([1], dtype=np.int)
indices_cols = np.array([1], dtype=np.int)
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(y=[1]), row=indices_rows[0], col=indices_cols[0])

data_path = ("data", 0, "y")
value = get_by_path(fig, data_path)
expected_value = (1,)
self.assertEqual(value, expected_value)

0 comments on commit 267c187

Please sign in to comment.