Skip to content

Commit 4d7abe5

Browse files
[mlir][arith] Add support for cmpf to ArithToAPFloat (#169753)
Add support for `arith.cmpf`.
1 parent a751ed9 commit 4d7abe5

File tree

4 files changed

+177
-5
lines changed

4 files changed

+177
-5
lines changed

mlir/lib/Conversion/ArithToAPFloat/ArithToAPFloat.cpp

Lines changed: 147 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,17 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
4141
}
4242

4343
/// Helper function to look up or create the symbol for a runtime library
44-
/// function with the given parameter types. Always returns an int64_t.
44+
/// function with the given parameter types. Returns an int64_t, unless a
45+
/// different result type is specified.
4546
static FailureOr<FuncOp>
4647
lookupOrCreateApFloatFn(OpBuilder &b, SymbolOpInterface symTable,
4748
StringRef name, TypeRange paramTypes,
48-
SymbolTableCollection *symbolTables = nullptr) {
49-
auto i64Type = IntegerType::get(symTable->getContext(), 64);
50-
49+
SymbolTableCollection *symbolTables = nullptr,
50+
Type resultType = {}) {
51+
if (!resultType)
52+
resultType = IntegerType::get(symTable->getContext(), 64);
5153
std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
52-
auto funcT = FunctionType::get(b.getContext(), paramTypes, {i64Type});
54+
auto funcT = FunctionType::get(b.getContext(), paramTypes, {resultType});
5355
FailureOr<FuncOp> func =
5456
lookupFnDecl(symTable, funcName, funcT, symbolTables);
5557
// Failed due to type mismatch.
@@ -308,6 +310,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
308310
bool isUnsigned;
309311
};
310312

