@@ -41,15 +41,17 @@ static FuncOp createFnDecl(OpBuilder &b, SymbolOpInterface symTable,
4141}
4242
4343// / Helper function to look up or create the symbol for a runtime library
44- // / function with the given parameter types. Always returns an int64_t.
44+ // / function with the given parameter types. Returns an int64_t, unless a
45+ // / different result type is specified.
4546static FailureOr<FuncOp>
4647lookupOrCreateApFloatFn (OpBuilder &b, SymbolOpInterface symTable,
4748 StringRef name, TypeRange paramTypes,
48- SymbolTableCollection *symbolTables = nullptr ) {
49- auto i64Type = IntegerType::get (symTable->getContext (), 64 );
50-
49+ SymbolTableCollection *symbolTables = nullptr ,
50+ Type resultType = {}) {
51+ if (!resultType)
52+ resultType = IntegerType::get (symTable->getContext (), 64 );
5153 std::string funcName = (llvm::Twine (" _mlir_apfloat_" ) + name).str ();
52- auto funcT = FunctionType::get (b.getContext (), paramTypes, {i64Type });
54+ auto funcT = FunctionType::get (b.getContext (), paramTypes, {resultType });
5355 FailureOr<FuncOp> func =
5456 lookupFnDecl (symTable, funcName, funcT, symbolTables);
5557 // Failed due to type mismatch.
@@ -308,6 +310,145 @@ struct IntToFpConversion final : OpRewritePattern<OpTy> {
308310 bool isUnsigned;
309311};
310312
313+ struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
314+ CmpFOpToAPFloatConversion (MLIRContext *context, SymbolOpInterface symTable,
315+ PatternBenefit benefit = 1 )
316+ : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
317+
318+ LogicalResult matchAndRewrite (arith::CmpFOp op,
319+ PatternRewriter &rewriter) const override {
320+ // Get APFloat function from runtime library.
321+ auto i1Type = IntegerType::get (symTable->getContext (), 1 );
322+ auto i8Type = IntegerType::get (symTable->getContext (), 8 );
323+ auto i32Type = IntegerType::get (symTable->getContext (), 32 );
324+ auto i64Type = IntegerType::get (symTable->getContext (), 64 );
325+ FailureOr<FuncOp> fn =
326+ lookupOrCreateApFloatFn (rewriter, symTable, " compare" ,
327+ {i32Type, i64Type, i64Type}, nullptr , i8Type);
328+ if (failed (fn))
329+ return fn;
330+
331+ // Cast operands to 64-bit integers.
332+ rewriter.setInsertionPoint (op);
333+ Location loc = op.getLoc ();
334+ auto floatTy = cast<FloatType>(op.getLhs ().getType ());
335+ auto intWType = rewriter.getIntegerType (floatTy.getWidth ());
336+ Value lhsBits = arith::ExtUIOp::create (
337+ rewriter, loc, i64Type,
338+ arith::BitcastOp::create (rewriter, loc, intWType, op.getLhs ()));
339+ Value rhsBits = arith::ExtUIOp::create (
340+ rewriter, loc, i64Type,
341+ arith::BitcastOp::create (rewriter, loc, intWType, op.getRhs ()));
342+
343+ // Call APFloat function.
344+ Value semValue = getSemanticsValue (rewriter, loc, floatTy);
345+ SmallVector<Value> params = {semValue, lhsBits, rhsBits};
346+ Value comparisonResult =
347+ func::CallOp::create (rewriter, loc, TypeRange (i8Type),
348+ SymbolRefAttr::get (*fn), params)
349+ ->getResult (0 );
350+
351+ // Generate an i1 SSA value that is "true" if the comparison result matches
352+ // the given `val`.
353+ auto checkResult = [&](llvm::APFloat::cmpResult val) {
354+ return arith::CmpIOp::create (
355+ rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
356+ arith::ConstantOp::create (
357+ rewriter, loc, i8Type,
358+ rewriter.getIntegerAttr (i8Type, static_cast <int8_t >(val)))
359+ .getResult ());
360+ };
361+ // Generate an i1 SSA value that is "true" if the comparison result matches
362+ // any of the given `vals`.
363+ std::function<Value (ArrayRef<llvm::APFloat::cmpResult>)> checkResults =
364+ [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
365+ Value first = checkResult (vals.front ());
366+ if (vals.size () == 1 )
367+ return first;
368+ Value rest = checkResults (vals.drop_front ());
369+ return arith::OrIOp::create (rewriter, loc, first, rest).getResult ();
370+ };
371+
372+ // This switch-case statement was taken from arith::applyCmpPredicate.
373+ Value result;
374+ switch (op.getPredicate ()) {
375+ case arith::CmpFPredicate::AlwaysFalse:
376+ result = arith::ConstantOp::create (rewriter, loc, i1Type,
377+ rewriter.getIntegerAttr (i1Type, 0 ))
378+ .getResult ();
379+ break ;
380+ case arith::CmpFPredicate::OEQ:
381+ result = checkResult (llvm::APFloat::cmpEqual);
382+ break ;
383+ case arith::CmpFPredicate::OGT:
384+ result = checkResult (llvm::APFloat::cmpGreaterThan);
385+ break ;
386+ case arith::CmpFPredicate::OGE:
387+ result = checkResults (
388+ {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
389+ break ;
390+ case arith::CmpFPredicate::OLT:
391+ result = checkResult (llvm::APFloat::cmpLessThan);
392+ break ;
393+ case arith::CmpFPredicate::OLE:
394+ result =
395+ checkResults ({llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
396+ break ;
397+ case arith::CmpFPredicate::ONE:
398+ // Not cmpUnordered and not cmpUnordered.
399+ result = checkResults (
400+ {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
401+ break ;
402+ case arith::CmpFPredicate::ORD:
403+ // Not cmpUnordered.
404+ result = checkResults ({llvm::APFloat::cmpLessThan,
405+ llvm::APFloat::cmpGreaterThan,
406+ llvm::APFloat::cmpEqual});
407+ break ;
408+ case arith::CmpFPredicate::UEQ:
409+ result =
410+ checkResults ({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
411+ break ;
412+ case arith::CmpFPredicate::UGT:
413+ result = checkResults (
414+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
415+ break ;
416+ case arith::CmpFPredicate::UGE:
417+ result = checkResults ({llvm::APFloat::cmpUnordered,
418+ llvm::APFloat::cmpGreaterThan,
419+ llvm::APFloat::cmpEqual});
420+ break ;
421+ case arith::CmpFPredicate::ULT:
422+ result = checkResults (
423+ {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
424+ break ;
425+ case arith::CmpFPredicate::ULE:
426+ result =
427+ checkResults ({llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan,
428+ llvm::APFloat::cmpEqual});
429+ break ;
430+ case arith::CmpFPredicate::UNE:
431+ // Not cmpEqual.
432+ result = checkResults ({llvm::APFloat::cmpLessThan,
433+ llvm::APFloat::cmpGreaterThan,
434+ llvm::APFloat::cmpUnordered});
435+ break ;
436+ case arith::CmpFPredicate::UNO:
437+ result = checkResult (llvm::APFloat::cmpUnordered);
438+ break ;
439+ case arith::CmpFPredicate::AlwaysTrue:
440+ result = arith::ConstantOp::create (rewriter, loc, i1Type,
441+ rewriter.getIntegerAttr (i1Type, 1 ))
442+ .getResult ();
443+ break ;
444+ }
445+ rewriter.replaceOp (op, result);
446+ return success ();
447+ }
448+
449+ SymbolOpInterface symTable;
450+ };
451+
311452namespace {
312453struct ArithToAPFloatConversionPass final
313454 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
@@ -340,6 +481,7 @@ void ArithToAPFloatConversionPass::runOnOperation() {
340481 /* isUnsigned=*/ false );
341482 patterns.add <IntToFpConversion<arith::UIToFPOp>>(context, getOperation (),
342483 /* isUnsigned=*/ true );
484+ patterns.add <CmpFOpToAPFloatConversion>(context, getOperation ());
343485 LogicalResult result = success ();
344486 ScopedDiagnosticHandler scopedHandler (context, [&result](Diagnostic &diag) {
345487 if (diag.getSeverity () == DiagnosticSeverity::Error) {
0 commit comments