Skip to content

Commit 727c97c

Browse files
[mlir][arith] Add support for sitofp, uitofp to ArithToAPFloat
1 parent 3db8ed0 commit 727c97c

File tree

4 files changed

+119
-0
lines changed

4 files changed

+119
-0
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,72 @@ struct FpToIntConversion final : OpRewritePattern<OpTy> {
241241
bool isUnsigned;
242242
};
243243

244+
template <typename OpTy>
245+
struct IntToFpConversion final : OpRewritePattern<OpTy> {
246+
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
247+
bool isUnsigned, PatternBenefit benefit = 1)
248+
: OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
249+
isUnsigned(isUnsigned){};
250+
251+
LogicalResult matchAndRewrite(OpTy op,
252+
PatternRewriter &rewriter) const override {
253+
// Get APFloat function from runtime library.
254+
auto i1Type = IntegerType::get(symTable->getContext(), 1);
255+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
256+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
257+
FailureOr<FuncOp> fn =
258+
lookupOrCreateApFloatFn(rewriter, symTable, "convert_from_int",
259+
{i32Type, i32Type, i1Type, i64Type});
260+
if (failed(fn))
261+
return fn;
262+
263+
rewriter.setInsertionPoint(op);
264+
// Cast operands to 64-bit integers.
265+
Location loc = op.getLoc();
266+
auto inIntTy = cast<IntegerType>(op.getOperand().getType());
267+
auto int64Type = rewriter.getI64Type();
268+
Value operandBits = op.getOperand();
269+
if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
270+
if (isUnsigned) {
271+
operandBits =
272+
arith::ExtUIOp::create(rewriter, loc, int64Type, operandBits);
273+
} else {
274+
operandBits =
275+
arith::ExtSIOp::create(rewriter, loc, int64Type, operandBits);
276+
}
277+
} else if (operandBits.getType().getIntOrFloatBitWidth() > 64) {
278+
return rewriter.notifyMatchFailure(
279+
loc, "integer bitwidth > 64 is not supported");
280+
}
281+
282+
// Call APFloat function.
283+
auto outFloatTy = cast<FloatType>(op.getType());
284+
Value outSemValue = getSemanticsValue(rewriter, loc, outFloatTy);
285+
Value inWidthValue = arith::ConstantOp::create(
286+
rewriter, loc, i32Type,
287+
rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
288+
Value isUnsignedValue = arith::ConstantOp::create(
289+
rewriter, loc, i1Type, rewriter.getIntegerAttr(i1Type, isUnsigned));
290+
SmallVector<Value> params = {outSemValue, inWidthValue, isUnsignedValue,
291+
operandBits};
292+
auto resultOp =
293+
func::CallOp::create(rewriter, loc, TypeRange(rewriter.getI64Type()),
294+
SymbolRefAttr::get(*fn), params);
295+
296+
// Truncate result to the original width.
297+
auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
298+
Value truncatedBits = arith::TruncIOp::create(rewriter, loc, outIntWType,
299+
resultOp->getResult(0));
300+
Value result =
301+
arith::BitcastOp::create(rewriter, loc, outFloatTy, truncatedBits);
302+
rewriter.replaceOp(op, result);
303+
return success();
304+
}
305+
306+
SymbolOpInterface symTable;
307+
bool isUnsigned;
308+
};
309+
244310
namespace {
245311
struct ArithToAPFloatConversionPass final
246312
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -269,6 +335,10 @@ void ArithToAPFloatConversionPass::runOnOperation() {
269335
/*isUnsigned=*/false);
270336
patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
271337
/*isUnsigned=*/true);
338+
patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
339+
/*isUnsigned=*/false);
340+
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
341+
/*isUnsigned=*/true);
272342
LogicalResult result = success();
273343
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
274344
if (diag.getSeverity() == DiagnosticSeverity::Error) {

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,16 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_to_int(
119119
// result to the desired result width.
120120
return result.getZExtValue();
121121
}
122+
123+
MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
124+
int32_t semantics, int32_t inputWidth, bool isUnsigned, uint64_t a) {
125+
llvm::APInt val(inputWidth, a, /*isSigned=*/!isUnsigned);
126+
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
127+
static_cast<llvm::APFloatBase::Semantics>(semantics));
128+
llvm::APFloat result(sem);
129+
// TODO: Custom rounding modes are not supported yet.
130+
result.convertFromAPInt(val, /*IsSigned=*/!isUnsigned,
131+
llvm::RoundingMode::NearestTiesToEven);
132+
return result.bitcastToAPInt().getZExtValue();
133+
}
122134
}

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,27 @@ func.func @fptoui(%arg0: f16) {
174174
%0 = arith.fptoui %arg0 : f16 to i4
175175
return
176176
}
177+
178+
// -----
179+
180+
// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
181+
// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
182+
// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
183+
// CHECK: %[[is_unsigned:.*]] = arith.constant false
184+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
185+
func.func @sitofp(%arg0: i32) {
186+
%0 = arith.sitofp %arg0 : i32 to f4E2M1FN
187+
return
188+
}
189+
190+
// -----
191+
192+
// CHECK: func.func private @_mlir_apfloat_convert_from_int(i32, i32, i1, i64) -> i64
193+
// CHECK: %[[sem_out:.*]] = arith.constant 18 : i32
194+
// CHECK: %[[in_width:.*]] = arith.constant 32 : i32
195+
// CHECK: %[[is_unsigned:.*]] = arith.constant true
196+
// CHECK: %[[res:.*]] = call @_mlir_apfloat_convert_from_int(%[[sem_out]], %[[in_width]], %[[is_unsigned]], %{{.*}}) : (i32, i32, i1, i64) -> i64
197+
func.func @uitofp(%arg0: i32) {
198+
%0 = arith.uitofp %arg0 : i32 to f4E2M1FN
199+
return
200+
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,5 +53,18 @@ func.func @entry() {
5353
%cvt_int_unsigned = arith.fptoui %cvt : f8E4M3FN to i2
5454
vector.print %cvt_int_unsigned : i2
5555

56+
// CHECK-NEXT: -6
57+
// Bit pattern: 1...11110111, interpreted as signed: -9
58+
// Closest f4E2M1FN value: -6.0
59+
%c9 = arith.constant -9 : i16
60+
%cvt_from_signed_int = arith.sitofp %c9 : i16 to f4E2M1FN
61+
vector.print %cvt_from_signed_int : f4E2M1FN
62+
63+
// CHECK-NEXT: 6
64+
// Bit pattern: 1...11110111, interpreted as unsigned: 65527
65+
// Closest f4E2M1FN value: 6.0
66+
%cvt_from_unsigned_int = arith.uitofp %c9 : i16 to f4E2M1FN
67+
vector.print %cvt_from_unsigned_int : f4E2M1FN
68+
5669
return
5770
}

0 commit comments

Comments
 (0)