Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hoyer neural and related layers for Hoyer training #292

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lava/lib/dl/slayer/block/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
'rf', 'rf_iz',
'alif',
'adrf', 'adrf_iz',
'sigma_delta'
'sigma_delta',
'cuba_hoyer',
]
116 changes: 116 additions & 0 deletions src/lava/lib/dl/slayer/block/cuba_hoyer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

"""CUBA-Hoyer-LIF layer blocks"""

import torch

from . import base, cuba
from ..neuron import cuba_hoyer
from ..synapse import layer as synapse
from ..axon import Delay, delay


class AbstractCubaHoyer(torch.nn.Module):
"""Abstract block class for Current Based Leaky Integrator neuron. This
should never be instantiated on it's own.
"""
def __init__(self, *args, **kwargs):
super(AbstractCubaHoyer, self).__init__(*args, **kwargs)
if self.neuron_params is not None:
self.neuron = cuba_hoyer.HoyerNeuron(**self.neuron_params)
delay = kwargs['delay'] if 'delay' in kwargs.keys() else False
self.delay = Delay(max_delay=62) if delay is True else None
del self.neuron_params


def _doc_from_base(base_doc):
""" """
return base_doc.__doc__.replace(
'Abstract', 'CUBA Hoyer LIF'
).replace(
'neuron parameter', 'CUBA Hoyer LIF neuron parameter'
).replace(
'This should never be instantiated on its own.',
'The block is 8 bit quantization ready.'
)


class Dense(AbstractCubaHoyer, base.AbstractDense):
def __init__(self, *args, **kwargs):
super(Dense, self).__init__(*args, **kwargs)
self.synapse = synapse.Dense(**self.synapse_params)
if 'pre_hook_fx' not in kwargs.keys():
self.synapse.pre_hook_fx = self.neuron.quantize_8bit
del self.synapse_params

Dense.__doc__ = _doc_from_base(base.AbstractDense)

class Conv(AbstractCubaHoyer, base.AbstractConv):
def __init__(self, *args, **kwargs):
super(Conv, self).__init__(*args, **kwargs)
self.synapse = synapse.Conv(**self.synapse_params)
if 'pre_hook_fx' not in kwargs.keys():
self.synapse.pre_hook_fx = self.neuron.quantize_8bit
del self.synapse_params


Conv.__doc__ = _doc_from_base(base.AbstractConv)


def step_delay(module, x):
"""Step delay computation. This simulates the 1 timestep delay needed
for communication between layers.

Parameters
----------
module: module
python module instance
x : torch.tensor
Tensor data to be delayed.
"""
if hasattr(module, 'delay_buffer') is False:
module.delay_buffer = None
persistent_state = hasattr(module, 'neuron') \
and module.neuron.persistent_state is True
if module.delay_buffer is not None:
if module.delay_buffer.shape[0] != x.shape[0]: # batch mismatch
module.delay_buffer = None
if persistent_state:
delay_buffer = 0 if module.delay_buffer is None else module.delay_buffer
module.delay_buffer = x[..., -1]
x = delay(x, 1)
if persistent_state:
x[..., 0] = delay_buffer
return x

class Pool(cuba.Pool):
def __init__(self, *args, **kwargs):
super(Pool, self).__init__(*args, **kwargs)
self.hoyer_loss = 0.0

def forward(self, x):
"""Forward computation method. The input must be in ``NCHWT`` format.
"""
self.neuron.shape = x[0].shape
z = self.synapse(x)
# skip the neuron computation in the pooling layer
# x = self.neuron(z)
x = z
if self.delay_shift is True:
x = step_delay(self, x)
if self.delay is not None:
x = self.delay(x)

if self.count_log is True:
return x, torch.mean(x > 0)
else:
return x

Pool.__doc__ = _doc_from_base(base.AbstractPool)

class Affine(cuba.Affine):
def __init__(self, *args, **kwargs):
super(Affine, self).__init__(*args, **kwargs)

Affine.__doc__ = _doc_from_base(base.AbstractAffine)
1 change: 1 addition & 0 deletions src/lava/lib/dl/slayer/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
'adrf', 'adrf_iz',
'sigma_delta',
'Dropout', 'norm',
'cuda_hoyer',
]
260 changes: 260 additions & 0 deletions src/lava/lib/dl/slayer/neuron/cuba_hoyer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: BSD-3-Clause

"""CUBA neuron model."""

import torch
import torch.nn as nn

from .dynamics import leaky_integrator
from ..spike import HoyerSpike
from .cuba import Neuron

