Skip to content

Commit

Permalink
Clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 15, 2023
1 parent 2cc6aaa commit f806ec4
Showing 1 changed file with 17 additions and 62 deletions.
79 changes: 17 additions & 62 deletions onnxruntime/core/optimizer/insert_cast_transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,68 +231,6 @@ static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, cons
return Status::OK();
}

enum TypeGroup {
Unknown = -1,
Bool = 0,
UnsignedInteger = 1,
Integer = 2,
Float = 3,
};

TypeGroup GetTypeGroup(DataType type) {
if (*type == "tensor(bool)") {
return Bool;
}

if (*type == "tensor(uint16)" || *type == "tensor(uint32)" || *type == "tensor(uint64)" || *type == "tensor(uint8)") {
return UnsignedInteger;
}

if (*type == "tensor(int16)" || *type == "tensor(int32)" || *type == "tensor(int64)" || *type == "tensor(int8)") {
return Integer;
}

if (*type == "tensor(bfloat16)" || *type == "tensor(double)" || *type == "tensor(float)" || *type == "tensor(float16)") {
return Float;
}

return Unknown;
}

int TypeWidth(DataType type) {
if (*type == "tensor(bool)") {
return 1;
}

if (*type == "tensor(uint8)" || *type == "tensor(int8)") {
return 8;
}

if (*type == "tensor(uint16)" || *type == "tensor(int16)" || *type == "tensor(uint16)" || *type == "tensor(int16)" || *type == "tensor(bfloat16)" || *type == "tensor(float16)") {
return 16;
}

if (*type == "tensor(uint32)" || *type == "tensor(int32)" || *type == "tensor(float)") {
return 32;
}

if (*type == "tensor(uint64)" || *type == "tensor(int64)" || *type == "tensor(double)") {
return 64;
}

return -1;
}

inline bool LossOfPrecision(DataType src_type, DataType dst_type, const Node& node) {
TypeGroup src_type_group = GetTypeGroup(src_type);
TypeGroup dst_type_group = GetTypeGroup(dst_type);
if (src_type_group == TypeGroup::Unknown || dst_type_group == TypeGroup::Unknown) {
return true;
}
// The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer.
// Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support.
return (dst_type_group < src_type_group) || (TypeWidth(dst_type) < TypeWidth(src_type) && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_")));
}

/** Transformer to remove duplicate Cast nodes. */
class RemoveDuplicateCastTransformer : public GraphTransformer {
Expand All @@ -301,6 +239,23 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
}

private:
InlinedVector<std::string> cast_ordering{
"tensor(bool)", "tensor(uint8)", "tensor(uint16)", "tensor(uint32)", "tensor(uint64)", "tensor(int8)", "tensor(int16)",
"tensor(int32)", "tensor(int64)", "tensor(bfloat16)", "tensor(float16)", "tensor(float)", "tensor(double)"
};

inline bool LossOfPrecision(DataType src_type, DataType dst_type, const Node& node) const {
// The comparison with "InsertedPrecisionFreeCast_" reflects cast nodes that are inserted by InsertCastTransformer.
// Such casts should not be considered as loss of precision - the inserted upcasts (f16 -> f32) and downcasts (f32 -> f16) are inserted to support kernels when on a CPU EP without F16 support.
auto src_pos = std::find(cast_ordering.begin(), cast_ordering.end(), *src_type);
auto dst_pos = std::find(cast_ordering.begin(), cast_ordering.end(), *dst_type);
if (src_pos == cast_ordering.end() || dst_pos == cast_ordering.end()) {
return true;
}

return std::distance(src_pos, dst_pos) < 0 && (node.Name().compare(0, 26, "InsertedPrecisionFreeCast_"));
}

Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override {
auto output_args = graph.GetOutputs();
InlinedHashSet<const onnxruntime::NodeArg*> graph_outputs;
Expand Down

0 comments on commit f806ec4

Please sign in to comment.