Skip to content

Commit

Permalink
1840 add tbf layer (#5757)
Browse files Browse the repository at this point in the history
Fixes #1840 .

### Description

I integrated the trainable bilateral filter layer (TBF) in the MONAI
repository as a new PyTorch filter layer. The TBF contains an analytical
gradient derivation toward its filter parameters and its noisy input
image which enables gradient-based optimization within the PyTorch
graph. See [here](https://doi.org/10.1002/mp.15718) for more details on
the gradient derivation. Unit tests were added that check the filter
output as well as the gradient computation.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Fabian Wagner <[email protected]>
  • Loading branch information
faebstn96 authored Jan 17, 2023
1 parent 373e47d commit 6803061
Show file tree
Hide file tree
Showing 12 changed files with 1,939 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,11 @@ Layers
.. autoclass:: BilateralFilter
:members:

`TrainableBilateralFilter`
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: TrainableBilateralFilter
:members:

`PHLFilter`
~~~~~~~~~~~
.. autoclass:: PHLFilter
Expand Down
2 changes: 2 additions & 0 deletions monai/csrc/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// filtering
m.def("bilateral_filter", &BilateralFilter, "Bilateral Filter");
m.def("phl_filter", &PermutohedralFilter, "Permutohedral Filter");
m.def("tbf_forward", &TrainableBilateralFilterForward, "Trainable Bilateral Filter Forward");
m.def("tbf_backward", &TrainableBilateralFilterBackward, "Trainable Bilateral Filter Backward");

// lltm
m.def("lltm_forward", &lltm_forward, "LLTM forward");
Expand Down
1 change: 1 addition & 0 deletions monai/csrc/filtering/filtering.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ limitations under the License.

#include "bilateral/bilateral.h"
#include "permutohedral/permutohedral.h"
#include "trainable_bilateral/trainable_bilateral.h"
249 changes: 249 additions & 0 deletions monai/csrc/filtering/trainable_bilateral/bf_layer_cpu_backward.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
Copyright (c) MONAI Consortium
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "trainable_bilateral.h"

struct Indexer {
public:
Indexer(int dimensions, int* sizes) {
m_dimensions = dimensions;
m_sizes = sizes;
m_index = new int[dimensions]{0};
}
~Indexer() {
delete[] m_index;
}

bool operator++(int) {
for (int i = 0; i < m_dimensions; i++) {
m_index[i] += 1;

if (m_index[i] < m_sizes[i]) {
return true;
} else {
m_index[i] = 0;
}
}

return false;
}

int& operator[](int dimensionIndex) {
return m_index[dimensionIndex];
}

private:
int m_dimensions;
int* m_sizes;
int* m_index;
};

template <typename scalar_t>
void BilateralFilterCpuBackward_3d(
torch::Tensor gradientInputTensor,
torch::Tensor gradientOutputTensor,
torch::Tensor inputTensor,
torch::Tensor outputTensor,
torch::Tensor outputWeightsTensor,
torch::Tensor dO_dx_ki,
float sigma_x,
float sigma_y,
float sigma_z,
float colorSigma) {
// Getting tensor description.
TensorDescription desc = TensorDescription(gradientInputTensor);

// Raw tensor data pointers.
scalar_t* gradientInputTensorData = gradientInputTensor.data_ptr<scalar_t>();
scalar_t* gradientOutputTensorData = gradientOutputTensor.data_ptr<scalar_t>();
scalar_t* inputTensorData = inputTensor.data_ptr<scalar_t>();
scalar_t* outputTensorData = outputTensor.data_ptr<scalar_t>();
scalar_t* outputWeightsTensorData = outputWeightsTensor.data_ptr<scalar_t>();
scalar_t* dO_dx_kiData = dO_dx_ki.data_ptr<scalar_t>();

// Pre-calculate common values
int windowSize_x = std::max(((int)ceil(5.0f * sigma_x) | 1), 5); // ORing last bit to ensure odd window size
int windowSize_y = std::max(((int)ceil(5.0f * sigma_y) | 1), 5); // ORing last bit to ensure odd window size
int windowSize_z = std::max(((int)ceil(5.0f * sigma_z) | 1), 5); // ORing last bit to ensure odd window size
int halfWindowSize_x = floor(0.5f * windowSize_x);
int halfWindowSize_y = floor(0.5f * windowSize_y);
int halfWindowSize_z = floor(0.5f * windowSize_z);
int halfWindowSize_arr[] = {halfWindowSize_x, halfWindowSize_y, halfWindowSize_z};
scalar_t spatialExpConstant_x = -1.0f / (2 * sigma_x * sigma_x);
scalar_t spatialExpConstant_y = -1.0f / (2 * sigma_y * sigma_y);
scalar_t spatialExpConstant_z = -1.0f / (2 * sigma_z * sigma_z);
scalar_t colorExpConstant = -1.0f / (2 * colorSigma * colorSigma);

// Set kernel sizes with respect to the defined spatial sigmas.
int* kernelSizes = new int[desc.dimensions];

kernelSizes[0] = windowSize_x;
kernelSizes[1] = windowSize_y;
kernelSizes[2] = windowSize_z;

// Pre-calculate gaussian kernel and distance map in 1D.
scalar_t* gaussianKernel_x = new scalar_t[windowSize_x];
scalar_t* gaussianKernel_y = new scalar_t[windowSize_y];
scalar_t* gaussianKernel_z = new scalar_t[windowSize_z];
scalar_t* xDistanceSquared = new scalar_t[windowSize_x];
scalar_t* yDistanceSquared = new scalar_t[windowSize_y];
scalar_t* zDistanceSquared = new scalar_t[windowSize_z];

for (int i = 0; i < windowSize_x; i++) {
int distance = i - halfWindowSize_x;
gaussianKernel_x[i] = exp(distance * distance * spatialExpConstant_x);
xDistanceSquared[i] = distance * distance;
}
for (int i = 0; i < windowSize_y; i++) {
int distance = i - halfWindowSize_y;
gaussianKernel_y[i] = exp(distance * distance * spatialExpConstant_y);
yDistanceSquared[i] = distance * distance;
}
for (int i = 0; i < windowSize_z; i++) {
int distance = i - halfWindowSize_z;
gaussianKernel_z[i] = exp(distance * distance * spatialExpConstant_z);
zDistanceSquared[i] = distance * distance;
}

// Looping over the batches
for (int b = 0; b < desc.batchCount; b++) {
int batchOffset = b * desc.batchStride;

// Looping over all dimensions for the home element
for (int z = 0; z < desc.sizes[2]; z++)
#pragma omp parallel for
for (int y = 0; y < desc.sizes[1]; y++) {
for (int x = 0; x < desc.sizes[0]; x++) {
// Calculating indexing offset for the home element
int homeOffset = batchOffset;

int homeIndex[] = {x, y, z};
homeOffset += x * desc.strides[0];
homeOffset += y * desc.strides[1];
homeOffset += z * desc.strides[2];

// Zero kernel aggregates.
scalar_t filter_kernel = 0;
scalar_t valueSum = 0;

// Looping over all dimensions for the neighbour element
Indexer kernelIndex = Indexer(desc.dimensions, kernelSizes);
do // while(kernelIndex++)
{
// Calculating buffer offset for the neighbour element
// Index is clamped to the border in each dimension.
int neighbourOffset = batchOffset;
bool flagNotClamped = true;

for (int i = 0; i < desc.dimensions; i++) {
int neighbourIndex = homeIndex[i] + kernelIndex[i] - halfWindowSize_arr[i];
int neighbourIndexClamped = std::min(desc.sizes[i] - 1, std::max(0, neighbourIndex));
neighbourOffset += neighbourIndexClamped * desc.strides[i];
if (neighbourIndex != neighbourIndexClamped) {
flagNotClamped = false;
}
}

// Euclidean color distance.
scalar_t colorDistance = 0;
scalar_t colorDistanceSquared = 0;

for (int i = 0; i < desc.channelCount; i++) {
scalar_t diff = inputTensorData[neighbourOffset + i * desc.channelStride] -
inputTensorData[homeOffset +
i * desc.channelStride]; // Be careful: Here it is (X_k - X_i) and not (X_i - X_q)
colorDistance += diff; // Do not take the absolute value here. Be careful with the signs.
colorDistanceSquared += diff * diff;
}

// Calculating and combining the spatial
// and color weights.
scalar_t spatialWeight = 1;

spatialWeight =
gaussianKernel_x[kernelIndex[0]] * gaussianKernel_y[kernelIndex[1]] * gaussianKernel_z[kernelIndex[2]];

scalar_t colorWeight = exp(colorDistanceSquared * colorExpConstant);
scalar_t totalWeight = spatialWeight * colorWeight;

// Aggregating values. Only do this if flagNotClamped: Pixels outside the image are disregarded.
if (flagNotClamped) {
for (int i = 0; i < desc.channelCount; i++) {
// Distinguish cases for k!=i (calculation is done here)
// and k==i (partial derivatives are precalculated).
// If statement replaces center element of neighborhood/kernel.
if (kernelIndex[0] != halfWindowSize_x || kernelIndex[1] != halfWindowSize_y ||
kernelIndex[2] != halfWindowSize_z) {
filter_kernel = -(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) *
outputTensorData[neighbourOffset + i * desc.channelStride] * totalWeight * colorDistance /
(colorSigma * colorSigma) +
(1 / outputWeightsTensorData[neighbourOffset + i * desc.channelStride]) * totalWeight *
(1 +
inputTensorData[homeOffset + i * desc.channelStride] * colorDistance /
(colorSigma * colorSigma)); // inputTensorData[homeOffset] !!
} else {
filter_kernel = dO_dx_kiData[homeOffset + i * desc.channelStride];
}

valueSum += gradientInputTensorData[neighbourOffset + i * desc.channelStride] * filter_kernel;
}
}
} while (kernelIndex++);

// Do the filtering and calculate the values for the backward pass.
for (int i = 0; i < desc.channelCount; i++) {
// Filtering:
gradientOutputTensorData[homeOffset + i * desc.channelStride] = valueSum;
}
}
}
}

delete[] kernelSizes;
delete[] gaussianKernel_x;
delete[] gaussianKernel_y;
delete[] gaussianKernel_z;
delete[] xDistanceSquared;
delete[] yDistanceSquared;
delete[] zDistanceSquared;
}

torch::Tensor BilateralFilterCpuBackward(
torch::Tensor gradientInputTensor,
torch::Tensor inputTensor,
torch::Tensor outputTensor,
torch::Tensor outputWeightsTensor,
torch::Tensor dO_dx_ki,
float sigma_x,
float sigma_y,
float sigma_z,
float colorSigma) {
// Preparing output tensor.
torch::Tensor gradientOutputTensor = torch::zeros_like(gradientInputTensor);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(gradientInputTensor.scalar_type(), "BilateralFilterCpuBackward_3d", ([&] {
BilateralFilterCpuBackward_3d<scalar_t>(
gradientInputTensor,
gradientOutputTensor,
inputTensor,
outputTensor,
outputWeightsTensor,
dO_dx_ki,
sigma_x,
sigma_y,
sigma_z,
colorSigma);
}));

return gradientOutputTensor;
}
Loading

0 comments on commit 6803061

Please sign in to comment.