-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Container modules with advanced control flow & modules with multiple inputs #306
Comments
Hi @m-lyon, thanks for your question. May I ask which BackPACK extensions you are planning to use? |
Of course, my apologies for not including that info. I'm attempting to use the Laplace framework for a model, in my current implementation it's using the |
The custom_module = torch.nn.Sequential(
torch.nn.Sequential(
# layers of 'OtherCustomModule'
),
torch.nn.Sequential(
# layers of 'AnotherCustomModule'
),
torch.nn.Sequential(
# layers of 'OtherCustomModule'
),
) Now you can fill in the layers that are already supported by BackPACK, as well as the ones you wish to implement yourself. As an alternative, you could also consider looking into extending the alternative backend of the Laplace library (ASDL). Let me know if that helps. |
Thank you for pointing me in the right direction. I'll take a look at the second order extension example and see if I can't implement something myself. Unfortunately the Laplace framework currently does not support ASDL for regression, only classification, so i'd assume that would be an equal amount of work. Thanks |
It doesn't seem completely obvious how I would do this because my network (and subsequently the custom layers within my network) have several inputs. Therefore without modifying Conceptually my model looks something like this class Network(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = CustomLayer1()
self.layer2 = CustomLayer2()
self.layer3 = CustomLayer3()
def forward(self, tensor1, tensor2, tensor3):
out = self.layer1(tensor1, tensor2, tensor2)
out = self.layer2(out, tensor2, tensor3)
out = self.layer3(out, tensor3, tensor3)
return out |
Hey, thanks for the update! You are right that this is indeed challenging for your model. The problem with the above forward pass is that |
They are somewhat complex. I think being able to break the problem down into smaller parts would be a much more feasible solution rather than fusing them into one, especially considering that the number of layers in reality is much greater than in this example. I think something that is quite important to solving this problem, which i'm currently unsure how to do, is writing extensions for I've taken a look at the |
BackPACK's design somewhat conflicts with this feature, the problem being that a module which does not act like a container (such as In your case, the container The short answer is that in such case it is still possible to get the backpropagation working (extensive answer below). |
I put together a self-consistent example below. It demonstrates how to add support for a layer I had to perform a slight adjustment to BackPACK's backpropagation mechanism. You will have to install from the I think this should get you started. There might be additional pitfalls on the way, though. Let me know if you run into issues. Here is the code: """Second-order extensions for modules with multiple inputs and slightly advanced control flow containers."""
from typing import List, Tuple
from torch import Tensor, allclose, einsum, manual_seed, rand, zeros
from torch.nn import Module, MSELoss, Parameter
from torch.nn.utils.convert_parameters import parameters_to_vector
from backpack import backpack, extend
from backpack.extensions import DiagGGNExact
from backpack.extensions.module_extension import ModuleExtension
from backpack.hessianfree.ggnvp import ggn_vector_product
from backpack.utils.convert_parameters import vector_to_parameter_list
class MultiplyModule(Module):
"""Module that multiplies all its inputs with itself and a weight."""
def __init__(self, weight: float = 1.0):
super().__init__()
self.weight = Parameter(Tensor([weight]))
def forward(self, *inputs: Tensor):
"""Multiply all inputs, then multiply by the weight and return result."""
# accept batched scalars only for simplicity
assert len({i.shape[0] for i in inputs}) == 1
assert all(i.dim() == 2 and i.shape[1] == 1 for i in inputs)
result = self.weight
for i in inputs:
result = result * i
return result
class DiagGGNMultiplyModule(ModuleExtension):
"""Describes how to compute the GGN diagonal for a ``MultiplyModule``."""
def __init__(self):
super().__init__(params=["weight"])
def backpropagate(
self,
ext: DiagGGNExact,
module: MultiplyModule,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
bpQuantities: Tensor,
) -> Tuple[Tensor]:
"""Backprop GGN matrix square root from output to inputs of ``MultiplyModule``.
This multiplies the backpropagated object with output-input Jacobian for each
input.
Returns a tuple with the backpropagated GGN matrix square root for each input.
"""
sqrt_ggn = bpQuantities
inputs = self.get_inputs(module) # stored by BackPACK in the forward pass
backpropagate_to_inputs = []
# apply the output-input Jacobian for all inputs
for i in range(len(inputs)):
other = [inp for j, inp in enumerate(inputs) if j != i]
jac_inp_i = module.weight
for inp_other in other:
jac_inp_i = jac_inp_i * inp_other
backpropagate_to_inputs.append(sqrt_ggn * jac_inp_i)
# tuple signalizes each entry will be backpropped to its input
return tuple(backpropagate_to_inputs)
def weight(
self,
ext: DiagGGNExact,
module: MultiplyModule,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
bpQuantities: Tensor,
) -> Tensor:
"""Compute the GGN diagonal for the weight of ``MultiplyModule``."""
sqrt_ggn = bpQuantities
inputs = self.get_inputs(module)
jac = inputs.pop()
while inputs:
jac = jac * inputs.pop()
jac_sqrt_ggn = einsum("vni,ni->vni", sqrt_ggn, jac)
return einsum("vni,vni->i", jac_sqrt_ggn, jac_sqrt_ggn)
@staticmethod
def get_inputs(module: Module) -> List[Tensor]:
"""Get all inputs of ``MultiplyModule``'s forward pass."""
layer_inputs = []
i = 0
while hasattr(module, f"input{i}"):
layer_inputs.append(getattr(module, f"input{i}"))
i += 1
return layer_inputs
class MySimpleContainer(Module):
"""Container module that feeds outputs of submodules into other children.
This is okay as long as there are no calls to ``nn.functional``'s in the
``forward`` method.
"""
def __init__(self) -> None:
super().__init__()
self.layer1 = MultiplyModule(0.5)
self.layer2 = MultiplyModule(-1.5)
self.layer3 = MultiplyModule(3.0)
def forward(self, x1, x2, x3):
out = self.layer1(x1, x2, x2)
out = self.layer2(out, x2, x3)
return self.layer3(out, x3, x3)
###############################################################################
# Set up toy problem #
###############################################################################
batch_size = 10
manual_seed(0)
x1 = rand(batch_size, 1)
x2 = rand(batch_size, 1)
x3 = rand(batch_size, 1)
label = rand(batch_size, 1)
model = MySimpleContainer()
loss_func = MSELoss()
###############################################################################
# Compute the GGN diagonal with autograd #
###############################################################################
output = model(x1, x2, x3)
loss = loss_func(output, label)
parameters = list(model.parameters())
num_params = sum(p.numel() for p in parameters)
diag_ggn_autograd = zeros(num_params)
# compute GGN column by column, extracting the diagonal element
for i in range(num_params):
e_i = zeros(num_params)
e_i[i] = 1.0
e_i_list = vector_to_parameter_list(e_i, parameters)
ggn_diag_i_list = ggn_vector_product(loss, output, model, e_i_list)
diag_ggn_autograd[i] = parameters_to_vector(ggn_diag_i_list)[i]
###############################################################################
# Compute the GGN diagonal with BackPACK #
###############################################################################
# only extend sub-modules; BackPACK does not know ``MySimpleContainer``
for submodule in model.children():
extend(submodule)
loss_func = extend(loss_func)
loss = loss_func(model(x1, x2, x3), label)
ext = DiagGGNExact()
# tell the extension for the GGN diagonal to use ``DiagGGNMultiplyModule`` when it
# encounters ``MultiplyModule``.
ext.set_module_extension(MultiplyModule, DiagGGNMultiplyModule())
with backpack(ext):
loss.backward()
diag_ggn_backpack = parameters_to_vector([p.diag_ggn_exact for p in parameters])
###############################################################################
# Compare #
###############################################################################
assert allclose(diag_ggn_autograd, diag_ggn_backpack) |
Thanks for this example, i'll have a play around with it and see if i can't intuit how to extend this to my problem.
So while this is true for the simplistic example I gave, the real module is pretty involved. To give a clearer idea of the complexity the module (called a Hence why I was hoping in being able to break down the problem, for each step I could figure out what I needed to implement for that given function call within a |
That's okay, but since you are implementing this as a Am I correct that the above example at least solves your problem of handling the data flow of input tensors to the module? |
Yes, as far as I can tell from looking at the example, though I haven't tested this yet.
Thanks for pointing this out. So, as far as I understand (please correct me if i've misunderstood) I need to implement the Additionally, my |
The operation performed in I think the way to progress on this is to first formulate your layer in terms of modules. For instance, it would be good to decide whether you want the |
It seems like given the amount of operations in the forward pass, that refactoring these into a
I understand, in that case i'm unsure how to derive the jacobian for the kind of matrix manipulation operations like |
You can try to take a look at the documentation of
BackPACK's focus is on standard DNN architectures. So if an operation is missing, chances are it can still be added at relatively little overhead. With the examples from the website we tried to simplify this procedure for others. But in the end you won't get around implementing multiplication by the Jacobian. I am happy to give feedback and review your code if you decide to tackle this, but won't have enough time to write code. I would proceed as follows:
Best, |
Thank you for your suggestions, I've set up a dev environment with I think this is probably a naive question but if you can verify the |
Hi, great to hear you're making progress! You are right that in principle one could use
BackPACK uses PyTorch's python API and might be slower than the I'd be happy to merge functionality that adds support for arbitrary layers and uses |
Hey Felix, I've written the following functionality to compute the # backpack/core/derivates/model.py
from typing import List, Tuple, Optional
import torch
from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.hessianfree.lop import transposed_jacobian_vector_product
class ArbitraryModelDerivatives(BaseDerivatives):
def _jac_t_mat_prod(
self,
module: torch.nn.Module,
g_inp: Tuple[torch.Tensor],
g_out: Tuple[torch.Tensor],
mat: torch.Tensor,
subsampling: Optional[List[int]] = None,
) -> torch.Tensor:
# Just 1 input for now
if not module.input0.requires_grad:
raise RuntimeError('requires_grad needed for arbitrary jac_t_mat_prod')
return torch.stack(transposed_jacobian_vector_product(module.output, module.input0, mat)) # test_simple.py
import torch
from backpack import extend
from backpack.core.derivatives.model import ArbitraryModelDerivatives
class MyModel(torch.nn.Module):
def forward(self, x):
return torch.sin(x)
def run_test():
model = extend(MyModel())
derivs = ArbitraryModelDerivatives()
# Just one input in this case
inputs = tuple([torch.range(0, 2).view(3, 1, 1).expand(3, 3, 3)])
for inp in inputs:
inp.requires_grad = True
output = model(*inputs)
# Use a matrix of ones to compare result with known answer
mat = torch.ones(output.shape)
res = derivs._jac_t_mat_prod(model, None, None, mat, None) # shape -> (1, 3, 3, 3)
# Known answer
ans = torch.cos(inputs[0]) # shape -> (3, 3, 3)
print(torch.allclose(res, ans)) # prints 'True'
if __name__ == '__main__':
run_test() If i've made any incorrect assumptions here or have any mistakes please let me know. One thing i'm unsure about is the required dimensionality of |
Hi, this looks like a good start! Let me clarify some of the points I meant in my previous post:
|
|
I've made the following edits # backpack.core.derivates.model.py
from typing import List, Tuple, Optional
from itertools import count
import torch
from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.hessianfree.lop import transposed_jacobian_vector_product
class ArbitraryModelDerivatives(BaseDerivatives):
'''Arbitrary Model Derivative'''
def _get_num_inputs(self, module) -> List[str]:
inputs = []
for i in count():
if hasattr(module, f'input{i}'):
inputs.append(f'input{i}')
else:
break
return inputs
def _jac_t_mat_prod(
self,
module: torch.nn.Module,
g_inp: Tuple[torch.Tensor],
g_out: Tuple[torch.Tensor],
mat: torch.Tensor,
subsampling: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, ...]:
# Start with just 1 input
input_names = self._get_num_inputs(module)
res = []
for inp_name in input_names:
inp = getattr(module, inp_name)
if not inp.requires_grad:
raise RuntimeError(f'requires_grad needed for module.{inp_name} jac_t_mat_prod')
if subsampling is None:
jtmp = torch.stack(transposed_jacobian_vector_product(module.output, inp, mat))
else:
raise NotImplementedError('Subsampling not currently implemented')
res.append(jtmp)
return tuple(res) # backpack.extensions.backprop_extension.py
class BackpropExtension(ABC):
...
def __get_module_extension(self, module: Module) -> Union[ModuleExtension, None]:
module_extension = self.__module_extensions.get(module.__class__)
if module_extension is None:
module_extension = self._get_arbitrary_extension(module)
...
def _get_arbitrary_extension(self, module):
return None # None in Abstract base class # backpack.extensions.secondorder.diag_ggn.__init__.py
from . import model
class DiagGGN(SecondOrderBackpropExtension):
...
def _get_arbitrary_extension(self, module):
return model.DiagGGNArbitraryModel()
... Adding all this, using a custom module that has various layers inside of it, including bp_quantity = self.__get_backproped_quantity(
extension, module.output, delete_old_quantities
) evaluates to |
Hi,
For Hessian-related quantities, BackPACK backpropagates multiple vectors in parallel. These vectors are stacked, which yields the leading dimension |
I tweaked your example to make it work and commented on some of the details that relate to my previous posts: """Backpropagation through ReLU for GGN diagonal via ``torch.autograd``."""
from typing import List, Optional, Tuple
from torch import Tensor, allclose, enable_grad, rand, stack
from torch.autograd import grad
from torch.nn import Linear, Module, MSELoss, ReLU, Sequential
from torch.nn.functional import relu
from backpack import backpack, extend, extensions
from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule
class ArbitraryModelDerivatives(BaseDerivatives):
"""Arbitrary Model Derivative"""
def __init__(self, forward_func) -> None:
super().__init__()
self.forward_func = forward_func
def _jac_t_mat_prod(
self,
module: Module,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
mat: Tensor,
subsampling: Optional[List[int]] = None,
) -> Tensor:
if subsampling is not None:
raise NotImplementedError("Subsampling not currently implemented")
print("Using arbitrary model derivatives.")
# regenerate computation graph for differentiation
with enable_grad():
# NOTE: Cannot use module(module.input0) since this triggers its
# forward hook and messes up the backpropagation internals (the
# internals rely on the memory address of module.input0, but the
# old module.input0 will be overwritten during
# module(module.input0)).
re_input = module.input0.clone().detach().requires_grad_(True)
re_output = self.forward_func(re_input)
# V vectors of shape [*module.input0.shape]
vjps = [grad(re_output, re_input, v, retain_graph=True)[0] for v in mat]
return stack(vjps) # shape [V, *module.input0.shape]
class DiagGGNReLUArbitrary(DiagGGNBaseModule):
"""Implements DiagGGN backpropagation for ReLU layer using ``torch.autograd``."""
def __init__(self):
super().__init__(derivatives=ArbitraryModelDerivatives(self.forward_func))
@staticmethod
def forward_func(input0):
return relu(input0)
X, y = rand(10, 5), rand(10, 3)
model = Sequential(Linear(5, 4), ReLU(), Linear(4, 3))
loss_func = MSELoss()
model = extend(model)
loss_func = extend(loss_func)
# ground truth
with backpack(extensions.DiagGGNExact()):
loss = loss_func(model(X), y)
loss.backward()
diag_ggn = [p.diag_ggn_exact for p in model.parameters()]
# now using arbitrary derivatives under the hood
ext = extensions.DiagGGNExact()
ext.set_module_extension(
ReLU,
DiagGGNReLUArbitrary(),
overwrite=True, # force overwrite as ReLU already exists within BackPACK
)
with backpack(ext):
loss = loss_func(model(X), y)
loss.backward()
diag_ggn_arbitrary = [p.diag_ggn_exact for p in model.parameters()]
for diag, diag_arbitrary in zip(diag_ggn, diag_ggn_arbitrary):
print(allclose(diag, diag_arbitrary)) |
I have a somewhat complicated
torch.nn.Module
, let's say for arguments sake its structure is a bit like this:Whilst
OtherCustomModule
andAnotherCustomModule
are themselves composed of some custom functionality, there's some standard layers within them likenn.Linear
, but there's other stuff going on too.I've read that as long as the direct children are standard torch modules like
nn.Linear
thatbackpack
can detect that and deal with that, however that isn't the case here.Looking at the example custom module docs with
ScaleModuleBatchGrad
, I'm not sure how i can implement my own class here sinceself.layer1
etc arenn.Module
s notnn.Parameter
s?The text was updated successfully, but these errors were encountered: