Skip to content

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 26, 2025

This commit adds support for replaceUsesWithIf (and variants such as replaceAllUsesExcept) to the ConversionPatternRewriter. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.)

This commit brings us a bit closer towards removing this workaround.

Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the ConversionValueMapping for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later.

This commit turns replaceUsesWithIf into a virtual function, so that the ConversionPatternRewriter can override it. All other API functions for conditional value replacements call that function.

Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call replaceUsesWithIf etc. directly on the Value object.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds support for replaceUsesWithIf (and variants such as replaceAllUsesExcept) to the ConversionPatternRewriter. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.)

Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the ConversionValueMapping for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later.

This commit turns replaceUsesWithIf into a virtual function, so that the ConversionPatternRewriter can override it. All other API functions for conditional value replacements call that function.

Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call replaceUsesWithIf etc. directly on the Value object.


Full diff: https://github.com/llvm/llvm-project/pull/169606.diff

5 Files Affected:

  • (modified) mlir/include/mlir/IR/PatternMatch.h (+3-3)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+21)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+52-25)
  • (added) mlir/test/Transforms/test-legalizer-no-rollback.mlir (+23)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+7-1)
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 576481a6e7215..35f7290a235c2 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -675,9 +675,9 @@ class RewriterBase : public OpBuilder { /// true. Also notify the listener about every in-place op modification (for /// every use that was replaced). The optional `allUsesReplaced` flag is set /// to "true" if all uses were replaced. - void replaceUsesWithIf(Value from, Value to, - function_ref<bool(OpOperand &)> functor, - bool *allUsesReplaced = nullptr); + virtual void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr); void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref<bool(OpOperand &)> functor, bool *allUsesReplaced = nullptr); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 5ac9e26e8636d..9f449080b0f37 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -903,6 +903,27 @@ class ConversionPatternRewriter final : public PatternRewriter { replaceAllUsesWith(from, ValueRange{to}); } + /// Replace the uses of `from` with `to` for which the `functor` returns + /// "true". The conversion driver will try to reconcile all type mismatches + /// that still exist at the end of the conversion with materializations. + /// This function supports both 1:1 and 1:N replacements. + /// + /// Note: The functor is also applied to builtin.unrealized_conversion_cast + /// ops that may have been inserted by the conversion driver. Some uses may + /// have been wrapped in unrealized_conversion_cast ops due to type changes. + /// + /// Note: This function is not supported in rollback mode. Calling it in + /// rollback mode will trigger an assertion. Furthermore, the + /// `allUsesReplaced` flag is not supported yet. + void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr) override { + replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced); + } + void replaceUsesWithIf(Value from, ValueRange to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr); + /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case /// of failure, the remapped value otherwise. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 27e3ec6f64c8f..c9f1596c07cbe 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -976,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues); /// Replace the uses of the given value with the given values. The specified - /// converter is used to build materializations (if necessary). - void replaceAllUsesWith(Value from, ValueRange to, - const TypeConverter *converter); + /// converter is used to build materializations (if necessary). If `functor` + /// is specified, only the uses that the functor returns "true" for are + /// replaced. + void replaceValueUses(Value from, ValueRange to, + const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor = nullptr); /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1203,11 +1206,16 @@ void BlockTypeConversionRewrite::rollback() { } /// Replace all uses of `from` with `repl`. -static void performReplaceValue(RewriterBase &rewriter, Value from, - Value repl) { +static void +performReplaceValue(RewriterBase &rewriter, Value from, Value repl, + function_ref<bool(OpOperand &)> functor = nullptr) { if (isa<BlockArgument>(repl)) { // `repl` is a block argument. Directly replace all uses. - rewriter.replaceAllUsesWith(from, repl); + if (functor) { + rewriter.replaceUsesWithIf(from, repl, functor); + } else { + rewriter.replaceAllUsesWith(from, repl); + } return; } @@ -1238,7 +1246,11 @@ static void performReplaceValue(RewriterBase &rewriter, Value from, Block *replBlock = replOp->getBlock(); rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); - return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + bool result = + user->getBlock() != replBlock || replOp->isBeforeInBlock(user); + if (functor) + result &= functor(operand); + return result; }); } @@ -1646,7 +1658,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, /*isPureTypeConversion=*/false) .front(); - replaceAllUsesWith(origArg, mat, converter); + replaceValueUses(origArg, mat, converter); continue; } @@ -1655,14 +1667,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - replaceAllUsesWith(origArg, inputMap->replacementValues, converter); + replaceValueUses(origArg, inputMap->replacementValues, converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - replaceAllUsesWith(origArg, replArgs, converter); + replaceValueUses(origArg, replArgs, converter); } if (config.allowPatternRollback) @@ -1962,8 +1974,24 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::replaceAllUsesWith( - Value from, ValueRange to, const TypeConverter *converter) { +void ConversionPatternRewriterImpl::replaceValueUses( + Value from, ValueRange to, const TypeConverter *converter, + function_ref<bool(OpOperand &)> functor) { + LLVM_DEBUG({ + logger.startLine() << "** Replace Value : '" << from << "'"; + if (auto blockArg = dyn_cast<BlockArgument>(from)) { + if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { + logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")"; + } else { + logger.getOStream() << " (unlinked block)"; + } + } + if (functor) { + logger.getOStream() << ", conditional replacement"; + } + }); + if (!config.allowPatternRollback) { SmallVector<Value> toConv = llvm::to_vector(to); SmallVector<Value> repls = @@ -1973,7 +2001,7 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( if (!repl) return; - performReplaceValue(r, from, repl); + performReplaceValue(r, from, repl, functor); return; } @@ -1992,6 +2020,8 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith( replacedValues.insert(from); #endif // NDEBUG + assert(!functor && + "conditional value replacement is not supported in rollback mode"); mapping.map(from, to); appendRewrite<ReplaceValueRewrite>(from, converter); } @@ -2190,18 +2220,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( } void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) { - LLVM_DEBUG({ - impl->logger.startLine() << "** Replace Value : '" << from << "'"; - if (auto blockArg = dyn_cast<BlockArgument>(from)) { - if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { - impl->logger.getOStream() << " (in region of '" << parentOp->getName() - << "' (" << parentOp << ")\n"; - } else { - impl->logger.getOStream() << " (unlinked block)\n"; - } - } - }); - impl->replaceAllUsesWith(from, to, impl->currentTypeConverter); + impl->replaceValueUses(from, to, impl->currentTypeConverter); +} + +void ConversionPatternRewriter::replaceUsesWithIf( + Value from, ValueRange to, function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced) { + assert(!allUsesReplaced && + "allUsesReplaced is not supported in a dialect conversion"); + impl->replaceValueUses(from, to, impl->currentTypeConverter, functor); } Value ConversionPatternRewriter::getRemappedValue(Value key) { diff --git a/mlir/test/Transforms/test-legalizer-no-rollback.mlir b/mlir/test/Transforms/test-legalizer-no-rollback.mlir new file mode 100644 index 0000000000000..5f421a35d956b --- /dev/null +++ b/mlir/test/Transforms/test-legalizer-no-rollback.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @conditional_replacement( +// CHECK-SAME: %[[arg0:.*]]: i43) +// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42 +// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42 +// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42 +// Uses were replaced for dummy_user_1. +// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> () +// Uses were also replaced for dummy_user_2, but not by value_replace. The uses +// were replaced due to the block signature conversion. +// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> () +// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> () +func.func @conditional_replacement(%arg0: i42) { + %repl = "test.legal_op"() : () -> (i42) + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> () + // expected-remark @+1 {{is not legalizable}} + "test.dummy_user_2"(%arg0) {} : (i42) -> () + // Perform a conditional 1:N replacement. + "test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 9b64bc691588d..7eabaaeb41500 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern { // Replace the first operand with 2x the second operand. Value from = op->getOperand(0); Value repl = op->getOperand(1); - rewriter.replaceAllUsesWith(from, {repl, repl}); + if (op->hasAttr("conditional")) { + rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) { + return use.getOwner()->hasAttr("replace_uses"); + }); + } else { + rewriter.replaceAllUsesWith(from, {repl, repl}); + } rewriter.modifyOpInPlace(op, [&] { // If the "trigger_rollback" attribute is set, keep the op illegal, so // that a rollback is triggered. 
@dcaballe
Copy link
Contributor

Thanks, Matthias!

Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the ConversionValueMapping for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later.

My understanding is that the rollback implementation should perform a conditional in-place replacement for the existing users at the time of the invocation + a conditional differed replacement of future users. Is the problem the evaluation of the condition for the differed users? A couple of points:

  • replaceAllUsesExcept is less generic than replaceUsesWithIf so the former might be easier to implement than the latter. For example, for replaceAllUsesExcept, the filtering condition is limited to ops (users) that must exist at the time of the invocation (I'm not sure how we could refer in the condition to future ops that haven't been created yet!). This means that the condition only applies to the "in-place" part of the replacement and will always be true for the differed part, right? This turns the replacement into "a conditional in-place replacement for the existing users at the time of the invocation + an unconditional differed replacement of future users". This should be easier to implement and would cover the hack in replaceAllUsesWith.

  • For the generic replaceUsesWithIf, could we store the conditions (lambda functions) in the driver and use them to decide if an entry should be added to ConversionValueMapping?

If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call replaceUsesWithIf etc. directly on the Value object.

Would it be possible to preserve the existing functionality for the rollback driver under a different name (e.g., replaceUsesWithIfXYZ) until we have the proper implementation, then remove it? The problem I see is that this PR could be a dead end for users of the current replaceUsesWithIf in rollback mode if the migration to the no-rollback driver is not possible (i.e., in my case, the no-rollback mode would invalidate analysis that it's needed to do the conversion).

@matthias-springer matthias-springer enabled auto-merge (squash) November 27, 2025 01:42
@matthias-springer matthias-springer merged commit 504b507 into main Nov 27, 2025
8 of 9 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/replace_uses_functor branch November 27, 2025 01:54
tanji-dg pushed a commit to tanji-dg/llvm-project that referenced this pull request Nov 27, 2025
…thIf` (llvm#169606) This commit adds support for `replaceUsesWithIf` (and variants such as `replaceAllUsesExcept`) to the `ConversionPatternRewriter`. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.) This commit brings us a bit closer towards removing [this](https://github.com/llvm/llvm-project/blob/76ec25f729fcc7ae576caf21293cc393e68e7cf7/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1214) workaround. Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the `ConversionValueMapping` for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later. This commit turns `replaceUsesWithIf` into a virtual function, so that the `ConversionPatternRewriter` can override it. All other API functions for conditional value replacements call that function. Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call `replaceUsesWithIf` etc. directly on the `Value` object.
GeneraluseAI pushed a commit to GeneraluseAI/llvm-project that referenced this pull request Nov 27, 2025
…thIf` (llvm#169606) This commit adds support for `replaceUsesWithIf` (and variants such as `replaceAllUsesExcept`) to the `ConversionPatternRewriter`. This API is supported only in no-rollback mode. An assertion is triggered in rollback mode. (This missing assertion has been confusing for users because it seemed that the API supported, while it was actually not working properly.) This commit brings us a bit closer towards removing [this](https://github.com/llvm/llvm-project/blob/76ec25f729fcc7ae576caf21293cc393e68e7cf7/mlir/lib/Transforms/Utils/DialectConversion.cpp#L1214) workaround. Additional changes are needed to support this API in rollback mode. In particular, no entries should be added to the `ConversionValueMapping` for conditional replacements. It's unclear at this point if this API can be supported in rollback mode, so this is deferred to later. This commit turns `replaceUsesWithIf` into a virtual function, so that the `ConversionPatternRewriter` can override it. All other API functions for conditional value replacements call that function. Note for LLVM integration: If you are seeing failed assertions due to this change, you are using unsupported API in your dialect conversion. You have 3 options: (1) Migrate to the no-rollback driver. (2) Rewrite your patterns without the unsupported API. (3) Last resort: bypass the rewriter and call `replaceUsesWithIf` etc. directly on the `Value` object.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

6 participants