- Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][Transforms] Dialect conversion: Add support for replaceUsesWithIf #169606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect conversion: Add support for replaceUsesWithIf #169606
Conversation
| @llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit adds support for Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the This commit turns Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call Full diff: https://github.com/llvm/llvm-project/pull/169606.diff 5 Files Affected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 576481a6e7215..35f7290a235c2 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -675,9 +675,9 @@ class RewriterBase : public OpBuilder { /// true. Also notify the listener about every in-place op modification (for /// every use that was replaced). The optional `allUsesReplaced` flag is set /// to "true" if all uses were replaced. - void replaceUsesWithIf(Value from, Value to, - function_ref<bool(OpOperand &)> functor, - bool *allUsesReplaced = nullptr); + virtual void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr); void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref<bool(OpOperand &)> functor, bool *allUsesReplaced = nullptr); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5ac9e26e8636d..9f449080b0f37 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -903,6 +903,27 @@ class ConversionPatternRewriter final : public PatternRewriter { replaceAllUsesWith(from, ValueRange{to}); } + /// Replace the uses of `from` with `to` for which the `functor` returns + /// "true". The conversion driver will try to reconcile all type mismatches + /// that still exist at the end of the conversion with materializations. + /// This function supports both 1:1 and 1:N replacements. + /// + /// Note: The functor is also applied to builtin.unrealized_conversion_cast + /// ops that may have been inserted by the conversion driver. Some uses may + /// have been wrapped in unrealized_conversion_cast ops due to type changes. + /// + /// Note: This function is not supported in rollback mode. Calling it in + /// rollback mode will trigger an assertion. Furthermore, the + /// `allUsesReplaced` flag is not supported yet. + void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr) override { + replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced); + } + void replaceUsesWithIf(Value from, ValueRange to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr); + /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case /// of failure, the remapped value otherwise. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 27e3ec6f64c8f..c9f1596c07cbe 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -976,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues); /// Replace the uses of the given value with the given values. The specified - /// converter is used to build materializations (if necessary). - void replaceAllUsesWith(Value from, ValueRange to, - const TypeConverter *converter); + /// converter is used to build materializations (if necessary). If `functor` + /// is specified, only the uses that the functor returns "true" for are + /// replaced. + void replaceValueUses(Value from, ValueRange to, + const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor = nullptr); /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1203,11 +1206,16 @@ void BlockTypeConversionRewrite::rollback() { } /// Replace all uses of `from` with `repl`. -static void performReplaceValue(RewriterBase &rewriter, Value from, - Value repl) { +static void +performReplaceValue(RewriterBase &rewriter, Value from, Value repl, + function_ref<bool(OpOperand &)> functor = nullptr) { if (isa<BlockArgument>(repl)) { // `repl` is a block argument. Directly replace all uses. - rewriter.replaceAllUsesWith(from, repl); + if (functor) { + rewriter.replaceUsesWithIf(from, repl, functor); + } else { + rewriter.replaceAllUsesWith(from, repl); + } return; } @@ -1238,7 +1246,11 @@ static void performReplaceValue(RewriterBase &rewriter, Value from, Block *replBlock = replOp->getBlock(); rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + bool result = + user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + if (functor) + result &= functor(operand); + return result; }); } @@ -1646,7 +1658,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, /*isPureTypeConversion=*/false) .front(); - replaceAllUsesWith(origArg, mat, converter); + replaceValueUses(origArg, mat, converter); continue; } @@ -1655,14 +1667,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - replaceAllUsesWith(origArg, inputMap->replacementValues, converter); + replaceValueUses(origArg, inputMap->replacementValues, converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - replaceAllUsesWith(origArg, replArgs, converter); + replaceValueUses(origArg, replArgs, converter); } if (config.allowPatternRollback) @@ -1962,8 +1974,24 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::replaceAllUsesWith( - Value from, ValueRange to, const TypeConverter *converter) { +void ConversionPatternRewriterImpl::replaceValueUses( + Value from, ValueRange to, const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor) { + LLVM_DEBUG({ + logger.startLine() << "** Replace Value : '" << from << "'"; + if (auto blockArg = dyn_cast<BlockArgument>(from)) { + if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { + logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")"; + } else { + logger.getOStream() << " (unlinked block)"; + } + } + if (functor) { + logger.getOStream() << ", conditional replacement"; + } + }); + if (!config.allowPatternRollback) { SmallVector<Value> toConv = llvm::to_vector(to); SmallVector<Value> repls = @@ -1973,7 +2001,7 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( if (!repl) return; - performReplaceValue(r, from, repl); + performReplaceValue(r, from, repl, functor); return; } @@ -1992,6 +2020,8 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( replacedValues.insert(from); #endif // NDEBUG + assert(!functor && + "conditional value replacement is not supported in rollback mode"); mapping.map(from, to); appendRewrite<ReplaceValueRewrite>(from, converter); } @@ -2190,18 +2220,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( } void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) { - LLVM_DEBUG({ - impl->logger.startLine() << "** Replace Value : '" << from << "'"; - if (auto blockArg = dyn_cast<BlockArgument>(from)) { - if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { - impl->logger.getOStream() << " (in region of '" << parentOp->getName() - << "' (" << parentOp << ")\n"; - } else { - impl->logger.getOStream() << " (unlinked block)\n"; - } - } - }); - impl->replaceAllUsesWith(from, to, impl->currentTypeConverter); + impl->replaceValueUses(from, to, impl->currentTypeConverter); +} + +void ConversionPatternRewriter::replaceUsesWithIf( + Value from, ValueRange to, function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced) { + assert(!allUsesReplaced && + "allUsesReplaced is not supported in a dialect conversion"); + impl->replaceValueUses(from, to, impl->currentTypeConverter, functor); } Value ConversionPatternRewriter::getRemappedValue(Value key) { diff --git a/mlir/test/Transforms/test-legalizer-no-rollback.mlir b/mlir/test/Transforms/test-legalizer-no-rollback.mlir new file mode 100644 index 0000000000000..5f421a35d956b --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-rollback.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @conditional_replacement( +// CHECK-SAME: %[[arg0:.*]]: i43) +// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42 +// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42 +// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42 +// Uses were replaced for dummy_user_1. +// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> () +// Uses were also replaced for dummy_user_2, but not by value_replace. The uses +// were replaced due to the block signature conversion. +// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> () +// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> () +func.func @conditional_replacement(%arg0: i42) { + %repl = "test.legal_op"() : () -> (i42) + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> () + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_2"(%arg0) {} : (i42) -> () + // Perform a conditional 1:N replacement. + "test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 9b64bc691588d..7eabaaeb41500 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern { // Replace the first operand with 2x the second operand. Value from = op->getOperand(0); Value repl = op->getOperand(1); - rewriter.replaceAllUsesWith(from, {repl, repl}); + if (op->hasAttr("conditional")) { + rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) { + return use.getOwner()->hasAttr("replace_uses"); + }); + } else { + rewriter.replaceAllUsesWith(from, {repl, repl}); + } rewriter.modifyOpInPlace(op, [&] { // If the "trigger_rollback" attribute is set, keep the op illegal, so // that a rollback is triggered. |
| Thanks, Matthias!
My understanding is that the rollback implementation should perform a conditional in-place replacement for the existing users at the time of the invocation + a conditional differed replacement of future users. Is the problem the evaluation of the condition for the differed users? A couple of points:
Would it be possible to preserve the existing functionality for the rollback driver under a different name (e.g., |
…thIf` (llvm#169606) This commit adds support for `replaceUsesWithIf` (and variants such as `replaceAllUsesExcept`) to the `ConversionPatternRewriter`. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.) This commit brings us a bit closer towards removing [this](https://github.com/llvm/llvm-project/blob/76ec25f729fcc7ae576caf21293cc393e68e7cf7/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1214) workaround. Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the `ConversionValueMapping` for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later. This commit turns `replaceUsesWithIf` into a virtual function, so that the `ConversionPatternRewriter` can override it. All other API functions for conditional value replacements call that function. Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call `replaceUsesWithIf` etc. directly on the `Value` object.
…thIf` (llvm#169606) This commit adds support for `replaceUsesWithIf` (and variants such as `replaceAllUsesExcept`) to the `ConversionPatternRewriter`. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.) This commit brings us a bit closer towards removing [this](https://github.com/llvm/llvm-project/blob/76ec25f729fcc7ae576caf21293cc393e68e7cf7/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1214) workaround. Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the `ConversionValueMapping` for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later. This commit turns `replaceUsesWithIf` into a virtual function, so that the `ConversionPatternRewriter` can override it. All other API functions for conditional value replacements call that function. Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call `replaceUsesWithIf` etc. directly on the `Value` object.
This commit adds support for
replaceUsesWithIf(and variants such asreplaceAllUsesExcept) to theConversionPatternRewriter. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.)This commit brings us a bit closer towards removing this workaround.
Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the
ConversionValueMappingfor conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later.This commit turns
replaceUsesWithIfinto a virtual function, so that theConversionPatternRewritercan override it. All other API functions for conditional value replacements call that function.Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call
replaceUsesWithIfetc. directly on theValueobject.