Skip to content

Commit

Permalink
Relative steps (#21)
Browse files Browse the repository at this point in the history
  • Loading branch information
dilpath authored May 2, 2023
1 parent 8a1766e commit f7fd881
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 9 deletions.
3 changes: 3 additions & 0 deletions fiddy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,6 @@ class MethodId(str, Enum):
# FORWARD = MethodId.FORWARD
# HYBRID = MethodId.HYBRID
##>>>>>>> origin/main


EPSILON = 1e-5
24 changes: 18 additions & 6 deletions fiddy/derivative.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
from typing import Any, Callable, Dict, List, Union
import warnings

from dataclasses import dataclass

Expand All @@ -9,6 +10,7 @@
from .constants import (
MethodId,
Type,
EPSILON,
)

from .analysis import Analysis
Expand Down Expand Up @@ -133,6 +135,7 @@ def get_derivative(
success_checker: Success,
*args,
analysis_classes: List[Analysis] = None,
relative_sizes: bool = False,
directions: Union[List[Type.DIRECTION], Dict[str, Type.DIRECTION]] = None,
direction_ids: List[str] = None,
direction_indices: List[int] = None,
Expand All @@ -149,14 +152,22 @@ def get_derivative(
The IDs of the directions.
directions:
List: The directions to step along. Dictionary: keys are direction IDs, values are directions.
relative_sizes:
If `True`, sizes are scaled by the `point`, otherwise not.
"""
# TODO docs
direction_ids, directions = get_directions(
point=point,
directions=directions,
ids=direction_ids,
indices=direction_indices,
)
if directions is not None:
direction_ids, directions = get_directions(
directions=directions,
ids=direction_ids,
indices=direction_indices,
)
else:
direction_ids, directions = get_directions(
point=point,
ids=direction_ids,
indices=direction_indices,
)
if custom_methods is None:
custom_methods = {}
if analysis_classes is None:
Expand Down Expand Up @@ -184,6 +195,7 @@ def get_derivative(
size=size,
method=method,
autorun=False,
relative_size=relative_sizes,
)
computers.append(computer)
directional_derivative = DirectionalDerivative(
Expand Down
24 changes: 22 additions & 2 deletions fiddy/directional_derivative.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import abc
from typing import Any, Callable, Dict, List, Union, Tuple
import warnings

import numpy as np
import pandas as pd

from .constants import (
MethodId,
Type,
EPSILON,
)

from .step import step
Expand Down Expand Up @@ -36,6 +38,7 @@ class Computer:
completed: bool = False
results: List[ComputerResult] = field(default_factory=list)
#value: Type.DIRECTIONAL_DERIVATIVE = None
relative_size: bool = False

def __post_init__(self):
if isinstance(self.method, MethodId):
Expand All @@ -48,13 +51,30 @@ def __post_init__(self):
if self.autorun:
self()

def get_size(self):
if not self.relative_size:
return self.size

# If relative, project point onto direction as scaling factor for size
unit_direction = self.direction / np.linalg.norm(self.direction)
# TODO add some epsilon to size?
size = np.dot(self.point, unit_direction) * self.size
if size == 0:
warnings.warn(
"Point has no component in this direction. "
"Set `Computer.relative_size=False` to avoid this. "
f"Using default small step size `fiddy.EPSILON`: {EPSILON}"
)
size = EPSILON
return size

def __call__(self):
value = self.method(
point=self.point,
direction=self.direction,
size=self.size,
size=self.get_size(),
)
result = ComputerResult(method_id=self.method.id, value=value, metadata={'size': self.size})
result = ComputerResult(method_id=self.method.id, value=value, metadata={'size': self.get_size(), 'size_absolute': self.size})
self.results.append(result)
self.completed = True

Expand Down
2 changes: 1 addition & 1 deletion fiddy/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ def step(
Returns:
The step.
"""
return direction * size / np.linalg.norm(direction)
return direction / np.linalg.norm(direction) * size
61 changes: 61 additions & 0 deletions tests/test_derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,64 @@ def test_get_derivative(point, sizes, output_shape):
)
result = check(rtol=1e-2)
assert result.success


def test_get_derivative_relative():
point = np.array((3,4,0))
size = 1e-1
output_shape = (1,)

direction = np.array([1,0,0])

directions = [direction]
success_checker = Consistency(atol=1e-2)

function = partial(rosenbrock, output_shape=output_shape)

# Expected finite difference derivatives
f_0 = function(point)
f_a = function(point + direction*size)
f_r = function(point + point*direction*size) # cardinal direction, simplifies to this, but usually need dot product

g_a = (f_a-f_0)/size
g_r = (f_r-f_0)/(point*direction*size).sum() # cardinal direction, simplifies to this, but usually need dot product

# Fiddy finite difference derivatives
kwargs = {
'function': function,
'point': point,
'sizes': [size],
'method_ids': [MethodId.FORWARD],
'directions': [direction],
'success_checker': success_checker,
}
fiddy_r = float(np.squeeze(get_derivative(**kwargs, relative_sizes=True).value))
fiddy_a = float(np.squeeze(get_derivative(**kwargs, relative_sizes=False).value))

# Relative step sizes work
assert np.isclose(fiddy_r, g_r)
assert np.isclose(fiddy_a, g_a)

# Same thing, now with non-cardinal direction
function = lambda x: (x[0]-2)**2 + (x[1]+3)**2
point = np.array([3,4])
direction = np.array([1,1])
unit_direction = direction / np.linalg.norm(direction)
kwargs['function'] = function
kwargs['directions'] = [direction]
kwargs['point'] = point

size_r = size * np.dot(point, unit_direction)

f_0 = function(point)
f_a = function(point + unit_direction * size)
f_r = function(point + unit_direction * size_r)

g_a = (f_a-f_0)/size
g_r = (f_r-f_0)/size_r

fiddy_r = float(np.squeeze(get_derivative(**kwargs, relative_sizes=True).value))
fiddy_a = float(np.squeeze(get_derivative(**kwargs, relative_sizes=False).value))

assert np.isclose(fiddy_r, g_r)
assert np.isclose(fiddy_a, g_a)

0 comments on commit f7fd881

Please sign in to comment.