-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathargmax_layer.py
34 lines (27 loc) · 1.07 KB
/
argmax_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
"""Tests the methods in argmax_layer.py
"""
import numpy as np
import torch
from external.bazel_python.pytest_helper import main
from pysyrenn.frontend.argmax_layer import ArgMaxLayer
def test_compute():
"""Tests that the ArgMax layer correctly computes a ArgMax.
"""
inputs = np.random.uniform(size=(101, 1025))
true_argmax = np.argmax(inputs, axis=1)
argmax_layer = ArgMaxLayer()
assert np.allclose(argmax_layer.compute(inputs), true_argmax)
torch_inputs = torch.FloatTensor(inputs)
torch_outputs = argmax_layer.compute(torch_inputs).numpy()
assert np.allclose(torch_outputs, true_argmax)
def test_serialize():
"""Tests that the layer correctly serializes/deserializes itself.
"""
serialized = ArgMaxLayer().serialize()
assert serialized.WhichOneof("layer_data") == "argmax_data"
deserialized = ArgMaxLayer.deserialize(serialized)
assert deserialized.serialize() == serialized
serialized.relu_data.SetInParent()
deserialized = ArgMaxLayer.deserialize(serialized)
assert deserialized is None
main(__name__, __file__)