Skip to content

Commit

Permalink
update batch norm to use layout gen
Browse files Browse the repository at this point in the history
Differential Revision: D69937208

Pull Request resolved: #8600
  • Loading branch information
nathanaelsee authored Feb 20, 2025
1 parent a454be5 commit 735f16e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
38 changes: 16 additions & 22 deletions backends/vulkan/runtime/graph/ops/glsl/batchnorm.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,18 @@

layout(std430) buffer;

layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in;
layout(set = 0, binding = 2) uniform PRECISION sampler3D weight_in;
layout(set = 0, binding = 3) uniform PRECISION sampler3D bias_in;
layout(set = 0, binding = 4) uniform PRECISION sampler3D mean_in;
layout(set = 0, binding = 5) uniform PRECISION sampler3D var_in;
#include "indexing_utils.h"

layout(set = 0, binding = 6) uniform PRECISION restrict OutLimits {
ivec3 out_limits;
};
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "weight_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "mean_in", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "var_in", DTYPE, STORAGE)}

layout(set = 0, binding = 7) uniform PRECISION restrict Params {
float eps;
};

layout(set = 0, binding = 8) uniform PRECISION restrict Params2 {
int num_texel_per_batch;
};
${layout_declare_ubo(B, "ivec3", "out_limits")}
${layout_declare_ubo(B, "float", "eps")}
${layout_declare_ubo(B, "int", "num_texel_per_batch")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

Expand All @@ -40,16 +34,16 @@ void main() {
return;
}

VEC4_T v = VEC4_T(texelFetch(image_in, pos, 0));
VEC4_T v = VEC4_T(load_texel(t_in, pos));

ivec3 param_pos = ivec3(pos.z % num_texel_per_batch, 0, 0);

VEC4_T weight = VEC4_T(texelFetch(weight_in, param_pos, 0));
VEC4_T bias = VEC4_T(texelFetch(bias_in, param_pos, 0));
VEC4_T mean = VEC4_T(texelFetch(mean_in, param_pos, 0));
VEC4_T var = VEC4_T(texelFetch(var_in, param_pos, 0));
VEC4_T weight = VEC4_T(load_texel(weight_in, param_pos));
VEC4_T bias = VEC4_T(load_texel(bias_in, param_pos));
VEC4_T mean = VEC4_T(load_texel(mean_in, param_pos));
VEC4_T var = VEC4_T(load_texel(var_in, param_pos));

v = ((v - mean) / sqrt(var + eps)) * weight + bias;

imageStore(image_out, pos, v);
write_texel(t_out, pos, v);
}
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/glsl/batchnorm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ batchnorm:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
STORAGE: texture3d
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down

0 comments on commit 735f16e

Please sign in to comment.