@@ -37,7 +37,100 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
3737 void runOnOperation () override ;
3838};
3939
40- // Lower scf::if to emitc::if, implementing return values as emitc::variable's
40+ // Lower scf::for to emitc::for, implementing result values using
41+ // emitc::variable's updated within the loop body.
42+ struct ForLowering : public OpRewritePattern <ForOp> {
43+ using OpRewritePattern<ForOp>::OpRewritePattern;
44+
45+ LogicalResult matchAndRewrite (ForOp forOp,
46+ PatternRewriter &rewriter) const override ;
47+ };
48+
49+ // Create an uninitialized emitc::variable op for each result of the given op.
50+ template <typename T>
51+ static SmallVector<Value> createVariablesForResults (T op,
52+ PatternRewriter &rewriter) {
53+ SmallVector<Value> resultVariables;
54+
55+ if (!op.getNumResults ())
56+ return resultVariables;
57+
58+ Location loc = op->getLoc ();
59+ MLIRContext *context = op.getContext ();
60+
61+ OpBuilder::InsertionGuard guard (rewriter);
62+ rewriter.setInsertionPoint (op);
63+
64+ for (OpResult result : op.getResults ()) {
65+ Type resultType = result.getType ();
66+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get (context, " " );
67+ emitc::VariableOp var =
68+ rewriter.create <emitc::VariableOp>(loc, resultType, noInit);
69+ resultVariables.push_back (var);
70+ }
71+
72+ return resultVariables;
73+ }
74+
75+ // Create a series of assign ops assigning given values to given variables at
76+ // the current insertion point of given rewriter.
77+ static void assignValues (ValueRange values, SmallVector<Value> &variables,
78+ PatternRewriter &rewriter, Location loc) {
79+ for (auto [value, var] : llvm::zip (values, variables))
80+ rewriter.create <emitc::AssignOp>(loc, var, value);
81+ }
82+
83+ static void lowerYield (SmallVector<Value> &resultVariables,
84+ PatternRewriter &rewriter, scf::YieldOp yield) {
85+ Location loc = yield.getLoc ();
86+ ValueRange operands = yield.getOperands ();
87+
88+ OpBuilder::InsertionGuard guard (rewriter);
89+ rewriter.setInsertionPoint (yield);
90+
91+ assignValues (operands, resultVariables, rewriter, loc);
92+
93+ rewriter.create <emitc::YieldOp>(loc);
94+ rewriter.eraseOp (yield);
95+ }
96+
97+ LogicalResult ForLowering::matchAndRewrite (ForOp forOp,
98+ PatternRewriter &rewriter) const {
99+ Location loc = forOp.getLoc ();
100+
101+ // Create an emitc::variable op for each result. These variables will be
102+ // assigned to by emitc::assign ops within the loop body.
103+ SmallVector<Value> resultVariables =
104+ createVariablesForResults (forOp, rewriter);
105+ SmallVector<Value> iterArgsVariables =
106+ createVariablesForResults (forOp, rewriter);
107+
108+ assignValues (forOp.getInits (), iterArgsVariables, rewriter, loc);
109+
110+ emitc::ForOp loweredFor = rewriter.create <emitc::ForOp>(
111+ loc, forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep ());
112+
113+ Block *loweredBody = loweredFor.getBody ();
114+
115+ // Erase the auto-generated terminator for the lowered for op.
116+ rewriter.eraseOp (loweredBody->getTerminator ());
117+
118+ SmallVector<Value> replacingValues;
119+ replacingValues.push_back (loweredFor.getInductionVar ());
120+ replacingValues.append (iterArgsVariables.begin (), iterArgsVariables.end ());
121+
122+ rewriter.mergeBlocks (forOp.getBody (), loweredBody, replacingValues);
123+ lowerYield (iterArgsVariables, rewriter,
124+ cast<scf::YieldOp>(loweredBody->getTerminator ()));
125+
126+ // Copy iterArgs into results after the for loop.
127+ assignValues (iterArgsVariables, resultVariables, rewriter, loc);
128+
129+ rewriter.replaceOp (forOp, resultVariables);
130+ return success ();
131+ }
132+
133+ // Lower scf::if to emitc::if, implementing result values as emitc::variable's
41134// updated within the then and else regions.
42135struct IfLowering : public OpRewritePattern <IfOp> {
43136 using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -52,20 +145,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
52145 PatternRewriter &rewriter) const {
53146 Location loc = ifOp.getLoc ();
54147
55- SmallVector<Value> resultVariables;
56-
57148 // Create an emitc::variable op for each result. These variables will be
58149 // assigned to by emitc::assign ops within the then & else regions.
59- if (ifOp.getNumResults ()) {
60- MLIRContext *context = ifOp.getContext ();
61- rewriter.setInsertionPoint (ifOp);
62- for (OpResult result : ifOp.getResults ()) {
63- Type resultType = result.getType ();
64- auto noInit = emitc::OpaqueAttr::get (context, " " );
65- auto var = rewriter.create <emitc::VariableOp>(loc, resultType, noInit);
66- resultVariables.push_back (var);
67- }
68- }
150+ SmallVector<Value> resultVariables =
151+ createVariablesForResults (ifOp, rewriter);
69152
70153 // Utility function to lower the contents of an scf::if region to an emitc::if
71154 // region. The contents of the scf::if regions is moved into the respective
@@ -76,16 +159,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
76159 Region &loweredRegion) {
77160 rewriter.inlineRegionBefore (region, loweredRegion, loweredRegion.end ());
78161 Operation *terminator = loweredRegion.back ().getTerminator ();
79- Location terminatorLoc = terminator->getLoc ();
80- ValueRange terminatorOperands = terminator->getOperands ();
81- rewriter.setInsertionPointToEnd (&loweredRegion.back ());
82- for (auto value2Var : llvm::zip (terminatorOperands, resultVariables)) {
83- Value resultValue = std::get<0 >(value2Var);
84- Value resultVar = std::get<1 >(value2Var);
85- rewriter.create <emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
86- }
87- rewriter.create <emitc::YieldOp>(terminatorLoc);
88- rewriter.eraseOp (terminator);
162+ lowerYield (resultVariables, rewriter, cast<scf::YieldOp>(terminator));
89163 };
90164
91165 Region &thenRegion = ifOp.getThenRegion ();
@@ -109,6 +183,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
109183}
110184
111185void mlir::populateSCFToEmitCConversionPatterns (RewritePatternSet &patterns) {
186+ patterns.add <ForLowering>(patterns.getContext ());
112187 patterns.add <IfLowering>(patterns.getContext ());
113188}
114189
@@ -118,7 +193,7 @@ void SCFToEmitCPass::runOnOperation() {
118193
119194 // Configure conversion to lower out SCF operations.
120195 ConversionTarget target (getContext ());
121- target.addIllegalOp <scf::IfOp>();
196+ target.addIllegalOp <scf::ForOp, scf:: IfOp>();
122197 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
123198 if (failed (
124199 applyPartialConversion (getOperation (), target, std::move (patterns))))
0 commit comments