Skip to content

Commit 258cb46

Browse files
[mlir][acc] Add acc serial to acc parallel conversion (#170189)
This patch introduces a new transformation pass that converts `acc.serial` constructs into `acc.parallel` constructs with num_gangs(1), num_workers(1), and vector_length(1). The transformation is semantically equivalent since an OpenACC serial region executes sequentially, which is identical to a parallel region with a single gang, worker, and vector. This unification simplifies processing of acc regions by enabling code reuse in later compilation stages. Co-authored-by: Vijay Kandiah <vkandiah@nvidia.com>
1 parent da76a48 commit 258cb46

File tree

4 files changed

+298
-0
lines changed

4 files changed

+298
-0
lines changed

mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,4 +136,20 @@ def ACCImplicitRoutine : Pass<"acc-implicit-routine", "mlir::ModuleOp"> {
136136
];
137137
}
138138

139+
def ACCLegalizeSerial : Pass<"acc-legalize-serial", "mlir::func::FuncOp"> {
140+
let summary = "Legalize OpenACC serial constructs";
141+
let description = [{
142+
This pass converts `acc.serial` constructs into `acc.parallel` constructs
143+
with `num_gangs(1)`, `num_workers(1)`, and `vector_length(1)`.
144+
145+
This transformation simplifies processing of acc regions by unifying the
146+
handling of serial and parallel constructs. Since an OpenACC serial region
147+
executes sequentially (like a parallel region with a single gang, worker,
148+
and vector), this conversion is semantically equivalent while enabling code
149+
reuse in later compilation stages.
150+
}];
151+
let dependentDialects = ["mlir::acc::OpenACCDialect",
152+
"mlir::arith::ArithDialect"];
153+
}
154+
139155
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
//===- ACCLegalizeSerial.cpp - Legalize ACC Serial region -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This pass converts acc.serial into acc.parallel with num_gangs(1)
10+
// num_workers(1) vector_length(1).
11+
//
12+
// This transformation simplifies processing of acc regions by unifying the
13+
// handling of serial and parallel constructs. Since an OpenACC serial region
14+
// executes sequentially (like a parallel region with a single gang, worker, and
15+
// vector), this conversion is semantically equivalent while enabling code reuse
16+
// in later compilation stages.
17+
//
18+
//===----------------------------------------------------------------------===//
19+
20+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
21+
22+
#include "mlir/Dialect/Arith/IR/Arith.h"
23+
#include "mlir/Dialect/Func/IR/FuncOps.h"
24+
#include "mlir/Dialect/OpenACC/OpenACC.h"
25+
#include "mlir/IR/Builders.h"
26+
#include "mlir/IR/BuiltinAttributes.h"
27+
#include "mlir/IR/Location.h"
28+
#include "mlir/IR/MLIRContext.h"
29+
#include "mlir/IR/PatternMatch.h"
30+
#include "mlir/IR/Region.h"
31+
#include "mlir/IR/Value.h"
32+
#include "mlir/Support/LLVM.h"
33+
#include "mlir/Support/LogicalResult.h"
34+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
35+
#include "llvm/Support/Debug.h"
36+
37+
namespace mlir {
38+
namespace acc {
39+
#define GEN_PASS_DEF_ACCLEGALIZESERIAL
40+
#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
41+
} // namespace acc
42+
} // namespace mlir
43+
44+
#define DEBUG_TYPE "acc-legalize-serial"
45+
46+
namespace {
47+
using namespace mlir;
48+
49+
struct ACCSerialOpConversion : public OpRewritePattern<acc::SerialOp> {
50+
using OpRewritePattern<acc::SerialOp>::OpRewritePattern;
51+
52+
LogicalResult matchAndRewrite(acc::SerialOp serialOp,
53+
PatternRewriter &rewriter) const override {
54+
55+
const Location loc = serialOp.getLoc();
56+
57+
// Create a container holding the constant value of 1 for use as the
58+
// num_gangs, num_workers, and vector_length attributes.
59+
llvm::SmallVector<mlir::Value> numValues;
60+
auto value = arith::ConstantIntOp::create(rewriter, loc, 1, 32);
61+
numValues.push_back(value);
62+
63+
// Since num_gangs is specified as both attributes and values, create a
64+
// segment attribute.
65+
llvm::SmallVector<int32_t> numGangsSegments;
66+
numGangsSegments.push_back(numValues.size());
67+
auto gangSegmentsAttr = rewriter.getDenseI32ArrayAttr(numGangsSegments);
68+
69+
// Create a device_type attribute set to `none` which ensures that
70+
// the parallel dimensions specification applies to the default clauses.
71+
llvm::SmallVector<mlir::Attribute> crtDeviceTypes;
72+
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
73+
rewriter.getContext(), mlir::acc::DeviceType::None);
74+
crtDeviceTypes.push_back(crtDeviceTypeAttr);
75+
auto devTypeAttr =
76+
mlir::ArrayAttr::get(rewriter.getContext(), crtDeviceTypes);
77+
78+
LLVM_DEBUG(llvm::dbgs() << "acc.serial OP: " << serialOp << "\n");
79+
80+
// Create a new acc.parallel op with the same operands - except include the
81+
// num_gangs, num_workers, and vector_length attributes.
82+
acc::ParallelOp parOp = acc::ParallelOp::create(
83+
rewriter, loc, serialOp.getAsyncOperands(),
84+
serialOp.getAsyncOperandsDeviceTypeAttr(), serialOp.getAsyncOnlyAttr(),
85+
serialOp.getWaitOperands(), serialOp.getWaitOperandsSegmentsAttr(),
86+
serialOp.getWaitOperandsDeviceTypeAttr(),
87+
serialOp.getHasWaitDevnumAttr(), serialOp.getWaitOnlyAttr(), numValues,
88+
gangSegmentsAttr, devTypeAttr, numValues, devTypeAttr, numValues,
89+
devTypeAttr, serialOp.getIfCond(), serialOp.getSelfCond(),
90+
serialOp.getSelfAttrAttr(), serialOp.getReductionOperands(),
91+
serialOp.getPrivateOperands(), serialOp.getFirstprivateOperands(),
92+
serialOp.getDataClauseOperands(), serialOp.getDefaultAttrAttr(),
93+
serialOp.getCombinedAttr());
94+
95+
parOp.getRegion().takeBody(serialOp.getRegion());
96+
97+
LLVM_DEBUG(llvm::dbgs() << "acc.parallel OP: " << parOp << "\n");
98+
rewriter.replaceOp(serialOp, parOp);
99+
100+
return success();
101+
}
102+
};
103+
104+
class ACCLegalizeSerial
105+
: public mlir::acc::impl::ACCLegalizeSerialBase<ACCLegalizeSerial> {
106+
public:
107+
using ACCLegalizeSerialBase<ACCLegalizeSerial>::ACCLegalizeSerialBase;
108+
void runOnOperation() override {
109+
func::FuncOp funcOp = getOperation();
110+
MLIRContext *context = funcOp.getContext();
111+
RewritePatternSet patterns(context);
112+
patterns.insert<ACCSerialOpConversion>(context);
113+
(void)applyPatternsGreedily(funcOp, std::move(patterns));
114+
}
115+
};
116+
117+
} // namespace

mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIROpenACCTransforms
22
ACCImplicitData.cpp
33
ACCImplicitDeclare.cpp
44
ACCImplicitRoutine.cpp
5+
ACCLegalizeSerial.cpp
56
LegalizeDataValues.cpp
67

78
ADDITIONAL_HEADER_DIRS
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
// RUN: mlir-opt %s -acc-legalize-serial | FileCheck %s
2+
3+
acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
4+
^bb0(%arg0: memref<10xf32>):
5+
%0 = memref.alloc() : memref<10xf32>
6+
acc.yield %0 : memref<10xf32>
7+
} destroy {
8+
^bb0(%arg0: memref<10xf32>):
9+
memref.dealloc %arg0 : memref<10xf32>
10+
acc.terminator
11+
}
12+
13+
acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
14+
^bb0(%arg0: memref<10x10xf32>):
15+
%0 = memref.alloc() : memref<10x10xf32>
16+
acc.yield %0 : memref<10x10xf32>
17+
} destroy {
18+
^bb0(%arg0: memref<10x10xf32>):
19+
memref.dealloc %arg0 : memref<10x10xf32>
20+
acc.terminator
21+
}
22+
23+
acc.firstprivate.recipe @firstprivatization_memref_10xf32 : memref<10xf32> init {
24+
^bb0(%arg0: memref<10xf32>):
25+
%0 = memref.alloc() : memref<10xf32>
26+
acc.yield %0 : memref<10xf32>
27+
} copy {
28+
^bb0(%arg0: memref<10xf32>, %arg1: memref<10xf32>):
29+
acc.terminator
30+
} destroy {
31+
^bb0(%arg0: memref<10xf32>):
32+
memref.dealloc %arg0 : memref<10xf32>
33+
acc.terminator
34+
}
35+
36+
acc.reduction.recipe @reduction_add_i64 : i64 reduction_operator<add> init {
37+
^bb0(%0: i64):
38+
%1 = arith.constant 0 : i64
39+
acc.yield %1 : i64
40+
} combiner {
41+
^bb0(%0: i64, %1: i64):
42+
%2 = arith.addi %0, %1 : i64
43+
acc.yield %2 : i64
44+
}
45+
46+
acc.reduction.recipe @reduction_add_memref_i64 : memref<i64> reduction_operator<add> init {
47+
^bb0(%arg0: memref<i64>):
48+
%0 = memref.alloca() : memref<i64>
49+
%c0 = arith.constant 0 : i64
50+
memref.store %c0, %0[] : memref<i64>
51+
acc.yield %0 : memref<i64>
52+
} combiner {
53+
^bb0(%arg0: memref<i64>, %arg1: memref<i64>):
54+
%0 = memref.load %arg0[] : memref<i64>
55+
%1 = memref.load %arg1[] : memref<i64>
56+
%2 = arith.addi %0, %1 : i64
57+
memref.store %2, %arg0[] : memref<i64>
58+
acc.terminator
59+
}
60+
61+
// CHECK: func.func @testserialop(%[[VAL_0:.*]]: memref<10xf32>, %[[VAL_1:.*]]: memref<10xf32>, %[[VAL_2:.*]]: memref<10x10xf32>) {
62+
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
63+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : i32
64+
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
65+
// CHECK: acc.parallel async(%[[VAL_3]] : i64) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
66+
// CHECK: }
67+
// CHECK: acc.parallel async(%[[VAL_4]] : i32) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
68+
// CHECK: }
69+
// CHECK: acc.parallel async(%[[VAL_5]] : index) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
70+
// CHECK: }
71+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64}) {
72+
// CHECK: }
73+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_4]] : i32}) {
74+
// CHECK: }
75+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_5]] : index}) {
76+
// CHECK: }
77+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) wait({%[[VAL_3]] : i64, %[[VAL_4]] : i32, %[[VAL_5]] : index}) {
78+
// CHECK: }
79+
// CHECK: %[[VAL_6:.*]] = acc.firstprivate varPtr(%[[VAL_1]] : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
80+
// CHECK: %[[VAL_9:.*]] = acc.private varPtr(%[[VAL_2]] : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
81+
// CHECK: acc.parallel firstprivate(%[[VAL_6]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) private(%[[VAL_9]] : memref<10x10xf32>) vector_length(%[[VAL_4]] : i32) {
82+
// CHECK: }
83+
// CHECK: %[[VAL_7:.*]] = acc.copyin varPtr(%[[VAL_0]] : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>}
84+
// CHECK: acc.parallel dataOperands(%[[VAL_7]] : memref<10xf32>) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
85+
// CHECK: }
86+
// CHECK: %[[I64MEM:.*]] = memref.alloca() : memref<i64>
87+
// CHECK: memref.store %[[VAL_3]], %[[I64MEM]][] : memref<i64>
88+
// CHECK: %[[VAL_10:.*]] = acc.reduction varPtr(%[[I64MEM]] : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
89+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) reduction(%[[VAL_10]] : memref<i64>) {
90+
// CHECK: }
91+
// CHECK: acc.parallel combined(loop) num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
92+
// CHECK: acc.loop combined(serial) control(%{{.*}} : index) = (%[[VAL_5]] : index) to (%[[VAL_5]] : index) step (%[[VAL_5]] : index) {
93+
// CHECK: acc.yield
94+
// CHECK: } attributes {seq = [#acc.device_type<none>]}
95+
// CHECK: acc.terminator
96+
// CHECK: }
97+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
98+
// CHECK: } attributes {defaultAttr = #acc<defaultvalue none>}
99+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
100+
// CHECK: } attributes {defaultAttr = #acc<defaultvalue present>}
101+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
102+
// CHECK: }
103+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
104+
// CHECK: }
105+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
106+
// CHECK: } attributes {selfAttr}
107+
// CHECK: acc.parallel num_gangs({%[[VAL_4]] : i32}) num_workers(%[[VAL_4]] : i32) vector_length(%[[VAL_4]] : i32) {
108+
// CHECK: acc.yield
109+
// CHECK: } attributes {selfAttr}
110+
// CHECK: return
111+
// CHECK: }
112+
113+
func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
114+
%i64value = arith.constant 1 : i64
115+
%i32value = arith.constant 1 : i32
116+
%idxValue = arith.constant 1 : index
117+
acc.serial async(%i64value: i64) {
118+
}
119+
acc.serial async(%i32value: i32) {
120+
}
121+
acc.serial async(%idxValue: index) {
122+
}
123+
acc.serial wait({%i64value: i64}) {
124+
}
125+
acc.serial wait({%i32value: i32}) {
126+
}
127+
acc.serial wait({%idxValue: index}) {
128+
}
129+
acc.serial wait({%i64value : i64, %i32value : i32, %idxValue : index}) {
130+
}
131+
%firstprivate = acc.firstprivate varPtr(%b : memref<10xf32>) recipe(@firstprivatization_memref_10xf32) -> memref<10xf32>
132+
%c_private = acc.private varPtr(%c : memref<10x10xf32>) recipe(@privatization_memref_10_10_f32) -> memref<10x10xf32>
133+
acc.serial private(%c_private : memref<10x10xf32>) firstprivate(%firstprivate : memref<10xf32>) {
134+
}
135+
%copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = #acc<data_clause acc_copy>}
136+
acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) {
137+
}
138+
%i64mem = memref.alloca() : memref<i64>
139+
memref.store %i64value, %i64mem[] : memref<i64>
140+
%i64reduction = acc.reduction varPtr(%i64mem : memref<i64>) recipe(@reduction_add_memref_i64) -> memref<i64>
141+
acc.serial reduction(%i64reduction : memref<i64>) {
142+
}
143+
acc.serial combined(loop) {
144+
acc.loop combined(serial) control(%arg3 : index) = (%idxValue : index) to (%idxValue : index) step (%idxValue : index) {
145+
acc.yield
146+
} attributes {seq = [#acc.device_type<none>]}
147+
acc.terminator
148+
}
149+
acc.serial {
150+
} attributes {defaultAttr = #acc<defaultvalue none>}
151+
acc.serial {
152+
} attributes {defaultAttr = #acc<defaultvalue present>}
153+
acc.serial {
154+
} attributes {asyncAttr}
155+
acc.serial {
156+
} attributes {waitAttr}
157+
acc.serial {
158+
} attributes {selfAttr}
159+
acc.serial {
160+
acc.yield
161+
} attributes {selfAttr}
162+
return
163+
}
164+

0 commit comments

Comments
 (0)