Skip to content

Conversation

@matthias-springer
Copy link
Member

Add support for arith.cmpf.

@llvmbot
Copy link
Member

llvmbot commented Nov 27, 2025

@llvm/pr-subscribers-mlir-arith
@llvm/pr-subscribers-mlir-execution-engine

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add support for arith.cmpf.


Full diff: https://github.com/llvm/llvm-project/pull/169753.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp (+145-4)
  • (modified) mlir/lib/ExecutionEngine/APFloatWrappers.cpp (+11)
  • (modified) mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir (+15)
  • (modified) mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir (+4)
diff --git a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp index 81fbdb1611deb..20b99e4a6e516 100644 --- a/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp +++ b/mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp @@ -45,11 +45,12 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable, static FailureOr<FuncOp> lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, - SymbolTableCollection *symbolTables = nullptr) { - auto i64Type = IntegerType::get(symTable->getContext(), 64); - + SymbolTableCollection *symbolTables = nullptr, + Type resultType = {}) { + if (!resultType) + resultType = IntegerType::get(symTable->getContext(), 64); std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str(); - auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type}); + auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType}); FailureOr<FuncOp> func = lookupFnDecl(symTable, funcName, funcT, symbolTables); // Failed due to type mismatch. @@ -308,6 +309,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> { bool isUnsigned; }; +struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> { + CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, + PatternBenefit benefit = 1) + : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {} + + LogicalResult matchAndRewrite(arith::CmpFOp op, + PatternRewriter &rewriter) const override { + // Get APFloat function from runtime library. + auto i1Type = IntegerType::get(symTable->getContext(), 1); + auto i8Type = IntegerType::get(symTable->getContext(), 8); + auto i32Type = IntegerType::get(symTable->getContext(), 32); + auto i64Type = IntegerType::get(symTable->getContext(), 64); + FailureOr<FuncOp> fn = + lookupOrCreateApFloatFn(rewriter, symTable, "compare", + {i32Type, i64Type, i64Type}, nullptr, i8Type); + if (failed(fn)) + return fn; + + // Cast operands to 64-bit integers. + rewriter.setInsertionPoint(op); + Location loc = op.getLoc(); + auto floatTy = cast<FloatType>(op.getLhs().getType()); + auto intWType = rewriter.getIntegerType(floatTy.getWidth()); + Value lhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs())); + Value rhsBits = arith::ExtUIOp::create( + rewriter, loc, i64Type, + arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs())); + + // Call APFloat function. + Value semValue = getSemanticsValue(rewriter, loc, floatTy); + SmallVector<Value> params = {semValue, lhsBits, rhsBits}; + Value comparisonResult = + func::CallOp::create(rewriter, loc, TypeRange(i8Type), + SymbolRefAttr::get(*fn), params) + ->getResult(0); + + // Generate an i1 SSA value that is "true" if the comparison result matches + // the given `val`. + auto checkValue = [&](llvm::APFloat::cmpResult val) { + return arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, comparisonResult, + arith::ConstantOp::create( + rewriter, loc, i8Type, + rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val))) + .getResult()); + }; + // Generate an i1 SSA value that is "true" if the comparison result matches + // any of the given `vals`. + std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkValues = + [&](ArrayRef<llvm::APFloat::cmpResult> vals) { + Value first = checkValue(vals.front()); + if (vals.size() == 1) + return first; + Value rest = checkValues(vals.drop_front()); + return arith::OrIOp::create(rewriter, loc, first, rest).getResult(); + }; + + // This switch-case statement was taken from arith::applyCmpPredicate. + Value result; + switch (op.getPredicate()) { + case arith::CmpFPredicate::AlwaysFalse: + result = arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 0)) + .getResult(); + break; + case arith::CmpFPredicate::OEQ: + result = checkValue(llvm::APFloat::cmpEqual); + break; + case arith::CmpFPredicate::OGT: + result = checkValue(llvm::APFloat::cmpGreaterThan); + break; + case arith::CmpFPredicate::OGE: + result = + checkValues({llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::OLT: + result = checkValue(llvm::APFloat::cmpLessThan); + break; + case arith::CmpFPredicate::OLE: + result = + checkValues({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ONE: + // Not cmpUnordered and not cmpUnordered. + result = checkValues( + {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::ORD: + // Not cmpUnordered. + result = + checkValues({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UEQ: + result = + checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UGT: + result = checkValues( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan}); + break; + case arith::CmpFPredicate::UGE: + result = + checkValues({llvm::APFloat::cmpUnordered, + llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::ULT: + result = checkValues( + {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan}); + break; + case arith::CmpFPredicate::ULE: + result = + checkValues({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpEqual}); + break; + case arith::CmpFPredicate::UNE: + // Not cmpEqual. + result = checkValues({llvm::APFloat::cmpLessThan, + llvm::APFloat::cmpGreaterThan, + llvm::APFloat::cmpUnordered}); + break; + case arith::CmpFPredicate::UNO: + result = checkValue(llvm::APFloat::cmpUnordered); + break; + case arith::CmpFPredicate::AlwaysTrue: + result = arith::ConstantOp::create(rewriter, loc, i1Type, + rewriter.getIntegerAttr(i1Type, 1)) + .getResult(); + break; + } + rewriter.replaceOp(op, result); + return success(); + } + + SymbolOpInterface symTable; +}; + namespace { struct ArithToAPFloatConversionPass final : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> { @@ -340,6 +480,7 @@ void ArithToAPFloatConversionPass::runOnOperation() { /*isUnsigned=*/false); patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(), /*isUnsigned=*/true); + patterns.add<CmpFOpToAPFloatConversion>(context, getOperation()); LogicalResult result = success(); ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) { if (diag.getSeverity() == DiagnosticSeverity::Error) { diff --git a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp index 44980ccd77491..77f7137264888 100644 --- a/mlir/lib/ExecutionEngine/APFloatWrappers.cpp +++ b/mlir/lib/ExecutionEngine/APFloatWrappers.cpp @@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int( llvm::RoundingMode::NearestTiesToEven); return result.bitcastToAPInt().getZExtValue(); } + +MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics, + uint64_t a, + uint64_t b) { + const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics( + static_cast<llvm::APFloatBase::Semantics>(semantics)); + unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem); + llvm::APFloat x(sem, llvm::APInt(bitWidth, a)); + llvm::APFloat y(sem, llvm::APInt(bitWidth, b)); + return static_cast<int8_t>(x.compare(y)); +} } diff --git a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir index d71d81dddcd4f..78ce3640ecc67 100644 --- a/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir +++ b/mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir @@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) { %0 = arith.uitofp %arg0 : i32 to f4E2M1FN return } + +// ----- + +// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8 +// CHECK: %[[sem:.*]] = arith.constant 18 : i32 +// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8 +// CHECK: %[[c3:.*]] = arith.constant 3 : i8 +// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8 +// CHECK: %[[c0:.*]] = arith.constant 0 : i8 +// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8 +// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1 +func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) { + %0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN + return +} diff --git a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir index 8046610d479a8..433d058d025cf 100644 --- a/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir +++ b/mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir @@ -43,6 +43,10 @@ func.func @entry() { %cvt = arith.truncf %b2 : f32 to f8E4M3FN vector.print %cvt : f8E4M3FN + // CHECK-NEXT: 1 + %cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN + vector.print %cmp1 : i1 + // CHECK-NEXT: 1 // Bit pattern: 01, interpreted as signed integer: 1 %cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2 
@matthias-springer matthias-springer force-pushed the users/matthias-springer/apfloat_cmpf branch from 97977f8 to 98853ee Compare November 27, 2025 01:16
@matthias-springer matthias-springer force-pushed the users/matthias-springer/apfloat_cmpf branch from 98853ee to 396f4f9 Compare November 27, 2025 01:19
@matthias-springer matthias-springer merged commit 4d7abe5 into main Dec 1, 2025
10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/apfloat_cmpf branch December 1, 2025 08:12
aahrun pushed a commit to aahrun/llvm-project that referenced this pull request Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment