Skip to content

Commit

Permalink
Stricter typing.
Browse files Browse the repository at this point in the history
Signed-off-by: Emanuele Ballarin <[email protected]>
  • Loading branch information
emaballarin committed Dec 6, 2024
1 parent 96613da commit 16799c9
Showing 1 changed file with 49 additions and 47 deletions.
96 changes: 49 additions & 47 deletions ebtorch/optim/adopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Original implementation: https://github.com/iShohei220/adopt.git
# Derivative implementation: https://github.com/huggingface/pytorch-image-models/blob/main/timm/optim/adopt.py
# ~~ Imports ~~ ────────────────────────────────────────────────────────────────
from collections.abc import Callable
from typing import cast
from typing import List
from typing import Optional
Expand All @@ -25,12 +26,11 @@

# ~~ Error Messages ~~ ─────────────────────────────────────────────────────────
_nocapture_err: str = (
"lr as a Tensor is not supported for capturable=False and foreach=True"
"`lr` as a `Tensor` is not supported for `capturable=False` and `foreach=True`"
)

# ~~ ADOPT Optimizer ~~ ────────────────────────────────────────────────────────


# ~~ ADOPT Optimizer ~~ ────────────────────────────────────────────────────────
class ADOPT(Optimizer):
"""
ADOPT: Modified Adam Can Converge with Any β2 with the Optimal Rate: https://arxiv.org/abs/2411.02853
Expand All @@ -51,24 +51,24 @@ def __init__(
maximize: bool = False,
capturable: bool = False,
differentiable: bool = False,
):
) -> None:
if isinstance(lr, Tensor):
if foreach and not capturable:
raise ValueError(_nocapture_err)
if lr.numel() != 1:
raise ValueError("Tensor lr must be 1-element")
raise ValueError("`lr` as `Tensor` must be 1-element")
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
raise ValueError(f"Invalid `lr`: {lr}")
if eps < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
raise ValueError(f"Invalid `eps`: {eps}")
if not 0.0 <= betas[0] < 1.0:
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
raise ValueError(f"Invalid `betas[0]`: {betas[0]}")
if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
raise ValueError(f"Invalid `betas[1]`: {betas[1]}")
if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
raise ValueError(f"Invalid `weight_decay`: {weight_decay}")

defaults = dict(
defaults: dict = dict(
lr=lr,
betas=betas,
eps=eps,
Expand All @@ -83,7 +83,7 @@ def __init__(
)
super().__init__(params, defaults)

def __setstate__(self, state):
def __setstate__(self, state: dict) -> None:
super().__setstate__(state)
for group in self.param_groups:
group.setdefault("maximize", False)
Expand All @@ -93,9 +93,9 @@ def __setstate__(self, state):
group.setdefault("clip_exp", None)
group.setdefault("caution", False)
for p in group["params"]:
p_state = self.state.get(p, [])
p_state: list = self.state.get(p, [])
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): # type: ignore
step_val = float(p_state["step"]) # type: ignore
step_val: float = float(p_state["step"]) # type: ignore
p_state["step"] = ( # type: ignore
torch.tensor(
step_val,
Expand All @@ -111,11 +111,11 @@ def _init_group( # NOSONAR
group,
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
):
has_complex = False
exp_avgs: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[Tensor],
) -> bool:
has_complex: bool = False
for p in group["params"]:
if p.grad is None:
continue
Expand Down Expand Up @@ -150,7 +150,7 @@ def _init_group( # NOSONAR

if group["differentiable"] and state["step"].requires_grad:
raise RuntimeError(
"`requires_grad` is not supported for `step` in differentiable mode"
"`requires_grad` is not supported for `step` in `differentiable=True` mode"
)

# Foreach without capturable does not support a tensor lr
Expand All @@ -165,11 +165,11 @@ def _init_group( # NOSONAR
return has_complex

@_use_grad_for_differentiable
def step(self, closure=None):
def step(self, closure: Optional[Callable] = None):
"""Perform a single optimization step.
Args:
closure (Callable, optional): A closure that reevaluates the model
and returns the loss.
and returns the loss.
"""
self._cuda_graph_capture_health_check()

Expand All @@ -186,7 +186,7 @@ def step(self, closure=None):
state_steps: List[Tensor] = []
beta1, beta2 = group["betas"]

has_complex = self._init_group(
has_complex: bool = self._init_group(
group,
params_with_grad,
grads,
Expand Down Expand Up @@ -255,16 +255,18 @@ def _single_tensor_adopt( # NOSONAR
assert isinstance(lr, float)

for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
grad: Tensor = grads[i] if not maximize else -grads[i]
exp_avg: Tensor = exp_avgs[i]
exp_avg_sq: Tensor = exp_avg_sqs[i]
step_t: Tensor = state_steps[i]

# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if capturable and not torch._utils.is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices

capturable_supported_devices = _get_capturable_supported_devices()
capturable_supported_devices: List[str] = (
_get_capturable_supported_devices()
)
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
Expand All @@ -275,15 +277,15 @@ def _single_tensor_adopt( # NOSONAR
step_t += 1

if torch.is_complex(param):
grad = torch.view_as_real(grad)
grad: Tensor = torch.view_as_real(grad)
if exp_avg is not None:
exp_avg = torch.view_as_real(exp_avg)
exp_avg: Tensor = torch.view_as_real(exp_avg)
if exp_avg_sq is not None:
exp_avg_sq = torch.view_as_real(exp_avg_sq)
param = torch.view_as_real(param)
exp_avg_sq: Tensor = torch.view_as_real(exp_avg_sq)
param: Tensor = torch.view_as_real(param)

if weight_decay != 0 and not decouple:
grad = grad.add(param, alpha=weight_decay)
grad: Tensor = grad.add(param, alpha=weight_decay)

step = step_t if capturable or differentiable else _get_value(step_t)
if step == 1:
Expand All @@ -293,8 +295,8 @@ def _single_tensor_adopt( # NOSONAR
if weight_decay != 0 and decouple:
param.add_(param, alpha=-lr * weight_decay)

denom = torch.clamp(exp_avg_sq.sqrt(), eps)
normed_grad = grad.div(denom)
denom: Tensor = torch.clamp(exp_avg_sq.sqrt(), eps)
normed_grad: Tensor = grad.div(denom)

if clip_exp is not None:
clip_val = (step - 1) ** clip_exp
Expand All @@ -304,9 +306,9 @@ def _single_tensor_adopt( # NOSONAR

if caution:
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
mask = (exp_avg * grad > 0).to(grad.dtype)
mask: Tensor = (exp_avg * grad > 0).to(grad.dtype)
mask.div_(mask.mean().clamp_(min=1e-3))
exp_avg = exp_avg * mask
exp_avg: Tensor = exp_avg * mask

param.add_(exp_avg, alpha=-lr)

Expand Down Expand Up @@ -335,7 +337,7 @@ def _multi_tensor_adopt( # NOSONAR
maximize: bool,
capturable: bool,
differentiable: bool,
):
) -> None:
if len(params) == 0:
return

Expand All @@ -346,7 +348,7 @@ def _multi_tensor_adopt( # NOSONAR
if capturable and not torch._utils.is_compiling():
from torch.optim.optimizer import _get_capturable_supported_devices

capturable_supported_devices = _get_capturable_supported_devices(
capturable_supported_devices: List[str] = _get_capturable_supported_devices(
supports_xla=False
)
assert all(
Expand Down Expand Up @@ -382,7 +384,7 @@ def _multi_tensor_adopt( # NOSONAR
)

if maximize:
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
device_grads: List[Tensor] = torch._foreach_neg(device_grads) # type: ignore[assignment]

# Update steps
# If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
Expand All @@ -400,7 +402,7 @@ def _multi_tensor_adopt( # NOSONAR
if maximize:
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
else:
device_grads = torch._foreach_add(
device_grads: List[Tensor] = torch._foreach_add(
device_grads, device_params, alpha=weight_decay
)

Expand All @@ -427,10 +429,10 @@ def _multi_tensor_adopt( # NOSONAR
# Apply caution as per 'Cautious Optimizers' - https://arxiv.org/abs/2411.16085
masks = torch._foreach_mul(device_exp_avgs, device_grads)
masks = [(m > 0).to(g.dtype) for m, g in zip(masks, device_grads)]
mask_scale = [m.mean() for m in masks]
mask_scale: list = [m.mean() for m in masks]
torch._foreach_maximum_(mask_scale, 1e-3)
torch._foreach_div_(masks, mask_scale)
device_exp_avgs = torch._foreach_mul(device_exp_avgs, masks)
device_exp_avgs: List[Tensor] = torch._foreach_mul(device_exp_avgs, masks)

torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)

Expand Down Expand Up @@ -465,10 +467,10 @@ def adopt( # NOSONAR
eps: float,
caution: bool,
maximize: bool,
):
) -> None:
r"""Functional API that performs ADOPT algorithm computation."""
if foreach is None:
foreach = False
foreach: bool = False

# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
Expand All @@ -484,9 +486,9 @@ def adopt( # NOSONAR
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

if foreach and not torch.jit.is_scripting():
func = _multi_tensor_adopt
func: Callable = _multi_tensor_adopt
else:
func = _single_tensor_adopt
func: Callable = _single_tensor_adopt

func(
params,
Expand Down

0 comments on commit 16799c9

Please sign in to comment.