Skip to content

Conversation

@razvanlupusoru
Copy link
Contributor

This patch introduces a new transformation pass that converts acc.serial constructs into acc.parallel constructs with num_gangs(1), num_workers(1), and vector_length(1).

The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages.

This patch introduces a new transformation pass that converts `acc.serial` constructs into `acc.parallel` constructs with num_gangs(1), num_workers(1), and vector_length(1). The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages.
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-openacc

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

Changes

This patch introduces a new transformation pass that converts acc.serial constructs into acc.parallel constructs with num_gangs(1), num_workers(1), and vector_length(1).

The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages.


Full diff: https://github.com/llvm/llvm-project/pull/170189.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td (+16)
  • (added) mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp (+117)
  • (modified) mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt (+1)
  • (added) mlir/test/Dialect/OpenACC/legalize-serial.mlir (+164)
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td index 713aaabee65f0..b37cc282d4555 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -136,4 +136,20 @@ def ACCImplicitRoutine : Pass<"acc-implicit-routine", "mlir::ModuleOp"> { ]; } +def ACCLegalizeSerial : Pass<"acc-legalize-serial", "mlir::func::FuncOp"> { + let summary = "Legalize OpenACC serial constructs"; + let description = [{ + This pass converts `acc.serial` constructs into `acc.parallel` constructs + with `num_gangs(1)`, `num_workers(1)`, and `vector_length(1)`. + + This transformation simplifies processing of acc regions by unifying the + handling of serial and parallel constructs. Since an OpenACC serial region + executes sequentially (like a parallel region with a single gang, worker, + and vector), this conversion is semantically equivalent while enabling code + reuse in later compilation stages. + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect", + "mlir::arith::ArithDialect"]; +} + #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000000000..f41ce276f994f --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 2c6da87c66a11..10a1796972044 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms ACCImplicitData.cpp ACCImplicitDeclare.cpp ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/test/Dialect/OpenACC/legalize-serial.mlir b/mlir/test/Dialect/OpenACC/legalize-serial.mlir new file mode 100644 index 0000000000000..774c6b6f65ce3 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/legalize-serial.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32>  + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %0 = memref.alloc() : memref<10x10xf32> + acc.yield %0 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32>  + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32>  + acc.terminator +} + +acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { +^bb0(%0: i64): + %1 = arith.constant 0 : i64 + acc.yield %1 : i64 +} combiner { +^bb0(%0: i64, %1: i64): + %2 = arith.addi %0, %1 : i64 + acc.yield %2 : i64 +} + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init { +^bb0(%arg0: memref<i64>): + %0 = memref.alloca() : memref<i64> + %c0 = arith.constant 0 : i64 + memref.store %c0, %0[] : memref<i64> + acc.yield %0 : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.terminator +} + +// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) { +// CHECK: } +// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} +// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64> +// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) { +// CHECK: } +// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) { +// CHECK: acc.yield +// CHECK: } attributes {seq = [#acc.device_type<none>]} +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {selfAttr} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.yield +// CHECK: } attributes {selfAttr} +// CHECK: return +// CHECK: } + +func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64value = arith.constant 1 : i64 + %i32value = arith.constant 1 : i32 + %idxValue = arith.constant 1 : index + acc.serial async(%i64value: i64) { + } + acc.serial async(%i32value: i32) { + } + acc.serial async(%idxValue: index) { + } + acc.serial wait({%i64value: i64}) { + } + acc.serial wait({%i32value: i32}) { + } + acc.serial wait({%idxValue: index}) { + } + acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + } + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + %c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { + } + %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} + acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) { + } + %i64mem = memref.alloca() : memref<i64> + memref.store %i64value, %i64mem[] : memref<i64> + %i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%i64reduction : memref<i64>) { + } + acc.serial combined(loop) { + acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) { + acc.yield + } attributes {seq = [#acc.device_type<none>]} + acc.terminator + } + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue none>} + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue present>} + acc.serial { + } attributes {asyncAttr} + acc.serial { + } attributes {waitAttr} + acc.serial { + } attributes {selfAttr} + acc.serial { + acc.yield + } attributes {selfAttr} + return +} + 
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2025

@llvm/pr-subscribers-mlir-openacc

Author: Razvan Lupusoru (razvanlupusoru)

Changes

This patch introduces a new transformation pass that converts acc.serial constructs into acc.parallel constructs with num_gangs(1), num_workers(1), and vector_length(1).

The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages.


Full diff: https://github.com/llvm/llvm-project/pull/170189.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td (+16)
  • (added) mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp (+117)
  • (modified) mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt (+1)
  • (added) mlir/test/Dialect/OpenACC/legalize-serial.mlir (+164)
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td index 713aaabee65f0..b37cc282d4555 100644 --- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td @@ -136,4 +136,20 @@ def ACCImplicitRoutine : Pass<"acc-implicit-routine", "mlir::ModuleOp"> { ]; } +def ACCLegalizeSerial : Pass<"acc-legalize-serial", "mlir::func::FuncOp"> { + let summary = "Legalize OpenACC serial constructs"; + let description = [{ + This pass converts `acc.serial` constructs into `acc.parallel` constructs + with `num_gangs(1)`, `num_workers(1)`, and `vector_length(1)`. + + This transformation simplifies processing of acc regions by unifying the + handling of serial and parallel constructs. Since an OpenACC serial region + executes sequentially (like a parallel region with a single gang, worker, + and vector), this conversion is semantically equivalent while enabling code + reuse in later compilation stages. + }]; + let dependentDialects = ["mlir::acc::OpenACCDialect", + "mlir::arith::ArithDialect"]; +} + #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp new file mode 100644 index 0000000000000..f41ce276f994f --- /dev/null +++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCLegalizeSerial.cpp @@ -0,0 +1,117 @@ +//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass converts acc.serial into acc.parallel with num_gangs(1) +// num_workers(1) vector_length(1). +// +// This transformation simplifies processing of acc regions by unifying the +// handling of serial and parallel constructs. Since an OpenACC serial region +// executes sequentially (like a parallel region with a single gang, worker, and +// vector), this conversion is semantically equivalent while enabling code reuse +// in later compilation stages. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace acc { +#define GEN_PASS_DEF_ACCLEGALIZESERIAL +#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc" +} // namespace acc +} // namespace mlir + +#define DEBUG_TYPE "acc-legalize-serial" + +namespace { +using namespace mlir; + +struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> { + using OpRewritePattern<acc::SerialOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(acc::SerialOp serialOp, + PatternRewriter &rewriter) const override { + + const Location loc = serialOp.getLoc(); + + // Create a container holding the constant value of 1 for use as the + // num_gangs, num_workers, and vector_length attributes. + llvm::SmallVector<mlir::Value> numValues; + auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32); + numValues.push_back(value); + + // Since num_gangs is specified as both attributes and values, create a + // segment attribute. + llvm::SmallVector<int32_t> numGangsSegments; + numGangsSegments.push_back(numValues.size()); + auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments); + + // Create a device_type attribute set to `none` which ensures that + // the parallel dimensions specification applies to the default clauses. + llvm::SmallVector<mlir::Attribute> crtDeviceTypes; + auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get( + rewriter.getContext(), mlir::acc::DeviceType::None); + crtDeviceTypes.push_back(crtDeviceTypeAttr); + auto devTypeAttr = + mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes); + + LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n"); + + // Create a new acc.parallel op with the same operands - except include the + // num_gangs, num_workers, and vector_length attributes. + acc::ParallelOp parOp = acc::ParallelOp::create( + rewriter, loc, serialOp.getAsyncOperands(), + serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(), + serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(), + serialOp.getWaitOperandsDeviceTypeAttr(), + serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues, + gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues, + devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(), + serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(), + serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(), + serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(), + serialOp.getCombinedAttr()); + + parOp.getRegion().takeBody(serialOp.getRegion()); + + LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n"); + rewriter.replaceOp(serialOp, parOp); + + return success(); + } +}; + +class ACCLegalizeSerial + : public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> { +public: + using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase; + void runOnOperation() override { + func::FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + RewritePatternSet patterns(context); + patterns.insert<ACCSerialOpConversion>(context); + (void)applyPatternsGreedily(funcOp, std::move(patterns)); + } +}; + +} // namespace diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt index 2c6da87c66a11..10a1796972044 100644 --- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms ACCImplicitData.cpp ACCImplicitDeclare.cpp ACCImplicitRoutine.cpp + ACCLegalizeSerial.cpp LegalizeDataValues.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/test/Dialect/OpenACC/legalize-serial.mlir b/mlir/test/Dialect/OpenACC/legalize-serial.mlir new file mode 100644 index 0000000000000..774c6b6f65ce3 --- /dev/null +++ b/mlir/test/Dialect/OpenACC/legalize-serial.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s + +acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32>  + acc.terminator +} + +acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init { +^bb0(%arg0: memref<10x10xf32>): + %0 = memref.alloc() : memref<10x10xf32> + acc.yield %0 : memref<10x10xf32> +} destroy { +^bb0(%arg0: memref<10x10xf32>): + memref.dealloc %arg0 : memref<10x10xf32>  + acc.terminator +} + +acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init { +^bb0(%arg0: memref<10xf32>): + %0 = memref.alloc() : memref<10xf32> + acc.yield %0 : memref<10xf32> +} copy { +^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>): + acc.terminator +} destroy { +^bb0(%arg0: memref<10xf32>): + memref.dealloc %arg0 : memref<10xf32>  + acc.terminator +} + +acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init { +^bb0(%0: i64): + %1 = arith.constant 0 : i64 + acc.yield %1 : i64 +} combiner { +^bb0(%0: i64, %1: i64): + %2 = arith.addi %0, %1 : i64 + acc.yield %2 : i64 +} + +acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init { +^bb0(%arg0: memref<i64>): + %0 = memref.alloca() : memref<i64> + %c0 = arith.constant 0 : i64 + memref.store %c0, %0[] : memref<i64> + acc.yield %0 : memref<i64> +} combiner { +^bb0(%arg0: memref<i64>, %arg1: memref<i64>): + %0 = memref.load %arg0[] : memref<i64> + %1 = memref.load %arg1[] : memref<i64> + %2 = arith.addi %0, %1 : i64 + memref.store %2, %arg0[] : memref<i64> + acc.terminator +} + +// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) { +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index +// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) { +// CHECK: } +// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> +// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> +// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} +// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64> +// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64> +// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) { +// CHECK: } +// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) { +// CHECK: acc.yield +// CHECK: } attributes {seq = [#acc.device_type<none>]} +// CHECK: acc.terminator +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: } attributes {selfAttr} +// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) { +// CHECK: acc.yield +// CHECK: } attributes {selfAttr} +// CHECK: return +// CHECK: } + +func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () { + %i64value = arith.constant 1 : i64 + %i32value = arith.constant 1 : i32 + %idxValue = arith.constant 1 : index + acc.serial async(%i64value: i64) { + } + acc.serial async(%i32value: i32) { + } + acc.serial async(%idxValue: index) { + } + acc.serial wait({%i64value: i64}) { + } + acc.serial wait({%i32value: i32}) { + } + acc.serial wait({%idxValue: index}) { + } + acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) { + } + %firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32> + %c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32> + acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) { + } + %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>} + acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) { + } + %i64mem = memref.alloca() : memref<i64> + memref.store %i64value, %i64mem[] : memref<i64> + %i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64> + acc.serial reduction(%i64reduction : memref<i64>) { + } + acc.serial combined(loop) { + acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) { + acc.yield + } attributes {seq = [#acc.device_type<none>]} + acc.terminator + } + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue none>} + acc.serial { + } attributes {defaultAttr = #acc<defaultvalue present>} + acc.serial { + } attributes {asyncAttr} + acc.serial { + } attributes {waitAttr} + acc.serial { + } attributes {selfAttr} + acc.serial { + acc.yield + } attributes {selfAttr} + return +} + 
Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@VijayKandiah VijayKandiah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@razvanlupusoru razvanlupusoru merged commit 258cb46 into llvm:main Dec 1, 2025
14 checks passed
rupprecht added a commit to rupprecht/llvm-project that referenced this pull request Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment