Skip to content

Commit

Permalink
Add T=bfloat16 to custom_ops registration (#2688)
Browse files Browse the repository at this point in the history
  • Loading branch information
szutenberg authored Apr 12, 2022
1 parent 50530e8 commit c585796
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ REGISTER_OP("Addons>AdjustHsvInYiq")
.Input("scale_s: float")
.Input("scale_v: float")
.Output("output: T")
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
.Attr(
"T: {uint8, int8, int16, int32, int64, half, float, double, bfloat16}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle images, delta_h, scale_s, scale_v;

Expand Down Expand Up @@ -70,4 +71,4 @@ output: The hsv-adjusted image or images. No clipping will be done in this op.
)Doc");

} // end namespace addons
} // namespace tensorflow
} // namespace tensorflow
6 changes: 3 additions & 3 deletions tensorflow_addons/custom_ops/image/cc/ops/image_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ components: Component ids for each pixel in "image". Same shape as "image". Zero

REGISTER_OP("Addons>EuclideanDistanceTransform")
.Input("images: uint8")
.Attr("dtype: {float16, float32, float64}")
.Attr("dtype: {bfloat16, float16, float32, float64}")
.Output("transformed_images: dtype")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(EuclideanDistanceTransformDoc);
Expand All @@ -65,9 +65,9 @@ REGISTER_OP("Addons>ImageConnectedComponents")
.Output("components: int64")
.Attr(
"dtype: {int64, int32, uint16, int16, uint8, int8, half, float, "
"double, bool, string}")
"bfloat16, double, bool, string}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(ImageConnectedComponentsDoc);

} // end namespace addons
} // namespace tensorflow
} // namespace tensorflow
6 changes: 3 additions & 3 deletions tensorflow_addons/custom_ops/image/cc/ops/resampler_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ REGISTER_OP("Addons>Resampler")
.Input("data: T")
.Input("warp: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle data;
ShapeHandle warp;
Expand All @@ -53,7 +53,7 @@ REGISTER_OP("Addons>ResamplerGrad")
.Input("grad_output: T")
.Output("grad_data: T")
.Output("grad_warp: T")
.Attr("T: {half, float, double}")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->input(0));
c->set_output(1, c->input(1));
Expand All @@ -62,4 +62,4 @@ REGISTER_OP("Addons>ResamplerGrad")
.Doc(R"doc(Resampler Grad op.)doc");

} // namespace addons
} // namespace tensorflow
} // namespace tensorflow
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ REGISTER_OP("Addons>EmbeddingBag")
.Input("params: T")
.Input("weights: T")
.Output("output: T")
.Attr("T: {half, float, double}")
.Attr("T: {bfloat16, half, float, double}")
.Attr("Tindices: {int32, int64}")
.Attr("combiner: {'SUM', 'MEAN'} = 'SUM'")
.SetShapeFn([](InferenceContext* c) {
Expand All @@ -51,7 +51,7 @@ REGISTER_OP("Addons>EmbeddingBagGrad")
.Input("grads: T")
.Output("params_grads: T")
.Output("weights_grads: T")
.Attr("T: {half, float, double}")
.Attr("T: {bfloat16, half, float, double}")
.Attr("Tindices: {int32, int64}")
.Attr("combiner: {'SUM', 'MEAN'} = 'SUM'")
.SetShapeFn([](InferenceContext* c) {
Expand Down

0 comments on commit c585796

Please sign in to comment.