Skip to content

Conversation

@ingomueller-net
Copy link
Contributor

This PR adds a new transform op that replaces memref.allocas with memref.get_globals to newly inserted memref.globals. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir mlir:memref labels Sep 15, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2023

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Changes This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals. -- Full diff: https://github.com//pull/66511.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td (+65)
  • (modified) mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp (+90)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+3-2)
  • (modified) mlir/python/mlir/dialects/_memref_transform_ops_ext.py (+58)
  • (modified) mlir/test/Dialect/MemRef/transform-ops.mlir (+39)
  • (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+12)
  • (modified) mlir/test/python/dialects/transform_memref_ext.py (+48)
diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 681759f970cb910..6a78784d74dd53c 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -144,6 +144,71 @@ def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op<Transform_Dialect, } def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">; +def Transform_MemRefAllocaOp : Transform_ConcreteOpType<"memref.alloca">; + +def MemRefAllocaToGlobalOp : + Op<Transform_Dialect, "memref.alloca_to_global", + [TransformOpInterface, + DeclareOpInterfaceMethods<MemoryEffectsOpInterface>, + DeclareOpInterfaceMethods<TransformOpInterface>]> { + let description = [{ + Inserts a new `memref.global` for each provided `memref.alloca` into the + provided module and replaces it with a `memref.get_global`. This is useful, + for example, for allocations that should reside in the shared memory of + a GPU, which have to be declared as globals. + + #### Example + + Consider the following transform op: + + ```mlir + %get_global, %global = + transform.memref.alloca_to_global %alloca in %module + : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) + ``` + + and the following input payload: + + ```mlir + module { + func.func @func() { + %alloca = memref.alloca() : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + then applying the transform op to the payload would result in the following + output IR: + + ```mlir + module { + memref.global "private" @alloc : memref<2x32xf32> + func.func @func() { + %alloca = memref.get_global @alloc : memref<2x32xf32> + // usages of %alloca... + } + } + ``` + + #### Return modes + + Emits a definite failure if not exactly one `module` payload op was provided + or any of the `alloca` payload ops is not inside that module, and succeeds + otherwise. The returned handles refer to the `memref.get_global` and + `memref.global` ops that were inserted by the transformation. + }]; + + let arguments = (ins Transform_ConcreteOpType<"builtin.module">:$module, + Transform_MemRefAllocaOp:$alloca); + let results = (outs TransformHandleTypeInterface:$get_global, + TransformHandleTypeInterface:$global); + + let assemblyFormat = [{ + $alloca `in` $module attr-dict `:` functional-type(operands, results) + }]; +} def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer", [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface, diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 58f4d8d8f6d21fe..7467359da83c37f 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -126,6 +126,96 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); } +//===----------------------------------------------------------------------===// +// AllocaToGlobalOp +//===----------------------------------------------------------------------===// + +namespace { +static llvm::SmallString<64> getUniqueSymbol(llvm::StringRef prefix, + ModuleOp module) { + llvm::SmallString<64> candidateNameStorage; + StringRef candidateName(prefix); + int uniqueNumber = 0; + while (true) { + if (!module.lookupSymbol(candidateName)) { + break; + } + candidateNameStorage.clear(); + candidateName = (prefix + Twine("_") + Twine(uniqueNumber)) + .toStringRef(candidateNameStorage); + uniqueNumber++; + } + return candidateName; +} +} // namespace + +DiagnosedSilenceableFailure +transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + auto allocaOps = state.getPayloadOps(getAlloca()); + + SmallVector<memref::GlobalOp> globalOps; + SmallVector<memref::GetGlobalOp> getGlobalOps; + + // Get `builtin.module`. + auto moduleOps = state.getPayloadOps(getModule()); + if (!llvm::hasSingleElement(moduleOps)) { + return emitDefiniteFailure() + << Twine("expected exactly one 'module' payload, but found ") + + std::to_string(llvm::range_size(moduleOps)); + } + ModuleOp module = cast<ModuleOp>(*moduleOps.begin()); + + // Transform `memref.alloca`s. + for (auto *op : allocaOps) { + auto alloca = cast<memref::AllocaOp>(op); + MLIRContext *ctx = rewriter.getContext(); + Location loc = alloca->getLoc(); + + memref::GlobalOp globalOp; + { + // Insert a `memref.global` at the beginning of the module. + if (module != alloca->getParentOfType<ModuleOp>()) { + return emitDefiniteFailure() + << "expected 'alloca' payload to be inside 'module' payload"; + } + IRRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&module.getBodyRegion().front()); + Type resultType = alloca.getResult().getType(); + llvm::SmallString<64> symName = getUniqueSymbol("alloca", module); + // XXX: Add a better builder for this. + globalOp = rewriter.create<memref::GlobalOp>( + loc, StringAttr::get(ctx, symName), StringAttr::get(ctx, "private"), + TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + } + + // Replace the `memref.alloca` with a `memref.get_global` accessing the + // global symbol inserted above. + rewriter.setInsertionPoint(alloca); + auto getGlobalOp = rewriter.replaceOpWithNewOp<memref::GetGlobalOp>( + alloca, globalOp.getType(), globalOp.getName()); + + globalOps.push_back(globalOp); + getGlobalOps.push_back(getGlobalOp); + } + + // Assemble results. + results.set(getGlobal().cast<OpResult>(), globalOps); + results.set(getGetGlobal().cast<OpResult>(), getGlobalOps); + + return DiagnosedSilenceableFailure::success(); +} + +void transform::MemRefAllocaToGlobalOp::getEffects( + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { + onlyReadsHandle(getModule(), effects); + producesHandle(getGlobal(), effects); + producesHandle(getGetGlobal(), effects); + consumesHandle(getAlloca(), effects); + modifiesPayload(effects); +} + //===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index de3cd1b28e435bc..f1d07b85adb7576 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1233,7 +1233,7 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, DenseSet<Operation *> resultSet; for (Operation *target : state.getPayloadOps(getTarget())) { Operation *parent = target->getParentOp(); - do { + while (parent) { bool checkIsolatedFromAbove = !getIsolatedFromAbove() || parent->hasTrait<OpTrait::IsIsolatedFromAbove>(); @@ -1241,7 +1241,8 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter, parent->getName().getStringRef() == *getOpName(); if (checkIsolatedFromAbove && checkOpName) break; - } while ((parent = parent->getParentOp())); + parent = parent->getParentOp(); + } if (!parent) { DiagnosedSilenceableFailure diag = emitSilenceableError() diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py index 4afe8e7b887f68e..56dcfbe5655e9b6 100644 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py @@ -11,6 +11,64 @@ from typing import Optional, overload, Union +class MemRefAllocaToGlobalOp: + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + module: Union[Operation, OpView, Value], + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + module: Union[Operation, OpView, Value], + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + def __init__( + self, + get_global_type_or_module: Union[Operation, OpView, Type, Value], + global_type_or_alloca: Union[Operation, OpView, Type, Value], + module_or_none: Optional[Union[Operation, OpView, Value]] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None + ): + if isinstance(get_global_type_or_module, Type): + get_global_type = get_global_type_or_module + global_type = global_type_or_alloca + module = module_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + module = get_global_type_or_module + alloca = global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + module, + alloca, + loc=loc, + ip=ip, + ) + + class MemRefMultiBufferOp: """Specialization for MemRefMultiBufferOp class.""" diff --git a/mlir/test/Dialect/MemRef/transform-ops.mlir b/mlir/test/Dialect/MemRef/transform-ops.mlir index b19db447af1c28a..aeeb2a6b0abedc5 100644 --- a/mlir/test/Dialect/MemRef/transform-ops.mlir +++ b/mlir/test/Dialect/MemRef/transform-ops.mlir @@ -1,5 +1,44 @@ // RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -allow-unregistered-dialect -split-input-file | FileCheck %s +// CHECK-DAG: memref.global "private" @[[ALLOC0:alloc.*]] : memref<2x32xf32> +// CHECK-DAG: memref.global "private" @[[ALLOC1:alloc.*]] : memref<2x32xf32> + +// CHECK: func.func @func( +func.func @func(%arg0: f32) { + %c3 = arith.constant 3 : index + %c1 = arith.constant 1 : index + // CHECK: scf.forall + scf.forall (%arg1, %arg2) in (%c3, %c1) { + // CHECK-DAG: %[[MR0:.*]] = memref.get_global @[[ALLOC0]] : memref<2x32xf32> + // CHECK-DAG: %[[MR1:.*]] = memref.get_global @[[ALLOC1]] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR0]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + // CHECK-DAG: memref.store %{{.*}}, %[[MR1]][%{{.*}}, %{{.*}}] : memref<2x32xf32> + %alloca = memref.alloca() : memref<2x32xf32> + %alloca_0 = memref.alloca() : memref<2x32xf32> + memref.store %arg0, %alloca[%arg1, %arg2] : memref<2x32xf32> + memref.store %arg0, %alloca_0[%arg1, %arg2] : memref<2x32xf32> + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg0: !transform.any_op): + %alloca = transform.structured.match ops{["memref.alloca"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %module = transform.structured.match ops{["builtin.module"]} in %arg0 + : (!transform.any_op) -> !transform.any_op + %alloca_typed = transform.cast %alloca + : !transform.any_op to !transform.op<"memref.alloca"> + %module_typed = transform.cast %module + : !transform.any_op to !transform.op<"builtin.module"> + %get_global, %global = + transform.memref.alloca_to_global %alloca_typed in %module_typed + : (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + -> (!transform.any_op, !transform.any_op) +} + +// ----- + // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 68e3a4851539690..91a283c799941bb 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -1891,6 +1891,18 @@ transform.sequence failures(propagate) { test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op } + +// ----- + +// expected-note @below {{target op}} +module { + transform.sequence failures(propagate) { + ^bb0(%arg0: !pdl.operation): + // expected-error @below{{could not find a parent op that matches all requirements}} + %3 = get_parent_op %arg0 {op_name = "builtin.module"} : (!pdl.operation) -> !transform.any_op + } +} + // ----- func.func @cast(%arg0: f32) -> f64 { diff --git a/mlir/test/python/dialects/transform_memref_ext.py b/mlir/test/python/dialects/transform_memref_ext.py index f89005cb2f86d1b..8278019bbab3b89 100644 --- a/mlir/test/python/dialects/transform_memref_ext.py +++ b/mlir/test/python/dialects/transform_memref_ext.py @@ -16,6 +16,54 @@ def run(f): return f +@run +def testMemRefAllocaToAllocOpCompact(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + module = transform.CastOp( + transform.OperationType.get("builtin.module"), sequence.bodyTarget + ) + alloca = transform.CastOp( + transform.OperationType.get("memref.alloca"), sequence.bodyTarget + ) + memref.MemRefAllocaToGlobalOp(module, alloca) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpCompact + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: (!transform.op<"builtin.module">, !transform.op<"memref.alloca">) + # CHECK-SAME: -> (!transform.any_op, !transform.any_op) + + +@run +def testMemRefAllocaToAllocOpTyped(): + sequence = transform.SequenceOp( + transform.FailurePropagationMode.Propagate, + [], + transform.OperationType.get("memref.alloc"), + ) + with InsertionPoint(sequence.body): + module = transform.CastOp( + transform.OperationType.get("builtin.module"), sequence.bodyTarget + ) + alloca = transform.CastOp( + transform.OperationType.get("memref.alloca"), sequence.bodyTarget + ) + memref.MemRefAllocaToGlobalOp( + transform.OperationType.get("memref.get_global"), + transform.OperationType.get("memref.global"), + module, + alloca, + ) + transform.YieldOp() + # CHECK-LABEL: TEST: testMemRefAllocaToAllocOpTyped + # CHECK: = transform.memref.alloca_to_global + # CHECK-SAME: -> (!transform.op<"memref.get_global">, !transform.op<"memref.global">) + + @run def testMemRefMultiBufferOpCompact(): sequence = transform.SequenceOp( 
@ingomueller-net
Copy link
Contributor Author

I am not sure whether the module argument is really necessary. I do need access to the surrounding module to insert the memref.globals, but I can do that through code (with getParentOfType). Also, I leave the module itself (and all other ops exact for the alloca inputs) intact and only add new ones.

I had an original version with that argument and that ran, except for some crashes that I retrospectively relate to #66357. It may, thus, work without the argument but is it legal to do so?

@ingomueller-net ingomueller-net force-pushed the transform-alloc-to-global branch 2 times, most recently from eca2474 to a9cbcf5 Compare September 15, 2023 15:01
@ingomueller-net
Copy link
Contributor Author

I am not sure whether the module argument is really necessary. I do need access to the surrounding module to insert the memref.globals, but I can do that through code (with getParentOfType). Also, I leave the module itself (and all other ops exact for the alloca inputs) intact and only add new ones.

I had an original version with that argument and that ran, except for some crashes that I retrospectively relate to #66357. It may, thus, work without the argument but is it legal to do so?

@ftynse: Can you help us with this? @nicolasvasilache and @matthias-springer seemed to say that this might be safe but weren't very confident in their assessment...

This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals.
In particular: * Accept any op type with `SymbolTable` trait as containing op rather than only `builtin.module` and rename op argument accordingly. * Use `SymbolTable::insert` to unique the name of the globals rather than some hand-rolled function. * Use more sane semantics in Python mix-in test.
@ingomueller-net ingomueller-net force-pushed the transform-alloc-to-global branch from 8ee31eb to a91c93d Compare September 21, 2023 13:49
@ingomueller-net ingomueller-net merged commit 991cb14 into llvm:main Sep 21, 2023
@ingomueller-net ingomueller-net deleted the transform-alloc-to-global branch September 21, 2023 16:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:memref mlir:python MLIR Python bindings mlir

3 participants