Skip to content

Commit

Permalink
making a quick variance fix (aws#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
lakshya97 authored and rahul003 committed Aug 16, 2019
1 parent 3e28383 commit 045d89e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
6 changes: 4 additions & 2 deletions tests/pytorch/test_reduce_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def train(model, device, optimizer, num_steps=500, save_steps=[]):

def test_reduce_config():
reset_collections()
global_reduce_config = ReductionConfig(reductions=["max", "mean"])
global_reduce_config = ReductionConfig(reductions=["max", "mean", "variance"])
global_save_config = SaveConfig(save_steps=[0,1,2,3])

ts.get_collection("ReluActivation").include(["relu*"])
Expand All @@ -76,12 +76,14 @@ def test_reduce_config():
tname = tr.tensors_matching_regex('Net_conv[0-9]+.weight')[0]
print(tr.tensors())

# Global reduction with max and mean
# Global reduction with max and mean and variance
weight_tensor = tr.tensor(tname)
max_val = weight_tensor.reduction_value(step_num=1, abs=False, reduction_name='max')
assert max_val != None
mean_val = weight_tensor.reduction_value(step_num=1, abs=False, reduction_name='mean')
assert mean_val != None
variance_val = weight_tensor.reduction_value(step_num=1, abs=False, reduction_name='variance')
assert variance_val != None

# custom reduction at step 4 with reduction = 'min and abs reduction = 'max'
tname = tr.tensors_matching_regex('relu0_input_0')[0]
Expand Down
19 changes: 10 additions & 9 deletions tornasole/pytorch/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,31 @@
from tornasole.core.reductions import get_numpy_reduction


def get_aggregated_data(aggregation_name, tensor_data, tensor_name, abs=False):
reduction_name = aggregation_name
def get_aggregated_data(reduction_name, tensor_data, tensor_name, abs=False):
if isinstance(tensor_data, np.ndarray):
return get_numpy_reduction(reduction_name, tensor_data, abs)
if abs:
tensor_data = torch.abs(tensor_data)

if reduction_name in ALLOWED_REDUCTIONS:
assert hasattr(torch.Tensor, aggregation_name)
f = getattr(torch.Tensor, aggregation_name)
if reduction_name == "variance":
reduction_name = "var"
assert hasattr(torch.Tensor, reduction_name)
f = getattr(torch.Tensor, reduction_name)
op = f(tensor_data)
return op
elif reduction_name in ALLOWED_NORMS:
if aggregation_name in ['l1', 'l2']:
ord = int(aggregation_name[1])
if reduction_name in ['l1', 'l2']:
ord = int(reduction_name[1])
else:
raise RuntimeError("Invalid normalization operation {0} for torch.Tensor".format(reduction_name))
op = torch.norm(tensor_data, p=ord)
return op
elif hasattr(torch, aggregation_name):
f = getattr(torch, aggregation_name)
elif hasattr(torch, reduction_name):
f = getattr(torch, reduction_name)
op = f(tensor_data)
return op
raise RuntimeError("Invalid aggregation_name {0}".format(aggregation_name))
raise RuntimeError("Invalid reduction_name {0}".format(reduction_name))


def make_numpy_array(x):
Expand Down

0 comments on commit 045d89e

Please sign in to comment.