- Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][arith] Add support for cmpf to ArithToAPFloat #169753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
+177 −5
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
Member
| @llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesAdd support for Full diff: https://github.com/llvm/llvm-project/pull/169753.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 81fbdb1611deb..20b99e4a6e516 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -45,11 +45,12 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, static FailureOr<FuncOp> lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, - SymbolTableCollection *symbolTables = nullptr) { - auto i64Type = IntegerType::get(symTable->getContext(), 64); - + SymbolTableCollection *symbolTables = nullptr, + Type resultType = {}) { + if (!resultType) + resultType = IntegerType::get(symTable->getContext(), 64); std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); - auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type}); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); FailureOr<FuncOp> func = lookupFnDecl(symTable, funcName, funcT, symbolTables); // Failed due to type mismatch. @@ -308,6 +309,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> { bool isUnsigned; }; +struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> { + CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i8Type = IntegerType::get(symTable->getContext(), 8); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "compare", + {i32Type, i64Type, i64Type}, nullptr, i8Type); + if (failed(fn)) + return fn; + + // Cast operands to 64-bit integers. + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto floatTy = cast<FloatType>(op.getLhs().getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs())); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + Value comparisonResult = + func::CallOp::create(rewriter, loc, TypeRange(i8Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Generate an i1 SSA value that is "true" if the comparison result matches + // the given `val`. + auto checkValue = [&](llvm::APFloat::cmpResult val) { + return arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, comparisonResult, + arith::ConstantOp::create( + rewriter, loc, i8Type, + rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val))) + .getResult()); + }; + // Generate an i1 SSA value that is "true" if the comparison result matches + // any of the given `vals`. + std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkValues = + [&](ArrayRef<llvm::APFloat::cmpResult> vals) { + Value first = checkValue(vals.front()); + if (vals.size() == 1) + return first; + Value rest = checkValues(vals.drop_front()); + return arith::OrIOp::create(rewriter, loc, first, rest).getResult(); + }; + + // This switch-case statement was taken from arith::applyCmpPredicate. + Value result; + switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: + result = arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 0)) + .getResult(); + break; + case arith::CmpFPredicate::OEQ: + result = checkValue(llvm::APFloat::cmpEqual); + break; + case arith::CmpFPredicate::OGT: + result = checkValue(llvm::APFloat::cmpGreaterThan); + break; + case arith::CmpFPredicate::OGE: + result = + checkValues({llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::OLT: + result = checkValue(llvm::APFloat::cmpLessThan); + break; + case arith::CmpFPredicate::OLE: + result = + checkValues({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ONE: + // Not cmpUnordered and not cmpUnordered. + result = checkValues( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::ORD: + // Not cmpUnordered. + result = + checkValues({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UEQ: + result = + checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UGT: + result = checkValues( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::UGE: + result = + checkValues({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ULT: + result = checkValues( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan}); + break; + case arith::CmpFPredicate::ULE: + result = + checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UNE: + // Not cmpEqual. + result = checkValues({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpUnordered}); + break; + case arith::CmpFPredicate::UNO: + result = checkValue(llvm::APFloat::cmpUnordered); + break; + case arith::CmpFPredicate::AlwaysTrue: + result = arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 1)) + .getResult(); + break; + } + rewriter.replaceOp(op, result); + return success(); + } + + SymbolOpInterface symTable; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { @@ -340,6 +480,7 @@ void ArithToAPFloatConversionPass::runOnOperation() { /*isUnsigned=*/false); patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(), /*isUnsigned=*/true); + patterns.add<CmpFOpToAPFloatConversion>(context, getOperation()); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 44980ccd77491..77f7137264888 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int( llvm::RoundingMode::NearestTiesToEven); return result.bitcastToAPInt().getZExtValue(); } + +MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, + uint64_t a, + uint64_t b) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + llvm::APFloat y(sem, llvm::APInt(bitWidth, b)); + return static_cast<int8_t>(x.compare(y)); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index d71d81dddcd4f..78ce3640ecc67 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) { %0 = arith.uitofp %arg0 : i32 to f4E2M1FN return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8 +// CHECK: %[[c3:.*]] = arith.constant 3 : i8 +// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8 +// CHECK: %[[c0:.*]] = arith.constant 0 : i8 +// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8 +// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1 +func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 8046610d479a8..433d058d025cf 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -43,6 +43,10 @@ func.func @entry() { %cvt = arith.truncf %b2 : f32 to f8E4M3FN vector.print %cvt : f8E4M3FN + // CHECK-NEXT: 1 + %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN + vector.print %cmp1 : i1 + // CHECK-NEXT: 1 // Bit pattern: 01, interpreted as signed integer: 1 %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2 |
97977f8 to 98853ee Compare 98853ee to 396f4f9 Compare kuhar approved these changes Nov 30, 2025
aahrun pushed a commit to aahrun/llvm-project that referenced this pull request Dec 1, 2025
Add support for `arith.cmpf`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
Add support for
arith.cmpf.