-
Notifications
You must be signed in to change notification settings - Fork 394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make to_tensor work correctly with PackedSequence #335
Make to_tensor work correctly with PackedSequence #335
Conversation
Also, add a bunch of tests for to_tensor. Partly addresses #325.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise
skorch/tests/test_utils.py
Outdated
(PackedSequence(x, y), PackedSequence(x, y)), | ||
]) | ||
def test_tensor_conversion(self, to_tensor, X, expected): | ||
result = to_tensor(X, 'cpu') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it might be worthwhile to make this an optional GPU test as well (if available), e.g.
result = to_tensor(X, 'cpu')
# ... asserts ...
if not torch.cuda.is_available():
return
result = to_tensor(X, 'cuda')
# ... asserts ...
* Use correct test for PackedSequence. * Add tests for device='cuda'
@githubnemo Could you pls review again? |
pack_padded_sequence(x, y).to('cuda') | ||
), | ||
]: | ||
yield X, expected, device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't you avoid code duplication by re-using the parameters? something like
devices = ['cpu']
if torch.cuda.is_available():
devices.append('cuda')
for dev in devices:
for p in params:
yield (p[0], p[1].to(dev), dev)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This unfortunately doesn't work because some of the parameters don't support .to(dev)
(list, tuple, dict).
I pondered quite a bit about this parameter thing but finally settled on something that is a little bit repetitive but I hope at least easy to understand.
* Fix a bug that prevented to_tensor from working correctly with PackedSequence * Add a bunch of unit tests for to_tensor.
Also, add a bunch of tests for to_tensor.
Partly addresses #325.