Skip to content

Commit

Permalink
[MLIR][TORCH] Add support for Short(si16) data type
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 committed Dec 9, 2023
1 parent fb21a85 commit 07c3e11
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) {
return torch_upstream::ScalarType::Long;
if (type.isSignedInteger(32))
return torch_upstream::ScalarType::Int;
if (type.isSignedInteger(16))
return torch_upstream::ScalarType::Short;
if (type.isSignlessInteger(1))
return torch_upstream::ScalarType::Bool;
if (type.isBF16())
Expand Down Expand Up @@ -95,6 +97,8 @@ Torch::getTypeForScalarType(MLIRContext *context,
return IntegerType::get(context, 64, mlir::IntegerType::Signed);
case torch_upstream::ScalarType::Int:
return IntegerType::get(context, 32, mlir::IntegerType::Signed);
case torch_upstream::ScalarType::Short:
return IntegerType::get(context, 16, mlir::IntegerType::Signed);
case torch_upstream::ScalarType::Bool:
return IntegerType::get(context, 1);
case torch_upstream::ScalarType::BFloat16:
Expand Down Expand Up @@ -213,8 +217,8 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Location loc, float value,
Type dtype) {
// Creating constants satisfying backend contract.
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(8) ||
dtype.isInteger(1))
if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(16) ||
dtype.isInteger(8) || dtype.isInteger(1))
return rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((int64_t)value));
if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16())
Expand Down

0 comments on commit 07c3e11

Please sign in to comment.