Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,9 @@ class RewriterBase : public OpBuilder {
/// true. Also notify the listener about every in-place op modification (for
/// every use that was replaced). The optional `allUsesReplaced` flag is set
/// to "true" if all uses were replaced.
void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
virtual void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
replaceAllUsesWith(from, ValueRange{to});
}

/// Replace the uses of `from` with `to` for which the `functor` returns
/// "true". The conversion driver will try to reconcile all type mismatches
/// that still exist at the end of the conversion with materializations.
/// This function supports both 1:1 and 1:N replacements.
///
/// Note: The functor is also applied to builtin.unrealized_conversion_cast
/// ops that may have been inserted by the conversion driver. Some uses may
/// have been wrapped in unrealized_conversion_cast ops due to type changes.
///
/// Note: This function is not supported in rollback mode. Calling it in
/// rollback mode will trigger an assertion. Furthermore, the
/// `allUsesReplaced` flag is not supported yet.
void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr) override {
replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced);
}
void replaceUsesWithIf(Value from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);

/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
Expand Down
78 changes: 53 additions & 25 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);

/// Replace the uses of the given value with the given values. The specified
/// converter is used to build materializations (if necessary).
void replaceAllUsesWith(Value from, ValueRange to,
const TypeConverter *converter);
/// converter is used to build materializations (if necessary). If `functor`
/// is specified, only the uses that the functor returns "true" for are
/// replaced.
void replaceValueUses(Value from, ValueRange to,
const TypeConverter *converter,
function_ref<bool(OpOperand &)> functor = nullptr);

/// Erase the given block and its contents.
void eraseBlock(Block *block);
Expand Down Expand Up @@ -1203,11 +1206,16 @@ void BlockTypeConversionRewrite::rollback() {
}

/// Replace all uses of `from` with `repl`.
static void performReplaceValue(RewriterBase &rewriter, Value from,
Value repl) {
static void
performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
function_ref<bool(OpOperand &)> functor = nullptr) {
if (isa<BlockArgument>(repl)) {
// `repl` is a block argument. Directly replace all uses.
rewriter.replaceAllUsesWith(from, repl);
if (functor) {
rewriter.replaceUsesWithIf(from, repl, functor);
} else {
rewriter.replaceAllUsesWith(from, repl);
}
return;
}

Expand Down Expand Up @@ -1238,7 +1246,11 @@ static void performReplaceValue(RewriterBase &rewriter, Value from,
Block *replBlock = replOp->getBlock();
rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
bool result =
user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
if (result && functor)
result &= functor(operand);
return result;
});
}

Expand Down Expand Up @@ -1646,7 +1658,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
/*isPureTypeConversion=*/false)
.front();
replaceAllUsesWith(origArg, mat, converter);
replaceValueUses(origArg, mat, converter);
continue;
}

Expand All @@ -1655,14 +1667,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
replaceValueUses(origArg, inputMap->replacementValues, converter);
continue;
}

// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
replaceAllUsesWith(origArg, replArgs, converter);
replaceValueUses(origArg, replArgs, converter);
}

if (config.allowPatternRollback)
Expand Down Expand Up @@ -1962,8 +1974,24 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}

void ConversionPatternRewriterImpl::replaceAllUsesWith(
Value from, ValueRange to, const TypeConverter *converter) {
void ConversionPatternRewriterImpl::replaceValueUses(
Value from, ValueRange to, const TypeConverter *converter,
function_ref<bool(OpOperand &)> functor) {
LLVM_DEBUG({
logger.startLine() << "** Replace Value : '" << from << "'";
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
logger.getOStream() << " (in region of '" << parentOp->getName()
<< "' (" << parentOp << ")";
} else {
logger.getOStream() << " (unlinked block)";
}
}
if (functor) {
logger.getOStream() << ", conditional replacement";
}
});

if (!config.allowPatternRollback) {
SmallVector<Value> toConv = llvm::to_vector(to);
SmallVector<Value> repls =
Expand All @@ -1973,7 +2001,7 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
if (!repl)
return;

performReplaceValue(r, from, repl);
performReplaceValue(r, from, repl, functor);
return;
}

Expand All @@ -1992,6 +2020,9 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
replacedValues.insert(from);
#endif // NDEBUG

if (functor)
llvm::report_fatal_error(
"conditional value replacement is not supported in rollback mode");
mapping.map(from, to);
appendRewrite<ReplaceValueRewrite>(from, converter);
}
Expand Down Expand Up @@ -2190,18 +2221,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
}

void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
LLVM_DEBUG({
impl->logger.startLine() << "** Replace Value : '" << from << "'";
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
impl->logger.getOStream() << " (in region of '" << parentOp->getName()
<< "' (" << parentOp << ")\n";
} else {
impl->logger.getOStream() << " (unlinked block)\n";
}
}
});
impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
impl->replaceValueUses(from, to, impl->currentTypeConverter);
}

void ConversionPatternRewriter::replaceUsesWithIf(
Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced) {
assert(!allUsesReplaced &&
"allUsesReplaced is not supported in a dialect conversion");
impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
}

Value ConversionPatternRewriter::getRemappedValue(Value key) {
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Transforms/test-legalizer-no-rollback.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @conditional_replacement(
// CHECK-SAME: %[[arg0:.*]]: i43)
// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42
// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42
// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42
// Uses were replaced for dummy_user_1.
// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> ()
// Uses were also replaced for dummy_user_2, but not by value_replace. The uses
// were replaced due to the block signature conversion.
// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> ()
// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> ()
func.func @conditional_replacement(%arg0: i42) {
%repl = "test.legal_op"() : () -> (i42)
// expected-remark @+1 {{is not legalizable}}
"test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> ()
// expected-remark @+1 {{is not legalizable}}
"test.dummy_user_2"(%arg0) {} : (i42) -> ()
// Perform a conditional 1:N replacement.
"test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> ()
"test.return"() : () -> ()
}
8 changes: 7 additions & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern {
// Replace the first operand with 2x the second operand.
Value from = op->getOperand(0);
Value repl = op->getOperand(1);
rewriter.replaceAllUsesWith(from, {repl, repl});
if (op->hasAttr("conditional")) {
rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) {
return use.getOwner()->hasAttr("replace_uses");
});
} else {
rewriter.replaceAllUsesWith(from, {repl, repl});
}
rewriter.modifyOpInPlace(op, [&] {
// If the "trigger_rollback" attribute is set, keep the op illegal, so
// that a rollback is triggered.
Expand Down