313+
struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
314+
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
315+
PatternBenefit benefit = 1)
316+
: OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
317+
318+
LogicalResult matchAndRewrite(arith::CmpFOp op,
319+
PatternRewriter &rewriter) const override {
320+
// Get APFloat function from runtime library.
321+
auto i1Type = IntegerType::get(symTable->getContext(), 1);
322+
auto i8Type = IntegerType::get(symTable->getContext(), 8);
323+
auto i32Type = IntegerType::get(symTable->getContext(), 32);
324+
auto i64Type = IntegerType::get(symTable->getContext(), 64);
325+
FailureOr<FuncOp> fn =
326+
lookupOrCreateApFloatFn(rewriter, symTable, "compare",
327+
{i32Type, i64Type, i64Type}, nullptr, i8Type);
328+
if (failed(fn))
329+
return fn;
330+
331+
// Cast operands to 64-bit integers.
332+
rewriter.setInsertionPoint(op);
333+
Location loc = op.getLoc();
334+
auto floatTy = cast<FloatType>(op.getLhs().getType());
335+
auto intWType = rewriter.getIntegerType(floatTy.getWidth());
336+
Value lhsBits = arith::ExtUIOp::create(
337+
rewriter, loc, i64Type,
338+
arith::BitcastOp::create(rewriter, loc, intWType, op.getLhs()));
339+
Value rhsBits = arith::ExtUIOp::create(
340+
rewriter, loc, i64Type,
341+
arith::BitcastOp::create(rewriter, loc, intWType, op.getRhs()));
342+
343+
// Call APFloat function.
344+
Value semValue = getSemanticsValue(rewriter, loc, floatTy);
345+
SmallVector<Value> params = {semValue, lhsBits, rhsBits};
346+
Value comparisonResult =
347+
func::CallOp::create(rewriter, loc, TypeRange(i8Type),
348+
SymbolRefAttr::get(*fn), params)
349+
->getResult(0);
350+
351+
// Generate an i1 SSA value that is "true" if the comparison result matches
352+
// the given `val`.
353+
auto checkResult = [&](llvm::APFloat::cmpResult val) {
354+
return arith::CmpIOp::create(
355+
rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
356+
arith::ConstantOp::create(
357+
rewriter, loc, i8Type,
358+
rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
359+
.getResult());
360+
};
361+
// Generate an i1 SSA value that is "true" if the comparison result matches
362+
// any of the given `vals`.
363+
std::function<Value(ArrayRef<llvm::APFloat::cmpResult>)> checkResults =
364+
[&](ArrayRef<llvm::APFloat::cmpResult> vals) {
365+
Value first = checkResult(vals.front());
366+
if (vals.size() == 1)
367+
return first;
368+
Value rest = checkResults(vals.drop_front());
369+
return arith::OrIOp::create(rewriter, loc, first, rest).getResult();
370+
};
371+
372+
// This switch-case statement was taken from arith::applyCmpPredicate.
373+
Value result;
374+
switch (op.getPredicate()) {
375+
case arith::CmpFPredicate::AlwaysFalse:
376+
result = arith::ConstantOp::create(rewriter, loc, i1Type,
377+
rewriter.getIntegerAttr(i1Type, 0))
378+
.getResult();
379+
break;
380+
case arith::CmpFPredicate::OEQ:
381+
result = checkResult(llvm::APFloat::cmpEqual);
382+
break;
383+
case arith::CmpFPredicate::OGT:
384+
result = checkResult(llvm::APFloat::cmpGreaterThan);
385+
break;
386+
case arith::CmpFPredicate::OGE:
387+
result = checkResults(
388+
{llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
389+
break;
390+
case arith::CmpFPredicate::OLT:
391+
result = checkResult(llvm::APFloat::cmpLessThan);
392+
break;
393+
case arith::CmpFPredicate::OLE:
394+
result =
395+
checkResults({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
396+
break;
397+
case arith::CmpFPredicate::ONE:
398+
// Not cmpUnordered and not cmpUnordered.
399+
result = checkResults(
400+
{llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
401+
break;
402+
case arith::CmpFPredicate::ORD:
403+
// Not cmpUnordered.
404+
result = checkResults({llvm::APFloat::cmpLessThan,
405+
llvm::APFloat::cmpGreaterThan,
406+
llvm::APFloat::cmpEqual});
407+
break;
408+
case arith::CmpFPredicate::UEQ:
409+
result =
410+
checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
411+
break;
412+
case arith::CmpFPredicate::UGT:
413+
result = checkResults(
414+
{llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
415+
break;
416+
case arith::CmpFPredicate::UGE:
417+
result = checkResults({llvm::APFloat::cmpUnordered,
418+
llvm::APFloat::cmpGreaterThan,
419+
llvm::APFloat::cmpEqual});
420+
break;
421+
case arith::CmpFPredicate::ULT:
422+
result = checkResults(
423+
{llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
424+
break;
425+
case arith::CmpFPredicate::ULE:
426+
result =
427+
checkResults({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
428+
llvm::APFloat::cmpEqual});
429+
break;
430+
case arith::CmpFPredicate::UNE:
431+
// Not cmpEqual.
432+
result = checkResults({llvm::APFloat::cmpLessThan,
433+
llvm::APFloat::cmpGreaterThan,
434+
llvm::APFloat::cmpUnordered});
435+
break;
436+
case arith::CmpFPredicate::UNO:
437+
result = checkResult(llvm::APFloat::cmpUnordered);
438+
break;
439+
case arith::CmpFPredicate::AlwaysTrue:
440+
result = arith::ConstantOp::create(rewriter, loc, i1Type,
441+
rewriter.getIntegerAttr(i1Type, 1))
442+
.getResult();
443+
break;
444+
}
445+
rewriter.replaceOp(op, result);
446+
return success();
447+
}
448+
449+
SymbolOpInterface symTable;
450+
};
451+
311452
namespace {
312453
struct ArithToAPFloatConversionPass final
313454
: impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -340,6 +481,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
340481
/*isUnsigned=*/false);
341482
patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
342483
/*isUnsigned=*/true);
484+
patterns.add<CmpFOpToAPFloatConversion>(context, getOperation());
343485
LogicalResult result = success();
344486
ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
345487
if (diag.getSeverity() == DiagnosticSeverity::Error) {

mlir/lib/ExecutionEngine/APFloatWrappers.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,15 @@ MLIR_APFLOAT_WRAPPERS_EXPORT uint64_t _mlir_apfloat_convert_from_int(
131131
llvm::RoundingMode::NearestTiesToEven);
132132
return result.bitcastToAPInt().getZExtValue();
133133
}
134+
135+
MLIR_APFLOAT_WRAPPERS_EXPORT int8_t _mlir_apfloat_compare(int32_t semantics,
136+
uint64_t a,
137+
uint64_t b) {
138+
const llvm::fltSemantics &sem = llvm::APFloatBase::EnumToSemantics(
139+
static_cast<llvm::APFloatBase::Semantics>(semantics));
140+
unsigned bitWidth = llvm::APFloatBase::semanticsSizeInBits(sem);
141+
llvm::APFloat x(sem, llvm::APInt(bitWidth, a));
142+
llvm::APFloat y(sem, llvm::APInt(bitWidth, b));
143+
return static_cast<int8_t>(x.compare(y));
144+
}
134145
}

mlir/test/Conversion/ArithToApfloat/arith-to-apfloat.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,18 @@ func.func @uitofp(%arg0: i32) {
198198
%0 = arith.uitofp %arg0 : i32 to f4E2M1FN
199199
return
200200
}
201+
202+
// -----
203+
204+
// CHECK: func.func private @_mlir_apfloat_compare(i32, i64, i64) -> i8
205+
// CHECK: %[[sem:.*]] = arith.constant 18 : i32
206+
// CHECK: %[[cmp:.*]] = call @_mlir_apfloat_compare(%[[sem]], %{{.*}}, %{{.*}}) : (i32, i64, i64) -> i8
207+
// CHECK: %[[c3:.*]] = arith.constant 3 : i8
208+
// CHECK: %[[is_unordered:.*]] = arith.cmpi eq, %[[cmp]], %[[c3]] : i8
209+
// CHECK: %[[c0:.*]] = arith.constant 0 : i8
210+
// CHECK: %[[is_lt:.*]] = arith.cmpi eq, %[[cmp]], %[[c0]] : i8
211+
// CHECK: arith.ori %[[is_unordered]], %[[is_lt]] : i1
212+
func.func @cmpf(%arg0: f4E2M1FN, %arg1: f4E2M1FN) {
213+
%0 = arith.cmpf "ult", %arg0, %arg1 : f4E2M1FN
214+
return
215+
}

mlir/test/Integration/Dialect/Arith/CPU/test-apfloat-emulation.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ func.func @entry() {
4343
%cvt = arith.truncf %b2 : f32 to f8E4M3FN
4444
vector.print %cvt : f8E4M3FN
4545

46+
// CHECK-NEXT: 1
47+
%cmp1 = arith.cmpf "olt", %cvt, %c1 : f8E4M3FN
48+
vector.print %cmp1 : i1
49+
4650
// CHECK-NEXT: 1
4751
// Bit pattern: 01, interpreted as signed integer: 1
4852
%cvt_int_signed = arith.fptosi %cvt : f8E4M3FN to i2

0 commit comments

Comments
 (0)