Skip to content

Commit

Permalink
dev nanoo fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Nov 14, 2024
1 parent 3e6f261 commit 67fcf83
Show file tree
Hide file tree
Showing 51 changed files with 537 additions and 109 deletions.
9 changes: 4 additions & 5 deletions tensorflow/c/tf_datatype.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,12 @@ typedef enum TF_DataType {
TF_VARIANT = 21,
TF_UINT32 = 22,
TF_UINT64 = 23,
TF_FLOAT8_E5M2 = 24, // 5 exponent bits, 2 mantissa bits.
TF_FLOAT8_E4M3FN = 25, // 4 exponent bits, 3 mantissa bits, finite-only, with
// 2 NaNs (0bS1111111).
TF_FLOAT8_E5M2 = 24, // 5 exponent bits, 2 mantissa bits.
TF_FLOAT8_E4M3FN = 25, // 4 exponent bits, 3 mantissa bits, finite-only, with two NaNs and two Zeros.
TF_FLOAT8_E4M3FNUZ = 26, // 4 exponent bits, 3 mantissa bits, finite-only, with one NaN and one Zero.
// TODO - b/299182407: Leaving room for remaining float8 types.
// TF_FLOAT8_E4M3FNUZ = 26,
// TF_FLOAT8_E4M3B11FNUZ = 27,
// TF_FLOAT8_E5M2FNUZ = 28,
TF_FLOAT8_E5M2FNUZ = 28, // 5 exponent bits, 2 mantissa bits, finite-only, with one NaN and one Zero.
TF_INT4 = 29,
TF_UINT4 = 30,
} TF_DataType;
Expand Down
27 changes: 22 additions & 5 deletions tensorflow/compiler/jit/xla_gpu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,28 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);

// Kernel registrations

constexpr std::array<DataType, 20> kAllXlaGpuTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8,
DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL,
DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_INT4, DT_UINT4}};
constexpr std::array<DataType, 22> kAllXlaGpuTypes = {{DT_UINT8,
DT_QUINT8,
DT_UINT16,
DT_INT8,
DT_QINT8,
DT_INT16,
DT_INT32,
DT_QINT32,
DT_INT64,
DT_HALF,
DT_FLOAT,
DT_DOUBLE,
DT_COMPLEX64,
DT_COMPLEX128,
DT_BOOL,
DT_BFLOAT16,
DT_FLOAT8_E5M2,
DT_FLOAT8_E4M3FN,
DT_FLOAT8_E5M2FNUZ,
DT_FLOAT8_E4M3FNUZ,
DT_INT4,
DT_UINT4}};

REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ TEST(IsLargeFloatTypeTest, scalars) {
auto context = CreateContext();

EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNType::get(context.get())));
EXPECT_FALSE(IsLargeFloatType(Float8E4M3FNUZType::get(context.get())));
EXPECT_FALSE(IsLargeFloatType(Float16Type::get(context.get())));
EXPECT_FALSE(IsLargeFloatType(BFloat16Type::get(context.get())));
EXPECT_TRUE(IsLargeFloatType(Float32Type::get(context.get())));
Expand All @@ -52,6 +53,8 @@ TEST(IsLargeFloatTypeTest, tensors) {

EXPECT_FALSE(IsLargeFloatType(
RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))));
EXPECT_FALSE(IsLargeFloatType(
RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get()))));
EXPECT_FALSE(IsLargeFloatType(
RankedTensorType::get({2, 2}, Float16Type::get(context.get()))));
EXPECT_FALSE(IsLargeFloatType(
Expand All @@ -76,6 +79,8 @@ TEST(ToBfloat16TypeTest, scalars) {

EXPECT_EQ(ToBfloat16Type(Float8E4M3FNType::get(context.get())),
Float8E4M3FNType::get(context.get()));
EXPECT_EQ(ToBfloat16Type(Float8E4M3FNUZType::get(context.get())),
Float8E4M3FNUZType::get(context.get()));
EXPECT_EQ(ToBfloat16Type(Float16Type::get(context.get())),
Float16Type::get(context.get()));
EXPECT_EQ(ToBfloat16Type(BFloat16Type::get(context.get())),
Expand All @@ -102,6 +107,10 @@ TEST(ToBfloat16TypeTest, tensors) {
ToBfloat16Type(
RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get()))),
RankedTensorType::get({2, 2}, Float8E4M3FNType::get(context.get())));
EXPECT_EQ(
ToBfloat16Type(RankedTensorType::get(
{2, 2}, Float8E4M3FNUZType::get(context.get()))),
RankedTensorType::get({2, 2}, Float8E4M3FNUZType::get(context.get())));
EXPECT_EQ(ToBfloat16Type(
RankedTensorType::get({2, 2}, Float16Type::get(context.get()))),
RankedTensorType::get({2, 2}, Float16Type::get(context.get())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ TEST(IsTFQintTypeTest, ValidTFQintTypeSucceeds) {

EXPECT_FALSE(IsTFQintType(TF::Int8RefType::get(context.get())));
EXPECT_FALSE(IsTFQintType(TF::Float8E5M2RefType::get(context.get())));
EXPECT_FALSE(IsTFQintType(TF::Float8E5M2FNUZRefType::get(context.get())));
}

TEST(GetIntTypeFromTFQintTest, ChecksIntTypesFromTFQint) {
Expand Down
8 changes: 7 additions & 1 deletion tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def TF_Float64Ref : TF_TensorFlowType<"DoubleRef", "f64ref">;
def TF_Bfloat16Ref : TF_TensorFlowType<"Bfloat16Ref", "bf16ref">;
def TF_Float8E4M3FNRef : TF_TensorFlowType<"Float8E4M3FNRef", "float8e4m3fnref">;
def TF_Float8E5M2Ref : TF_TensorFlowType<"Float8E5M2Ref", "float8e5m2ref">;
def TF_Float8E4M3FNUZRef : TF_TensorFlowType<"Float8E4M3FNUZRef", "float8e4m3fnuzref">;
def TF_Float8E5M2FNUZRef : TF_TensorFlowType<"Float8E5M2FNUZRef", "float8e5m2fnuzref">;

// Complex reference types
def TF_Complex64Ref : TF_TensorFlowType<"Complex64Ref", "complex64ref">;
Expand Down Expand Up @@ -443,12 +445,14 @@ def TF_Float64 : AnyTypeOf<[F64, TF_Float64Ref], "64-bit float">;
def TF_Bfloat16 : AnyTypeOf<[BF16, TF_Bfloat16Ref], "bfloat16">;
def TF_Float8E4M3FN : AnyTypeOf<[F8E4M3FN, TF_Float8E4M3FNRef], "float8e4m3fn">;
def TF_Float8E5M2 : AnyTypeOf<[F8E5M2, TF_Float8E5M2Ref], "float8e5m2">;
def TF_Float8E4M3FNUZ : AnyTypeOf<[F8E4M3FNUZ, TF_Float8E4M3FNUZRef], "float8e4m3fnuz">;
def TF_Float8E5M2FNUZ : AnyTypeOf<[F8E5M2FNUZ, TF_Float8E5M2FNUZRef], "float8e5m2fnuz">;

def TF_F32OrF64 : AnyTypeOf<[TF_Float32, TF_Float64], "32/64-bit float">;

def TF_Float : AnyTypeOf<
[TF_Float16, TF_Float32, TF_Float64, TF_Bfloat16, TF_Float8E4M3FN,
TF_Float8E5M2],
TF_Float8E5M2, TF_Float8E4M3FNUZ, TF_Float8E5M2FNUZ],
"floating-point">;

// Tensor types
Expand All @@ -460,6 +464,8 @@ def TF_Float64Tensor : TensorOf<[TF_Float64]>;
def TF_Bfloat16Tensor : TensorOf<[TF_Bfloat16]>;
def TF_Float8E4M3FNTensor : TensorOf<[TF_Float8E4M3FN]>;
def TF_Float8E5M2Tensor : TensorOf<[TF_Float8E5M2]>;
def TF_Float8E4M3FNUZTensor : TensorOf<[TF_Float8E4M3FNUZ]>;
def TF_Float8E5M2FNUZTensor : TensorOf<[TF_Float8E5M2FNUZ]>;

//===----------------------------------------------------------------------===//
// Complex types (including corresponding reference types)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/ir/tf_types.def
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ HANDLE_TF_REF_TYPE(HalfRef, HALF_REF, "halfref")
HANDLE_TF_REF_TYPE(ResourceRef, RESOURCE_REF, "resourceref")
HANDLE_TF_REF_TYPE(Float8E4M3FNRef, FLOAT8_E4M3FN_REF, "float8e4m3fnref")
HANDLE_TF_REF_TYPE(Float8E5M2Ref, FLOAT8_E5M2_REF, "float8e5m2ref")
HANDLE_TF_REF_TYPE(Float8E4M3FNUZRef, FLOAT8_E4M3FNUZ_REF, "float8e4m3fnuzref")
HANDLE_TF_REF_TYPE(Float8E5M2FNUZRef, FLOAT8_E5M2FNUZ_REF, "float8e5m2fnuzref")

#ifndef HANDLE_LAST_TF_TYPE
#define HANDLE_LAST_TF_TYPE(class, enumerant, name) \
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ absl::StatusOr<ElementsAttr> ConvertTensor(const Tensor& input_tensor,
case DT_HALF:
case DT_FLOAT8_E5M2:
case DT_FLOAT8_E4M3FN:
case DT_FLOAT8_E5M2FNUZ:
case DT_FLOAT8_E4M3FNUZ:
return ConvertTensorOfCustomFloatType(input_tensor, type);
case DT_STRING:
return ConvertStringTensor(input_tensor, type);
Expand Down Expand Up @@ -466,6 +468,14 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) {
ConvertFloat8ElementsAttr<tsl::float8_e4m3fn>(
dense_attr, output->mutable_float8_val());
break;
case DT_FLOAT8_E5M2FNUZ:
ConvertFloat8ElementsAttr<tsl::float8_e5m2fnuz>(
dense_attr, output->mutable_float8_val());
break;
case DT_FLOAT8_E4M3FNUZ:
ConvertFloat8ElementsAttr<tsl::float8_e4m3fnuz>(
dense_attr, output->mutable_float8_val());
break;
case tensorflow::DT_INT4:
ConvertIntElementsAttr<int, tsl::int4>(dense_attr,
output->mutable_int_val(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ TEST_F(ConvertTensorTest, Simple) {
ASSERT_NO_FATAL_FAILURE(VerifyConversion<tsl::float8_e4m3fn>(
{tsl::float8_e4m3fn{1.0}, tsl::float8_e4m3fn{-1.0}}, DT_FLOAT8_E4M3FN,
mlir::FloatType::getFloat8E4M3FN(&context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<tsl::float8_e5m2fnuz>(
{tsl::float8_e5m2fnuz{1.0}, tsl::float8_e5m2fnuz{-1.0}},
DT_FLOAT8_E5M2FNUZ, mlir::FloatType::getFloat8E5M2FNUZ(&context)));
ASSERT_NO_FATAL_FAILURE(VerifyConversion<tsl::float8_e4m3fnuz>(
{tsl::float8_e4m3fnuz{1.0}, tsl::float8_e4m3fnuz{-1.0}},
DT_FLOAT8_E4M3FNUZ, mlir::FloatType::getFloat8E4M3FNUZ(&context)));

ASSERT_NO_FATAL_FAILURE(VerifyConversion<int4>(
{static_cast<int4>(1), static_cast<int4>(-1)}, DT_INT4,
Expand Down
12 changes: 12 additions & 0 deletions tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) {
case tensorflow::DT_FLOAT8_E5M2:
*type = builder.getFloat8E5M2Type();
return absl::OkStatus();
case tensorflow::DT_FLOAT8_E4M3FNUZ:
*type = builder.getFloat8E4M3FNUZType();
return absl::OkStatus();
case tensorflow::DT_FLOAT8_E5M2FNUZ:
*type = builder.getFloat8E5M2FNUZType();
return absl::OkStatus();
case DT_INT4:
*type = builder.getIntegerType(4, /*isSigned=*/true);
return absl::OkStatus();
Expand Down Expand Up @@ -125,6 +131,12 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) {
} else if (type.isFloat8E5M2()) {
*dtype = DT_FLOAT8_E5M2;
return absl::OkStatus();
} else if (type.isFloat8E4M3FNUZ()) {
*dtype = DT_FLOAT8_E4M3FNUZ;
return absl::OkStatus();
} else if (type.isFloat8E5M2FNUZ()) {
*dtype = DT_FLOAT8_E5M2FNUZ;
return absl::OkStatus();
} else if (auto itype = mlir::dyn_cast<mlir::IntegerType>(type)) {
switch (itype.getWidth()) {
case 1:
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/tests/const_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def testConst(self):
dtypes.float64,
dtypes.float8_e5m2,
dtypes.float8_e4m3fn,
dtypes.float8_e5m2fnuz,
dtypes.float8_e4m3fnuz,
}
for dtype in types:
with self.subTest(dtype=dtype):
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/compiler/tests/unary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,8 @@ def testCastFp8(self):
# TODO(b/271327511): Fix issue where casts to FP8 very rarely result in
# NaN on Mac
self.skipTest("Casts to FP8 sometimes result in NaN on Mac")
fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn}
fp8_types = {dtypes.float8_e5m2, dtypes.float8_e4m3fn,
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz}
other_types = {
dtypes.bool, dtypes.float32, dtypes.float64, dtypes.complex64,
dtypes.int32, dtypes.int64, dtypes.uint32, dtypes.uint64
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/tf2xla/type_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ absl::Status DataTypeToPrimitiveType(DataType data_type,
case tensorflow::DT_FLOAT8_E4M3FN:
*type = xla::F8E4M3FN;
return absl::OkStatus();
case tensorflow::DT_FLOAT8_E5M2FNUZ:
*type = xla::F8E5M2FNUZ;
return absl::OkStatus();
case tensorflow::DT_FLOAT8_E4M3FNUZ:
*type = xla::F8E4M3FNUZ;
return absl::OkStatus();
case tensorflow::DT_BFLOAT16:
*type = xla::BF16;
return absl::OkStatus();
Expand Down Expand Up @@ -103,6 +109,8 @@ absl::StatusOr<DataType> EncodePrimitiveTypeAsDataType(
{xla::PRED, DT_BOOL},
{xla::F8E5M2, DT_FLOAT8_E5M2},
{xla::F8E4M3FN, DT_FLOAT8_E4M3FN},
{xla::F8E5M2FNUZ, DT_FLOAT8_E5M2FNUZ},
{xla::F8E4M3FNUZ, DT_FLOAT8_E4M3FNUZ},
{xla::BF16, DT_BFLOAT16},
{xla::F16, DT_HALF},
{xla::F32, DT_FLOAT},
Expand Down
32 changes: 19 additions & 13 deletions tensorflow/compiler/tf2xla/xla_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,25 @@ constexpr std::array<DataType, 14> kNumericTypes = {
DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128,
DT_BFLOAT16}};

constexpr std::array<DataType, 22> kCpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64,
DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32,
DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN,
DT_INT4, DT_UINT4}};

constexpr std::array<DataType, 22> kGpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64,
DT_INT8, DT_QINT8, DT_INT16, DT_INT32, DT_QINT32,
DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16, DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN,
DT_INT4, DT_UINT4}};
constexpr std::array<DataType, 24> kCpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_INT8,
DT_QINT8, DT_INT16, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16,
DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_FLOAT8_E5M2FNUZ,
DT_FLOAT8_E4M3FNUZ, DT_INT4, DT_UINT4}};

constexpr std::array<DataType, 24> kGpuAllTypes = {
{DT_UINT8, DT_QUINT8, DT_UINT16,
DT_UINT32, DT_UINT64, DT_INT8,
DT_QINT8, DT_INT16, DT_INT32,
DT_QINT32, DT_INT64, DT_HALF,
DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16,
DT_FLOAT8_E5M2, DT_FLOAT8_E4M3FN, DT_FLOAT8_E5M2FNUZ,
DT_FLOAT8_E4M3FNUZ, DT_INT4, DT_UINT4}};

// Class that manages registrations of operators and devices for the XLA JIT.
// Not thread-safe.
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/core/framework/register_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ limitations under the License.

#define TF_CALL_float8_e5m2(m) m(::tensorflow::float8_e5m2)
#define TF_CALL_float8_e4m3fn(m) m(::tensorflow::float8_e4m3fn)
#define TF_CALL_float8_e5m2fnuz(m) m(::tensorflow::float8_e5m2fnuz)
#define TF_CALL_float8_e4m3fnuz(m) m(::tensorflow::float8_e4m3fnuz)

#define TF_CALL_int4(m) m(::tensorflow::int4)
#define TF_CALL_uint4(m) m(::tensorflow::uint4)
Expand Down Expand Up @@ -127,6 +129,8 @@ limitations under the License.

#define TF_CALL_float8_e5m2(m)
#define TF_CALL_float8_e4m3fn(m)
#define TF_CALL_float8_e5m2fnuz(m)
#define TF_CALL_float8_e4m3fnuz(m)

#define TF_CALL_int4(m)
#define TF_CALL_uint4(m)
Expand Down Expand Up @@ -164,6 +168,8 @@ limitations under the License.

#define TF_CALL_float8_e5m2(m)
#define TF_CALL_float8_e4m3fn(m)
#define TF_CALL_float8_e5m2fnuz(m)
#define TF_CALL_float8_e4m3fnuz(m)

#define TF_CALL_int4(m)
#define TF_CALL_uint4(m)
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/core/framework/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,14 @@ struct ProtoHelper<float8_e5m2> : public Float8ProtoHelper<float8_e5m2> {};
template <>
struct ProtoHelper<float8_e4m3fn> : public Float8ProtoHelper<float8_e4m3fn> {};

template <>
struct ProtoHelper<float8_e5m2fnuz>
: public Float8ProtoHelper<float8_e5m2fnuz> {};

template <>
struct ProtoHelper<float8_e4m3fnuz>
: public Float8ProtoHelper<float8_e4m3fnuz> {};

template <typename T>
Buffer<T>::Buffer(Allocator* a, int64_t n)
: BufferBase(a, TypedAllocator::Allocate<T>(a, n, AllocationAttributes())),
Expand Down Expand Up @@ -950,6 +958,8 @@ int Tensor::RefCount() const {
CASE(Variant, SINGLE_ARG(STMTS)) \
CASE(float8_e5m2, SINGLE_ARG(STMTS)) \
CASE(float8_e4m3fn, SINGLE_ARG(STMTS)) \
CASE(float8_e5m2fnuz, SINGLE_ARG(STMTS)) \
CASE(float8_e4m3fnuz, SINGLE_ARG(STMTS)) \
CASE(int4, SINGLE_ARG(STMTS)) \
CASE(uint4, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
Expand Down Expand Up @@ -1243,6 +1253,14 @@ inline float PrintOneElement(float8_e4m3fn f, bool print_v2) {
return static_cast<float>(f);
}

inline float PrintOneElement(float8_e5m2fnuz f, bool print_v2) {
return static_cast<float>(f);
}

inline float PrintOneElement(float8_e4m3fnuz f, bool print_v2) {
return static_cast<float>(f);
}

inline int16_t PrintOneElement(int4 a, bool print_v2) {
return static_cast<int16_t>(a);
}
Expand Down Expand Up @@ -1429,6 +1447,12 @@ string Tensor::SummarizeValue(int64_t max_entries, bool print_v2) const {
case DT_FLOAT8_E4M3FN:
return SummarizeArray<float8_e4m3fn>(limit, num_elts, shape_, data,
print_v2);
case DT_FLOAT8_E5M2FNUZ:
return SummarizeArray<float8_e5m2fnuz>(limit, num_elts, shape_, data,
print_v2);
case DT_FLOAT8_E4M3FNUZ:
return SummarizeArray<float8_e4m3fnuz>(limit, num_elts, shape_, data,
print_v2);
case DT_FLOAT:
return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
break;
Expand Down
Loading

0 comments on commit 67fcf83

Please sign in to comment.