Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for conversion to torch.Tensor #196

Closed
samvanstroud opened this issue Nov 21, 2022 · 3 comments · Fixed by #197
Closed

Support for conversion to torch.Tensor #196

samvanstroud opened this issue Nov 21, 2022 · 3 comments · Fixed by #197
Labels
bug Something isn't working

Comments

@samvanstroud
Copy link

🐛 Bug report

I am encountering a problem using the package with Lightning.

When trying to specify an init_args which expects a torch.Tensor (for example when specifying the weights of a loss function), an error occurs because torch module expects a tensor, not a list.

It should be possible to register pytorch tensors using

def serializer(x):
    return x.tolist()

def deserializer(x):
    return torch.tensor(x)

register_type(torch.Tensor, serializer, deserializer)

However, doing so results in an error:

main: error: Parser key "model": Problem with given class_path "MyModel":
  - Configuration check failed :: Parser key "my_model": Value "Namespace(class_path='MySubModel', init_args=Namespace(..., loss=Namespace(class_path='torch.nn.CrossEntropyLoss', init_args=Namespace(weight=tensor([  7.6200, 104.9100,   1.5600,  20.6500,  14.0500,  14.0600,   1.0000,  36.0700]), size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0))))" does not validate against any of the types in typing.Optional[MySubModel]:
    - Boolean value of Tensor with more than one value is ambiguous
...

Since a list is specified in the yaml and the namespace shows a tensor([...]), it looks like the deserializer is working correctly, but something is going wrong.

Expected behavior

The list should be converted to a tensor without error.

Environment

  • jsonargparse 4.15.2
  • Python 3.10.6
@samvanstroud samvanstroud added the bug Something isn't working label Nov 21, 2022
@mauvilsa
Copy link
Member

Thank you for reporting! Will fix it as soon as possible.

@mauvilsa
Copy link
Member

I created a pull request with the fix #197. It was a small bug but I also added a unit test to avoid regressions. @samvanstroud can you please test it out with lightning and maybe review the pull request? You can install directly from the branch as:

pip install "jsonargparse @ https://github.com/omni-us/jsonargparse/zipball/issue-196-non-bool-cast-type"

@samvanstroud
Copy link
Author

Thanks a lot @mauvilsa! That is working perfectly for me, I don't have any comments on the implementation - looks clean.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants