Skip to content

Commit 9448a69

Browse files
committed
Revert "[mlir][linalg] Add folder for linalg.index (llvm#136640)"
This reverts commit f010725.
1 parent 2ce50fe commit 9448a69

File tree

4 files changed

+3
-96
lines changed

4 files changed

+3
-96
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
8888

8989
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
9090
let hasVerifier = 1;
91-
let hasFolder = 1;
9291
}
9392

9493
def Linalg_SoftmaxOp : Linalg_Op<"softmax",

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,19 +2283,6 @@ LogicalResult IndexOp::verify() {
22832283
return success();
22842284
}
22852285

2286-
OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2287-
auto linalgOp = cast<LinalgOp>((*this)->getParentOp());
2288-
2289-
// Index of unit dims is always 0.
2290-
SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2291-
uint64_t dim = getDim();
2292-
assert(dim < loopBounds.size() && "Dim is out of bounds");
2293-
if (loopBounds[dim] == 1)
2294-
return IntegerAttr::get(IndexType::get(getContext()), 0);
2295-
2296-
return OpFoldResult{};
2297-
}
2298-
22992286
/////// Operations corresponding to library calls defined with Tablegen ////////
23002287

23012288
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -305,86 +305,6 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
305305
}
306306

307307
// -----
308-
309-
// CHECK: func @fold_linalg_index_tensor_static
310-
func.func @fold_linalg_index_tensor_static(%0: tensor<4x16xi32>, %1: tensor<1x16xi32>,
311-
%2: tensor<4x1xi32>) -> tensor<4x1xi32> {
312-
// CHECK-NEXT: linalg.generic
313-
// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
314-
// CHECK-NOT: linalg.index 1
315-
// CHECK: %[[IDX_2:.+]] = linalg.index 2 : index
316-
// CHECK: %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]]
317-
// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]]
318-
// CHECK: linalg.yield %[[CAST]]
319-
%res = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
320-
affine_map<(d0, d1, d2) -> (d1, d2)>,
321-
affine_map<(d0, d1, d2) -> (d0, d1)>],
322-
iterator_types = ["parallel", "parallel", "reduction"]}
323-
ins(%0, %1 : tensor<4x16xi32>, tensor<1x16xi32>)
324-
outs(%2 : tensor<4x1xi32>) {
325-
^bb0(%lhs: i32, %rhs: i32, %out: i32):
326-
%idx0 = linalg.index 0 : index
327-
%idx1 = linalg.index 1 : index
328-
%idx2 = linalg.index 2 : index
329-
%add0 = arith.addi %idx0, %idx1 : index
330-
%add1 = arith.addi %add0, %idx2 : index
331-
%int = arith.index_cast %add1 : index to i32
332-
linalg.yield %int : i32
333-
} -> tensor<4x1xi32>
334-
return %res : tensor<4x1xi32>
335-
}
336-
337-
// -----
338-
339-
// CHECK: func @fold_linalg_index_tensor_dynamic
340-
func.func @fold_linalg_index_tensor_dynamic(%0: tensor<?x1xi32>,
341-
%1: tensor<?x1xi32>) -> tensor<?x1xi32> {
342-
// CHECK-NEXT: linalg.generic
343-
// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
344-
// CHECK-NOT: linalg.index 1
345-
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_0]]
346-
// CHECK: linalg.yield %[[CAST]]
347-
%res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
348-
affine_map<(d0, d1) -> (d1, d1)>],
349-
iterator_types = ["parallel", "parallel"]}
350-
ins(%0 : tensor<?x1xi32>)
351-
outs(%1 : tensor<?x1xi32>) {
352-
^bb0(%lhs: i32, %out: i32):
353-
%idx0 = linalg.index 0 : index
354-
%idx1 = linalg.index 1 : index
355-
%add = arith.addi %idx0, %idx1 : index
356-
%int = arith.index_cast %add : index to i32
357-
linalg.yield %int : i32
358-
} -> tensor<?x1xi32>
359-
return %res : tensor<?x1xi32>
360-
}
361-
362-
// -----
363-
364-
// CHECK: func @fold_linalg_index_memref
365-
func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) {
366-
// CHECK-NEXT: linalg.generic
367-
// CHECK-NOT: linalg.index 0
368-
// CHECK: %[[IDX_1:.+]] = linalg.index 1 : index
369-
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_1]]
370-
// CHECK: linalg.yield %[[CAST]]
371-
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
372-
affine_map<(d0, d1) -> (d1, d1)>],
373-
iterator_types = ["parallel", "parallel"]}
374-
ins(%0 : memref<1x?xi32>)
375-
outs(%1 : memref<1x?xi32>) {
376-
^bb0(%lhs: i32, %out: i32):
377-
%idx0 = linalg.index 0 : index
378-
%idx1 = linalg.index 1 : index
379-
%add = arith.addi %idx0, %idx1 : index
380-
%int = arith.index_cast %add : index to i32
381-
linalg.yield %int : i32
382-
}
383-
return
384-
}
385-
386-
// -----
387-
388308
// CHECK-LABEL: func @fold_fill_reshape()
389309
func.func @fold_fill_reshape() -> tensor<6x4xf32> {
390310
%zero = arith.constant 0.0 : f32

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,12 @@ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%
278278
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
279279
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
280280
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
281+
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
281282
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
282283
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
283284
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
284-
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<8x1xindex>
285-
// CHECK: %[[ADDI:.*]] = arith.addi %[[B2]], %[[T]] : vector<8x1xindex>
285+
// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
286+
// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
286287
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
287288
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
288289

0 commit comments

Comments
 (0)