- Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][acc] Add acc serial to acc parallel conversion #170189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch introduces a new transformation pass that converts `acc.serial` constructs into `acc.parallel` constructs with num_gangs(1), num_workers(1), and vector_length(1). The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages.
| @llvm/pr-subscribers-openacc @llvm/pr-subscribers-mlir Author: Razvan Lupusoru (razvanlupusoru) ChangesThis patch introduces a new transformation pass that converts The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages. Full diff: https://github.com/llvm/llvm-project/pull/170189.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td index 713aaabee65f0..b37cc282d4555 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -136,4 +136,20 @@ def ACCImplicitRoutine : Pass<"acc-implicit-routine", "mlir::ModuleOp"> { ]; } +def ACCLegalizeSerial : Pass<"acc-legalize-serial", "mlir::func::FuncOp"> { + let summary = "Legalize OpenACC serial constructs"; + let description = [{ + This pass converts `acc.serial` constructs into `acc.parallel` constructs + with `num_gangs(1)`, `num_workers(1)`, and `vector_length(1)`. + + This transformation simplifies processing of acc regions by unifying the + handling of serial and parallel constructs. Since an OpenACC serial region + executes sequentially (like a parallel region with a single gang, worker, + and vector), this conversion is semantically equivalent while enabling code + reuse in later compilation stages. + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect", + "mlir::arith::ArithDialect"]; +} + #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000000000..f41ce276f994f --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 2c6da87c66a11..10a1796972044 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms ACCImplicitData.cpp ACCImplicitDeclare.cpp ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/test/Dialect/OpenACC/legalize-serial.mlir b/mlir/test/Dialect/OpenACC/legalize-serial.mlir new file mode 100644 index 0000000000000..774c6b6f65ce3 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/legalize-serial.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %0 = memref.alloc() : memref<10x10xf32> + acc.yield %0 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { +^bb0(%0: i64): + %1 = arith.constant 0 : i64 + acc.yield %1 : i64 +} combiner { +^bb0(%0: i64, %1: i64): + %2 = arith.addi %0, %1 : i64 + acc.yield %2 : i64 +} + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init { +^bb0(%arg0: memref<i64>): + %0 = memref.alloca() : memref<i64> + %c0 = arith.constant 0 : i64 + memref.store %c0, %0[] : memref<i64> + acc.yield %0 : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.terminator +} + +// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) { +// CHECK: } +// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} +// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64> +// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) { +// CHECK: } +// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) { +// CHECK: acc.yield +// CHECK: } attributes {seq = [#acc.device_type<none>]} +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {selfAttr} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.yield +// CHECK: } attributes {selfAttr} +// CHECK: return +// CHECK: } + +func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64value = arith.constant 1 : i64 + %i32value = arith.constant 1 : i32 + %idxValue = arith.constant 1 : index + acc.serial async(%i64value: i64) { + } + acc.serial async(%i32value: i32) { + } + acc.serial async(%idxValue: index) { + } + acc.serial wait({%i64value: i64}) { + } + acc.serial wait({%i32value: i32}) { + } + acc.serial wait({%idxValue: index}) { + } + acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + } + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + %c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { + } + %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} + acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) { + } + %i64mem = memref.alloca() : memref<i64> + memref.store %i64value, %i64mem[] : memref<i64> + %i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%i64reduction : memref<i64>) { + } + acc.serial combined(loop) { + acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) { + acc.yield + } attributes {seq = [#acc.device_type<none>]} + acc.terminator + } + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue none>} + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue present>} + acc.serial { + } attributes {asyncAttr} + acc.serial { + } attributes {waitAttr} + acc.serial { + } attributes {selfAttr} + acc.serial { + acc.yield + } attributes {selfAttr} + return +} + |
| @llvm/pr-subscribers-mlir-openacc Author: Razvan Lupusoru (razvanlupusoru) ChangesThis patch introduces a new transformation pass that converts The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages. Full diff: https://github.com/llvm/llvm-project/pull/170189.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td index 713aaabee65f0..b37cc282d4555 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -136,4 +136,20 @@ def ACCImplicitRoutine : Pass<"acc-implicit-routine", "mlir::ModuleOp"> { ]; } +def ACCLegalizeSerial : Pass<"acc-legalize-serial", "mlir::func::FuncOp"> { + let summary = "Legalize OpenACC serial constructs"; + let description = [{ + This pass converts `acc.serial` constructs into `acc.parallel` constructs + with `num_gangs(1)`, `num_workers(1)`, and `vector_length(1)`. + + This transformation simplifies processing of acc regions by unifying the + handling of serial and parallel constructs. Since an OpenACC serial region + executes sequentially (like a parallel region with a single gang, worker, + and vector), this conversion is semantically equivalent while enabling code + reuse in later compilation stages. + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect", + "mlir::arith::ArithDialect"]; +} + #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000000000..f41ce276f994f --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 2c6da87c66a11..10a1796972044 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms ACCImplicitData.cpp ACCImplicitDeclare.cpp ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/test/Dialect/OpenACC/legalize-serial.mlir b/mlir/test/Dialect/OpenACC/legalize-serial.mlir new file mode 100644 index 0000000000000..774c6b6f65ce3 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/legalize-serial.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %0 = memref.alloc() : memref<10x10xf32> + acc.yield %0 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32> + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32> + acc.terminator +} + +acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { +^bb0(%0: i64): + %1 = arith.constant 0 : i64 + acc.yield %1 : i64 +} combiner { +^bb0(%0: i64, %1: i64): + %2 = arith.addi %0, %1 : i64 + acc.yield %2 : i64 +} + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init { +^bb0(%arg0: memref<i64>): + %0 = memref.alloca() : memref<i64> + %c0 = arith.constant 0 : i64 + memref.store %c0, %0[] : memref<i64> + acc.yield %0 : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.terminator +} + +// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) { +// CHECK: } +// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} +// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64> +// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) { +// CHECK: } +// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) { +// CHECK: acc.yield +// CHECK: } attributes {seq = [#acc.device_type<none>]} +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {selfAttr} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.yield +// CHECK: } attributes {selfAttr} +// CHECK: return +// CHECK: } + +func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64value = arith.constant 1 : i64 + %i32value = arith.constant 1 : i32 + %idxValue = arith.constant 1 : index + acc.serial async(%i64value: i64) { + } + acc.serial async(%i32value: i32) { + } + acc.serial async(%idxValue: index) { + } + acc.serial wait({%i64value: i64}) { + } + acc.serial wait({%i32value: i32}) { + } + acc.serial wait({%idxValue: index}) { + } + acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + } + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + %c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { + } + %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} + acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) { + } + %i64mem = memref.alloca() : memref<i64> + memref.store %i64value, %i64mem[] : memref<i64> + %i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%i64reduction : memref<i64>) { + } + acc.serial combined(loop) { + acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) { + acc.yield + } attributes {seq = [#acc.device_type<none>]} + acc.terminator + } + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue none>} + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue present>} + acc.serial { + } attributes {asyncAttr} + acc.serial { + } attributes {waitAttr} + acc.serial { + } attributes {selfAttr} + acc.serial { + acc.yield + } attributes {selfAttr} + return +} + |
clementval left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
VijayKandiah left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This patch introduces a new transformation pass that converts
acc.serialconstructs intoacc.parallelconstructs with num_gangs(1), num_workers(1), and vector_length(1).The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages.