Skip to content

Commit cb0ed6f

Browse files
[mlir][vector] Implement subset op interface for xfer ops
1 parent ff614a5 commit cb0ed6f

File tree

7 files changed

+425
-471
lines changed

7 files changed

+425
-471
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- SubsetOpInterfaceImpl.h - Tensor subsets -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace vector {
16+
void registerSubsetOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace vector
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_VECTOR_SUBSETOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@
8585
#include "mlir/Dialect/UB/IR/UBOps.h"
8686
#include "mlir/Dialect/Vector/IR/VectorOps.h"
8787
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
88+
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
8889
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
8990
#include "mlir/IR/Dialect.h"
9091
#include "mlir/Interfaces/CastInterfaces.h"
@@ -171,6 +172,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
171172
tensor::registerTilingInterfaceExternalModels(registry);
172173
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
173174
vector::registerBufferizableOpInterfaceExternalModels(registry);
175+
vector::registerSubsetOpInterfaceExternalModels(registry);
174176
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
175177
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
176178
}

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
1010
LowerVectorShapeCast.cpp
1111
LowerVectorTransfer.cpp
1212
LowerVectorTranspose.cpp
13+
SubsetOpInterfaceImpl.cpp
1314
VectorDistribute.cpp
1415
VectorDropLeadUnitDim.cpp
1516
VectorEmulateNarrowType.cpp
@@ -40,6 +41,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
4041
MLIRMemRefUtils
4142
MLIRSCFDialect
4243
MLIRSideEffectInterfaces
44+
MLIRSubsetOpInterface
4345
MLIRTensorDialect
4446
MLIRTransforms
4547
MLIRVectorDialect
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- SubsetOpInterfaceImpl.cpp - Tensor subsets -------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12+
#include "mlir/Interfaces/SubsetOpInterface.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::vector;
16+
17+
namespace {
18+
19+
template <typename OpTy>
20+
struct XferOpSubsetOpInterface
21+
: public SubsetOpInterface::ExternalModel<XferOpSubsetOpInterface<OpTy>,
22+
OpTy> {
23+
FailureOr<HyperrectangularSlice>
24+
getAccessedHyperrectangularSlice(Operation *op) const {
25+
auto xferOp = cast<OpTy>(op);
26+
Builder b(xferOp->getContext());
27+
SmallVector<OpFoldResult> offsets = llvm::map_to_vector(
28+
xferOp.getIndices(), [](Value v) -> OpFoldResult { return v; });
29+
SmallVector<OpFoldResult> sizes = llvm::map_to_vector(
30+
xferOp.getTransferChunkAccessed(),
31+
[&](int64_t sz) -> OpFoldResult { return b.getIndexAttr(sz); });
32+
return HyperrectangularSlice(offsets, sizes);
33+
}
34+
};
35+
36+
struct TransferReadOpSubsetExtractionOpInterface
37+
: public SubsetExtractionOpInterface::ExternalModel<
38+
TransferReadOpSubsetExtractionOpInterface, vector::TransferReadOp> {
39+
OpOperand &getSourceOperand(Operation *op) const {
40+
return cast<vector::TransferReadOp>(op).getSourceMutable();
41+
}
42+
};
43+
44+
struct TransferWriteOpSubsetInsertionOpInterface
45+
: public SubsetInsertionOpInterface::ExternalModel<
46+
TransferWriteOpSubsetInsertionOpInterface, vector::TransferWriteOp> {
47+
OpOperand &getSourceOperand(Operation *op) const {
48+
return cast<vector::TransferWriteOp>(op).getVectorMutable();
49+
}
50+
51+
OpOperand &getDestinationOperand(Operation *op) const {
52+
return cast<vector::TransferWriteOp>(op).getSourceMutable();
53+
}
54+
55+
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
56+
Location loc) const {
57+
// TODO: Implement when needed.
58+
return Value();
59+
}
60+
61+
SmallVector<Value>
62+
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
63+
// TODO: Implement when needed.
64+
return {};
65+
}
66+
};
67+
68+
} // namespace
69+
70+
void mlir::vector::registerSubsetOpInterfaceExternalModels(
71+
DialectRegistry &registry) {
72+
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
73+
TransferReadOp::attachInterface<XferOpSubsetOpInterface<TransferReadOp>>(
74+
*ctx);
75+
TransferReadOp::attachInterface<TransferReadOpSubsetExtractionOpInterface>(
76+
*ctx);
77+
TransferWriteOp::attachInterface<XferOpSubsetOpInterface<TransferWriteOp>>(
78+
*ctx);
79+
TransferWriteOp::attachInterface<TransferWriteOpSubsetInsertionOpInterface>(
80+
*ctx);
81+
});
82+
}

0 commit comments

Comments
 (0)