Skip to content

Conversation

@CoTinker
Copy link
Contributor

@CoTinker CoTinker commented Sep 22, 2024

This PR fixes multiple bugs in DuplicateFunctionElimination.

  • Prevents elimination of function declarations.
  • Updates all symbol uses to reference unique function representatives.

Fixes #93483.

@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2024

@llvm/pr-subscribers-mlir-func

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes multiple bugs in DuplicateFunctionElimination.

  • Prevents elimination of function declarations.
  • Updates constant ops to reference unique function representatives.
  • Simplifies DenseMap by using StringRef as the key instead of StringAttr.

Fixes #93483.


Full diff: https://github.com/llvm/llvm-project/pull/109571.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp (+12-3)
  • (modified) mlir/test/Dialect/Func/duplicate-function-elimination.mlir (+48)
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp index d41d6c3e8972f9..5e23207eabf9c4 100644 --- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; + + if (lhs.isDeclaration() || rhs.isDeclaration()) + return false; + // Check discardable attributes equivalence if (lhs->getDiscardableAttrDictionary() != rhs->getDiscardableAttrDictionary()) @@ -87,11 +91,11 @@ struct DuplicateFunctionEliminationPass // Find unique representant per equivalent func ops. DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps; - DenseMap<StringAttr, func::FuncOp> getRepresentant; + DenseMap<StringRef, func::FuncOp> getRepresentant; DenseSet<func::FuncOp> toBeErased; module.walk([&](func::FuncOp f) { auto [repr, inserted] = uniqueFuncOps.insert(f); - getRepresentant[f.getSymNameAttr()] = *repr; + getRepresentant[f.getSymName()] = *repr; if (!inserted) { toBeErased.insert(f); } @@ -99,9 +103,14 @@ struct DuplicateFunctionEliminationPass // Update call ops to call unique func op representants. module.walk([&](func::CallOp callOp) { - func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()]; + func::FuncOp callee = getRepresentant[callOp.getCallee()]; callOp.setCallee(callee.getSymName()); }); + // Update constant ops to reference unique func op representants. + module.walk([&](func::ConstantOp constantOp) { + func::FuncOp value = getRepresentant[constantOp.getValue()]; + constantOp.setValue(value.getSymName()); + }); // Erase redundant func ops. for (auto it : toBeErased) { diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir index 28d059a149bde8..1c6876c1327bc6 100644 --- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir +++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir @@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32) // CHECK: @user // CHECK-2: call @deep_tree // CHECK: call @reverse_deep_tree + +// ----- + +func.func private @func_declaration(i32, i32) -> i32 +func.func private @func_declaration1(i32, i32) -> i32 + +func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) { + %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32 + %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32 + return %0, %1 : i32, i32 +} + +// CHECK: @func_declaration +// CHECK: @func_declaration1 +// CHECK: @user +// CHECK: call @func_declaration +// CHECK: call @func_declaration1 + + +// ----- + +func.func @identity(%arg0: tensor<f32>) -> tensor<f32> { + return %arg0 : tensor<f32> +} + +func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> { + return %arg0 : tensor<f32> +} + +func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> { + return %arg0 : tensor<f32> +} + +func.func @user(%arg0: tensor<f32>) -> tensor<f32> { + %f = constant @identity : (tensor<f32>) -> tensor<f32> + %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32> + %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32> + %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32> + %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32> + return %2 : tensor<f32> +} + +// CHECK: @identity +// CHECK-NOT: @also_identity +// CHECK-NOT: @yet_another_identity +// CHECK: @user +// CHECK-2: constant @identity +// CHECK: call @identity 
@llvmbot
Copy link
Member

llvmbot commented Sep 22, 2024

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes multiple bugs in DuplicateFunctionElimination.

  • Prevents elimination of function declarations.
  • Updates constant ops to reference unique function representatives.
  • Simplifies DenseMap by using StringRef as the key instead of StringAttr.

Fixes #93483.


Full diff: https://github.com/llvm/llvm-project/pull/109571.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp (+12-3)
  • (modified) mlir/test/Dialect/Func/duplicate-function-elimination.mlir (+48)
diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp index d41d6c3e8972f9..5e23207eabf9c4 100644 --- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -54,6 +54,10 @@ struct DuplicateFuncOpEquivalenceInfo if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; + + if (lhs.isDeclaration() || rhs.isDeclaration()) + return false; + // Check discardable attributes equivalence if (lhs->getDiscardableAttrDictionary() != rhs->getDiscardableAttrDictionary()) @@ -87,11 +91,11 @@ struct DuplicateFunctionEliminationPass // Find unique representant per equivalent func ops. DenseSet<func::FuncOp, DuplicateFuncOpEquivalenceInfo> uniqueFuncOps; - DenseMap<StringAttr, func::FuncOp> getRepresentant; + DenseMap<StringRef, func::FuncOp> getRepresentant; DenseSet<func::FuncOp> toBeErased; module.walk([&](func::FuncOp f) { auto [repr, inserted] = uniqueFuncOps.insert(f); - getRepresentant[f.getSymNameAttr()] = *repr; + getRepresentant[f.getSymName()] = *repr; if (!inserted) { toBeErased.insert(f); } @@ -99,9 +103,14 @@ struct DuplicateFunctionEliminationPass // Update call ops to call unique func op representants. module.walk([&](func::CallOp callOp) { - func::FuncOp callee = getRepresentant[callOp.getCalleeAttr().getAttr()]; + func::FuncOp callee = getRepresentant[callOp.getCallee()]; callOp.setCallee(callee.getSymName()); }); + // Update constant ops to reference unique func op representants. + module.walk([&](func::ConstantOp constantOp) { + func::FuncOp value = getRepresentant[constantOp.getValue()]; + constantOp.setValue(value.getSymName()); + }); // Erase redundant func ops. for (auto it : toBeErased) { diff --git a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir index 28d059a149bde8..1c6876c1327bc6 100644 --- a/mlir/test/Dialect/Func/duplicate-function-elimination.mlir +++ b/mlir/test/Dialect/Func/duplicate-function-elimination.mlir @@ -366,3 +366,51 @@ func.func @user(%p0: i1, %p1: i1, %p2: i1, %p3: i1, %odd: f32, %even: f32) // CHECK: @user // CHECK-2: call @deep_tree // CHECK: call @reverse_deep_tree + +// ----- + +func.func private @func_declaration(i32, i32) -> i32 +func.func private @func_declaration1(i32, i32) -> i32 + +func.func @user(%arg0: i32, %arg1: i32) -> (i32, i32) { + %0 = call @func_declaration(%arg0, %arg1) : (i32, i32) -> i32 + %1 = call @func_declaration1(%arg0, %arg1) : (i32, i32) -> i32 + return %0, %1 : i32, i32 +} + +// CHECK: @func_declaration +// CHECK: @func_declaration1 +// CHECK: @user +// CHECK: call @func_declaration +// CHECK: call @func_declaration1 + + +// ----- + +func.func @identity(%arg0: tensor<f32>) -> tensor<f32> { + return %arg0 : tensor<f32> +} + +func.func @also_identity(%arg0: tensor<f32>) -> tensor<f32> { + return %arg0 : tensor<f32> +} + +func.func @yet_another_identity(%arg0: tensor<f32>) -> tensor<f32> { + return %arg0 : tensor<f32> +} + +func.func @user(%arg0: tensor<f32>) -> tensor<f32> { + %f = constant @identity : (tensor<f32>) -> tensor<f32> + %0 = call_indirect %f(%arg0) : (tensor<f32>) -> tensor<f32> + %f_0 = constant @also_identity : (tensor<f32>) -> tensor<f32> + %1 = call_indirect %f_0(%0) : (tensor<f32>) -> tensor<f32> + %2 = call @yet_another_identity(%1) : (tensor<f32>) -> tensor<f32> + return %2 : tensor<f32> +} + +// CHECK: @identity +// CHECK-NOT: @also_identity +// CHECK-NOT: @yet_another_identity +// CHECK: @user +// CHECK-2: constant @identity +// CHECK: call @identity 
@CoTinker CoTinker force-pushed the duplicate branch 2 times, most recently from 780b3d5 to 69f301a Compare September 22, 2024 08:54
@llvm llvm deleted a comment from github-actions bot Sep 22, 2024
@CoTinker
Copy link
Contributor Author

I've noticed that replaceAllSymbolUsesImpl uses collectSymbolScopes to traverse the region and addReplacement to perform replacements. Perhaps using SymbolTable::replaceAllSymbolUses inside the loop isn't as costly as we think.

/// The implementation of SymbolTable::replaceAllSymbolUses below.
template <typename SymbolT, typename IRUnitT>
static LogicalResult
replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
// Generate a new attribute to replace the given attribute.
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr oldAttr = scope.symbol;
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
AttrTypeReplacer replacer;
replacer.addReplacement(
[&](SymbolRefAttr attr) -> std::pair<Attribute, WalkResult> {
// Regardless of the match, don't walk nested SymbolRefAttrs, we don't
// want to accidentally replace an inner reference.
if (attr == oldAttr)

I'm not familiar with this, could you please take a look at it. @River707

@CoTinker
Copy link
Contributor Author

In the following scenario, SymbolTable::replaceAllSymbolUses is also inside a loop.

for (auto &op : *combinedModule.getBody()) {
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
// Do not support ops with operands or results.
// Global variables, spec constants, and functions won't have
// operands/results, but just for safety here.
if (op.getNumOperands() != 0 || op.getNumResults() != 0)
continue;
// Deduplicating functions are not supported yet.
if (isa<FuncOp>(op))
continue;
auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
if (result.second)
continue;
SymbolOpInterface replacementSymOp = result.first->second;
if (failed(SymbolTable::replaceAllSymbolUses(
symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
symbolOp.emitError("unable to update all symbol uses for ")
<< symbolOp.getName() << " to " << replacementSymOp.getName();
return nullptr;
}
eraseList.push_back(symbolOp);
}

@joker-eph
Copy link
Collaborator

In the following scenario, SymbolTable::replaceAllSymbolUses is also inside a loop.

I suspect It's just as bad

@CoTinker CoTinker requested a review from River707 October 17, 2024 06:52
@christopherbate
Copy link
Contributor

@CoTinker I commented above with a suggestion for how to fix.

@CoTinker
Copy link
Contributor Author

@CoTinker I commented above with a suggestion for how to fix.

Thanks, I'll refer to it.

This PR fixes multiple bugs in `DuplicateFunctionElimination`. - Prevents elimination of function declarations. - Updates all symbol uses to reference unique function representatives.
@CoTinker CoTinker merged commit 2ce655c into llvm:main Oct 22, 2024
8 checks passed
@CoTinker CoTinker deleted the duplicate branch October 22, 2024 01:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

4 participants