@@ -120,8 +120,10 @@ namespace {
120120class MatchingSubsets {
121121public:
122122 // / Insert a subset op.
123- void insert (SubsetOpInterface op) {
123+ void insert (SubsetOpInterface op, bool collectHoistableOps = true ) {
124124 allSubsetOps.push_back (op);
125+ if (!collectHoistableOps)
126+ return ;
125127 if (auto extractionOp =
126128 dyn_cast<SubsetExtractionOpInterface>(op.getOperation ()))
127129 insertExtractionOp (extractionOp);
@@ -148,6 +150,15 @@ class MatchingSubsets {
148150 });
149151 }
150152
153+ // / Populate subset ops starting from the given region iter_arg. Return
154+ // / "failure" if non-subset ops are found along the path to the loop yielding
155+ // / op or if there is no single path to the tied yielded operand. If
156+ // / `collectHoistableOps` is set to "false", subset ops are gathered
157+ // / throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
158+ LogicalResult populateSubsetOpsAtIterArg (LoopLikeOpInterface loopLike,
159+ BlockArgument iterArg,
160+ bool collectHoistableOps = true );
161+
151162private:
152163 // / Helper function for equivalence of tensor values. Since only insertion
153164 // / subset ops (that are also destination style ops) are followed when
@@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
225236 return nullptr ;
226237}
227238
228- // / Hoist all subset ops that operate on the idx-th region iter_arg of the given
229- // / loop-like op and index into loop-invariant subset locations. Return the
230- // / newly created loop op (that has extra iter_args) or the original loop op if
231- // / nothing was hoisted.
232- static LoopLikeOpInterface hoistSubsetAtIterArg (LoopLikeOpInterface loopLike,
233- BlockArgument iterArg) {
234- IRRewriter rewriter (loopLike.getContext ());
239+ LogicalResult
240+ MatchingSubsets::populateSubsetOpsAtIterArg (LoopLikeOpInterface loopLike,
241+ BlockArgument iterArg,
242+ bool collectHoistableOps) {
235243 assert (iterArg.getOwner ()->getParentOp () == loopLike && " invalid iter_arg" );
236- auto it = llvm::find (loopLike.getRegionIterArgs (), iterArg);
237- int64_t iterArgIdx = std::distance (loopLike.getRegionIterArgs ().begin (), it);
238244 Value value = iterArg;
239- MatchingSubsets subsets;
240245
241246 // Traverse use-def chain. Subset ops can be hoisted only if all ops along the
242247 // use-def chain starting from the region iter_arg are subset extraction or
@@ -249,36 +254,71 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
249254 Value nextValue = {};
250255
251256 for (OpOperand &use : value.getUses ()) {
257+ if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner ())) {
258+ // Subset ops in nested loops are collected to check if there are only
259+ // disjoint subset ops, but such subset ops are not subject to hoisting.
260+ // To hoist subset ops from nested loops, the hoisting transformation
261+ // should be run on the nested loop.
262+ auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg (&use);
263+ if (!nestedIterArg)
264+ return failure ();
265+ // Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
266+ // use-def chain starting at `nestedIterArg` and terminating in the
267+ // tied, yielding operand.
268+ if (failed (populateSubsetOpsAtIterArg (nestedLoop, nestedIterArg,
269+ /* collectHoistableOps=*/ false )))
270+ return failure ();
271+ nextValue = nestedLoop.getTiedLoopResult (&use);
272+ continue ;
273+ }
274+
252275 auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner ());
253276 if (!subsetOp)
254- return loopLike ;
255- subsets. insert (subsetOp);
277+ return failure () ;
278+ insert (subsetOp);
256279
257280 if (auto insertionOp =
258281 dyn_cast<SubsetInsertionOpInterface>(use.getOwner ())) {
259282 // The value must be used as a destination. (In case of a source, the
260283 // entire tensor would be read, which would prevent any hoisting.)
261284 if (&use != &insertionOp.getDestinationOperand ())
262- return loopLike ;
285+ return failure () ;
263286 // There must be a single use-def chain from the region iter_arg to the
264287 // terminator. I.e., only one insertion op. Branches are not supported.
265288 if (nextValue)
266- return loopLike ;
289+ return failure () ;
267290 nextValue = insertionOp.getUpdatedDestination ();
268291 }
269292 }
270293
271294 // Nothing can be hoisted if the chain does not continue with loop yielding
272295 // op or a subset insertion op.
273296 if (!nextValue)
274- return loopLike ;
297+ return failure () ;
275298 value = nextValue;
276299 }
277300
278301 // Hoist only if the SSA use-def chain ends in the yielding terminator of the
279302 // loop and the yielded value is the `idx`-th operand. (I.e., there is no
280303 // swapping yield.)
281304 if (loopLike.getTiedLoopYieldedValue (iterArg) != yieldedOperand)
305+ return failure ();
306+
307+ return success ();
308+ }
309+
310+ // / Hoist all subset ops that operate on the idx-th region iter_arg of the given
311+ // / loop-like op and index into loop-invariant subset locations. Return the
312+ // / newly created loop op (that has extra iter_args) or the original loop op if
313+ // / nothing was hoisted.
314+ static LoopLikeOpInterface hoistSubsetAtIterArg (LoopLikeOpInterface loopLike,
315+ BlockArgument iterArg) {
316+ assert (iterArg.getOwner ()->getParentOp () == loopLike && " invalid iter_arg" );
317+ auto it = llvm::find (loopLike.getRegionIterArgs (), iterArg);
318+ int64_t iterArgIdx = std::distance (loopLike.getRegionIterArgs ().begin (), it);
319+ IRRewriter rewriter (loopLike.getContext ());
320+ MatchingSubsets subsets;
321+ if (failed (subsets.populateSubsetOpsAtIterArg (loopLike, iterArg)))
282322 return loopLike;
283323
284324 // Hoist all matching extraction-insertion pairs one-by-one.
0 commit comments