# These are tuned heuristically so that scale_grad=1 and tau_grad=1 serves
# as a good starting point

# SCALE_RHO_MULT = 0.1
# TAU_RHO_MULT = 10
SCALE_RHO_MULT = 0.1
TAU_RHO_MULT = 100
# SCALE_RHO_MULT = 1
# TAU_RHO_MULT = 1

class HoyerNeuron(Neuron):
"""This is the implementation of Loihi CUBA neuron.

.. math::
u[t] &= (1 - \\alpha_u)\\,u[t-1] + x[t] \\

v[t] &= (1 - \\alpha_v)\\,v[t-1] + u[t] + \\text{bias} \\

s[t] &= v[t] \\geq \\vartheta \\

v[t] &= v[t]\\,(1-s[t])

The internal state representations are scaled down compared to
the actual hardware implementation. This allows for a natural range of
synaptic weight values as well as the gradient parameters.

The neuron parameters like threshold, decays are represented as real
values. They internally get converted to fixed precision representation of
the hardware. It also provides properties to access the neuron
parameters in fixed precision states. The parameters are internally clamped
to the valid range.

Parameters
----------
threshold : float
neuron threshold.
current_decay : float or tuple
the fraction of current decay per time step. If ``shared_param``
is False, then it can be specified as a tuple (min_decay, max_decay).
voltage_decay : float or tuple
the fraction of voltage decay per time step. If ``shared_param`` is
False, then it can be specified as a tuple (min_decay, max_decay).
tau_grad : float, optional
time constant of spike function derivative. Defaults to 1.
scale_grad : float, optional
scale of spike function derivative. Defaults to 1.
scale : int, optional
scale of the internal state. ``scale=1`` will result in values in the
range expected from the of Loihi hardware. Defaults to 1<<6.
norm : fx-ptr or lambda, optional
normalization function on the dendrite output. None means no
normalization. Defaults to None.
dropout : fx-ptr or lambda, optional
neuron dropout method. None means no normalization. Defaults to None.
shared_param : bool, optional
flag to enable/disable shared parameter neuron group. If it is
False, individual parameters are assigned on a per-channel basis.
Defaults to True.
persistent_state : bool, optional
flag to enable/disable persistent state between iterations. Defaults to
False.
requires_grad : bool, optional
flag to enable/disable learning on neuron parameter. Defaults to False.
graded_spike : bool, optional
flag to enable/disable graded spike output. Defaults to False.
num_features : int, optinal
if the Hoyer neuron is behind the Conv, it is the number of feature channels; if behind the Dense, it should be 1. Defaults to 1.
T: int, optinal
the number of time steps. Defaults to 1.
hoyer_type: str, optinal
sum: the Hoyer Ext will be averaged across all feature channels; cw: the Hoyer Ext will be channel-wise. Defaults to sum.
momentum: float, optinal
the value used for the running_hoyer_ext computation.
"""
def __init__(
self, threshold, current_decay, voltage_decay,
tau_grad=1, scale_grad=1, scale=1 << 6,
norm=None, dropout=None,
shared_param=True, persistent_state=False, requires_grad=False, graded_spike=False,
num_features=1, hoyer_mode=True, T=1, hoyer_type='sum', momentum=0.9, delay=False
):
super(HoyerNeuron, self).__init__(
threshold=threshold,
current_decay=current_decay,
voltage_decay=voltage_decay,
tau_grad=tau_grad,
scale_grad=scale_grad,
scale=scale,
norm=norm,
dropout=dropout,
shared_param=shared_param,
persistent_state=persistent_state,
requires_grad=requires_grad
)

# add some attributes for hoyer spiking
# self.learnable_thr = nn.Parameter(torch.FloatTensor([self.threshold]), requires_grad=True)
self.register_parameter(
'learnable_thr',
torch.nn.Parameter(
torch.FloatTensor([self.threshold]),
requires_grad=self.requires_grad
),
)
self.T = T
self.hoyer_type = hoyer_type
self.hoyer_mode = hoyer_mode
self.num_features = num_features
self.momentum = 0.9
if self.num_features > 1:
self.bias = nn.Parameter(torch.zeros(1,num_features,1,1,1), requires_grad=True)
if self.hoyer_mode:
self.bn = nn.BatchNorm2d(num_features=self.num_features)
self.delay = delay

