Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: treescope fails to render typed PRNG key arrays #52

Closed
amifalk opened this issue Jan 3, 2025 · 1 comment · Fixed by #54
Closed

Bug: treescope fails to render typed PRNG key arrays #52

amifalk opened this issue Jan 3, 2025 · 1 comment · Fixed by #54
Labels
bug Something isn't working

Comments

@amifalk
Copy link

amifalk commented Jan 3, 2025

With jax 0.4.37 and treescope 0.1.7:

import jax
import treescope as ts

ts.basic_interactive_setup()

jax.random.key(0)
<TypeError during deferred rendering
Traceback (most recent call last):
  File ".../.venv/lib/python3.10/site-packages/treescope/lowering.py", line 358, in _render_to_html_as_root_streaming
    replacement_part = deferred.thunk(layout_decision)
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 606, in _thunk
    summarized = adapter.get_array_summary(node, fast=False)
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 525, in get_array_summary
    output_parts.append(summarize_array_data(array))
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 479, in summarize_array_data
    output_parts.extend(_summarize_array_data_unconditionally(array))
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 433, in _summarize_array_data_unconditionally
    stat = compute_summary(array, is_floating, is_integer, is_bool)
  File ".../.venv/lib/python3.10/site-packages/treescope/external/jax_support.py", line 391, in _compute_summary
    x = xnp.array(x)
  File ".../.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 283, in __array__
    raise TypeError("JAX array with PRNGKey dtype cannot be converted to a NumPy array."
TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.
>
@danieldjohnson danieldjohnson added the bug Something isn't working label Jan 21, 2025
danieldjohnson added a commit that referenced this issue Jan 21, 2025
Numeric dtypes can be summaried by outputting summary statistics
(e.g. min, max, avg, sparsity), but we do not produce any summary
for extended dtypes like JAX PRNG key dtypes. However, the existing
implementation still tries to convert the arrays to numpy for faster
summarization even if the dtype cannot be summarized, which produced
an error. This commit disables summarization fully for these dtypes
so that they can be rendered as normal.

Fixes #52.
danieldjohnson added a commit that referenced this issue Jan 21, 2025
Numeric dtypes can be summaried by outputting summary statistics
(e.g. min, max, avg, sparsity), but we do not produce any summary
for extended dtypes like JAX PRNG key dtypes. However, the existing
implementation still tries to convert the arrays to numpy for faster
summarization even if the dtype cannot be summarized, which produced
an error. This commit disables summarization fully for these dtypes
so that they can be rendered as normal.

Fixes #52.
danieldjohnson added a commit that referenced this issue Jan 21, 2025
Numeric dtypes can be summaried by outputting summary statistics
(e.g. min, max, avg, sparsity), but we do not produce any summary
for extended dtypes like JAX PRNG key dtypes. However, the existing
implementation still tries to convert the arrays to numpy for faster
summarization even if the dtype cannot be summarized, which produced
an error. This commit disables summarization fully for these dtypes
so that they can be rendered as normal.

Fixes #52.
danieldjohnson added a commit that referenced this issue Jan 21, 2025
Numeric dtypes can be summaried by outputting summary statistics
(e.g. min, max, avg, sparsity), but we do not produce any summary
for extended dtypes like JAX PRNG key dtypes. However, the existing
implementation still tries to convert the arrays to numpy for faster
summarization even if the dtype cannot be summarized, which produced
an error. This commit disables summarization fully for these dtypes
so that they can be rendered as normal.

Fixes #52.
@danieldjohnson
Copy link
Collaborator

Thanks for flagging this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants