Skip to content

Conversation

@bjacob
Copy link
Contributor

@bjacob bjacob commented Oct 26, 2023

linalg.batch_vecmat was just added in #70218, but I forgot then to add the standard isBatchVecmat utilities

@bjacob bjacob marked this pull request as ready for review October 26, 2023 02:53
@llvmbot
Copy link
Member

llvmbot commented Oct 26, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (bjacob)

Changes

linalg.batch_vecmat was just added in #70218, but I forgot then to add the standard isBatchVecmat utilities


Full diff: https://github.com/llvm/llvm-project/pull/70284.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td (+11)
  • (modified) mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h (+6)
  • (modified) mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp (+25)
  • (modified) mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp (+52)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td index 44e82f452b3cef1..69ca888a8acdbe0 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -98,6 +98,17 @@ def LinalgContractionOpInterface : OpInterface<"ContractionOpInterface"> { return mlir::isVecmat($_op.getIndexingMaps()); }]>, InterfaceMethod< + /*desc=*/[{ + Returns whether the given op has indexing maps that correspond to a + batched vector-matrix multiplication. + }], + /*retTy=*/"bool", + /*methodName=*/"isBatchVecmat", + /*args=*/(ins), + /*methodBody=*/[{ + return mlir::isBatchVecmat($_op.getIndexingMaps()); + }]>, + InterfaceMethod< /*desc=*/[{ Returns whether the given op has indexing maps that correspond to a matrix-vector multiplication. diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index 225b9f287d340db..134c5569fbb2f3e 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -55,6 +55,12 @@ bool isRowMajorBatchMatmul(ArrayAttr indexingMaps); /// performed within the reduction. bool isVecmat(ArrayAttr indexingMaps); +/// Tests whether the given maps describe a batch vector matrix multiplication. +/// The test is permutation-invariant. Note that this only checks the affine +/// maps from an operation, so does not perform any checks on the math being +/// performed within the reduction. +bool isBatchVecmat(ArrayAttr indexingMaps); + /// Tests whether the given maps describe a matrix vector multiplication. The /// test is permutation-invariant. Note that this only checks the affine maps /// from an operation, so does not perform any checks on the math being diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp index 641ddf3f91cb2d9..383ef1cea53fd30 100644 --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -120,6 +120,31 @@ bool mlir::isVecmat(ArrayAttr indexingMaps) { return indexingMaps == maps; } +bool mlir::isBatchVecmat(ArrayAttr indexingMaps) { + if (indexingMaps.size() != 3) + return false; + AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue(); + AffineMap map1 = cast<AffineMapAttr>(indexingMaps[1]).getValue(); + AffineMap map2 = cast<AffineMapAttr>(indexingMaps[2]).getValue(); + + if (map0.getNumResults() != 2 || map1.getNumResults() != 3 || + map2.getNumResults() != 2 || map0.getNumInputs() != 3 || + map1.getNumInputs() != 3 || map2.getNumInputs() != 3) { + return false; + } + + // Extract dimensions for B*K * B*K*N -> B*N + AffineExpr b = map0.getResult(0); + AffineExpr k = map0.getResult(1); + AffineExpr n = map2.getResult(1); + auto *context = indexingMaps.getContext(); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {b, k}, context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {b, k, n}, context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {b, n}, context)); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); + return indexingMaps == maps; +} + bool mlir::isMatvec(ArrayAttr indexingMaps) { if (indexingMaps.size() != 3) return false; diff --git a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp index 3f576bacebf6aad..d257fc5d6e041d1 100644 --- a/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp +++ b/mlir/unittests/Dialect/Utils/StructuredOpsUtilsTest.cpp @@ -370,4 +370,56 @@ TEST(isBatchMatvec, WrongDimOrderMatrix) { EXPECT_THAT(maps, Not(Truly(isBatchMatvec))); } +TEST(isBatchVecmat, Simple) { + MLIRContext context; + + AffineExpr batch, k, n; + bindDims(&context, batch, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isBatchVecmat)); +} + +TEST(isBatchVecmat, BindingSwapped) { + MLIRContext context; + + AffineExpr batch, k, n; + bindDims(&context, batch, n, k); // bind in different order + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Truly(isBatchVecmat)); +} + +TEST(isBatchVecmat, Matmul) { + MLIRContext context; + + AffineExpr m, n, k; + bindDims(&context, m, n, k); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isBatchVecmat))); +} + +TEST(isBatchVecmat, WrongDimOrderMatrix) { + MLIRContext context; + + AffineExpr batch, k, n; + bindDims(&context, batch, k, n); + auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {batch, k}, &context)); + auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n, k}, &context)); + auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {batch, n}, &context)); + auto maps = ArrayAttr::get(&context, {mapA, mapB, mapC}); + + EXPECT_THAT(maps, Not(Truly(isBatchVecmat))); +} + } // namespace 
bool mlir::isBatchVecmat(ArrayAttr indexingMaps) {
if (indexingMaps.size() != 3)
return false;
AffineMap map0 = cast<AffineMapAttr>(indexingMaps[0]).getValue();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you could use the infer utils which I find nicer to work with:

if (maps == infer({{m, k}, {k, n}, {m, n}})) {
your call :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see. nicer indeed, but i was copy and pasting other instances in this file. The move to infer looks like it would be a nice improvement, but should be done for the whole file togeteher, so, in a separate PR.

@bjacob bjacob merged commit 8a80e33 into llvm:main Oct 26, 2023
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Oct 26, 2023
`linalg.batch_vecmat` was just added in llvm#70218, but I forgot then to add the standard `isBatchVecmat` utilities
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

4 participants