if self.num_features > 1:
# Conv layer B,C,H,W,T
# self.register_buffer('running_hoyer_ext', torch.zeros([1, self.num_features, 1, 1, T], **factory_kwargs))
if self.hoyer_type == 'sum':
self.register_buffer('running_hoyer_ext', torch.zeros([1, 1, 1, 1, T]))
else:
self.register_buffer('running_hoyer_ext', torch.zeros([1, self.num_features, 1, 1, T]))
else:
# Linear layer B,C,T
self.register_buffer('running_hoyer_ext', torch.zeros([1, 1, T]))

if norm is not None:
if self.complex is False:
self.norm = norm(num_features=num_features)
if hasattr(self.norm, 'pre_hook_fx'):
self.norm.pre_hook_fx = self.quantize_8bit
else:
self.real_norm = norm(num_features=num_features)
self.imag_norm = norm(num_features=num_features)
if hasattr(self.real_norm, 'pre_hook_fx'):
self.real_norm.pre_hook_fx = self.quantize_8bit
if hasattr(self.imag_norm, 'pre_hook_fx'):
self.imag_norm.pre_hook_fx = self.quantize_8bit
else:
self.norm = None
if self.complex is True:
self.real_norm = None
self.imag_norm = None

# self.register_buffer('ref_delay', torch.FloatTensor([ref_delay]))

self.clamp()

def thr_clamp(self):
"""Clamps the threshold value to
:math:`[\\verb~1/scale~, \\infty)`."""
self.learnable_thr.data.clamp_(1 / self.scale)

def spike(self, voltage, hoyer_ext=1.0):
"""Extracts spike points from the voltage timeseries. It assumes the
reset dynamics is already applied.

Parameters
----------
voltage : torch tensor
neuron voltage dynamics
hoyer_ext : torch tensor
extra hoyer ext

Returns
-------
torch tensor
spike output

"""
spike = HoyerSpike.apply(
voltage,
hoyer_ext,
self.tau_rho * TAU_RHO_MULT,
self.scale_rho * SCALE_RHO_MULT,
self.graded_spike,
self.voltage_state,
# self.s_scale,
1,
)

if self.persistent_state is True:
with torch.no_grad():
self.voltage_state = leaky_integrator.persistent_state(
self.voltage_state, spike[..., -1]
).detach().clone()

if self.drop is not None:
spike = self.drop(spike)

return spike

def cal_hoyer_loss(self, x, thr=None):
if thr:
x[x>thr] = thr
x[x<0.0] = 0.0
# avoid division by zero
return (torch.sum(torch.abs(x))**2) / (torch.sum(x**2) + 1e-9)

def forward(self, input):
"""Computes the full response of the neuron instance to an input.
The input shape must match with the neuron shape. For the first time,
the neuron shape is determined from the input automatically.

Parameters
----------
input : torch tensor
Input tensor.

Returns
-------
torch tensor
spike response of the neuron.

"""
if not self.hoyer_mode:
out = super().forward(input)
return out
if self.num_features > 1 and hasattr(self, 'bn'):
B,C,H,W,T = input.shape
input = self.bn(input.permute(4,0,1,2,3).reshape(T*B,C, H, W).contiguous()).reshape(T,B,C,H,W).permute(1,2,3,4,0).contiguous()
_, voltage = self.dynamics(input)
self.hoyer_loss = self.cal_hoyer_loss(torch.clamp(voltage.clone(), min=0.0, max=1.0), 1.0)
self.clamp()
self.thr_clamp()
voltage = voltage / self.learnable_thr
if self.training:
clamped_input = torch.clamp(voltage.clone().detach(), min=0.0, max=1.0)
dim = tuple(range(clamped_input.ndim-1))
if self.hoyer_type == 'sum':
hoyer_ext = torch.sum(clamped_input**2, dim=dim) / (torch.sum(torch.abs(clamped_input), dim=dim))
else:
hoyer_ext = torch.sum((clamped_input)**2, dim=(0,2,3), keepdim=True) / torch.sum(torch.abs(clamped_input), dim=(0,2,3), keepdim=True)

hoyer_ext = torch.nan_to_num(hoyer_ext, nan=1.0)
with torch.no_grad():
if self.delay:
# delay hoyer ext
self.running_hoyer_ext[..., 0] = 0
self.running_hoyer_ext = torch.roll(self.running_hoyer_ext, shifts=-1, dims=-1)
self.running_hoyer_ext = self.momentum * hoyer_ext + (1 - self.momentum) * self.running_hoyer_ext
else:
# do not delay hoyer ext
self.running_hoyer_ext = self.momentum * hoyer_ext + (1 - self.momentum) * self.running_hoyer_ext
else:
hoyer_ext = self.running_hoyer_ext
output = self.spike(voltage, hoyer_ext)
return output
Loading