Skip to content

Conversation

@qedawkins
Copy link
Contributor

This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.

This does the same as llvm#72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

Changes

This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.


Full diff: https://github.com/llvm/llvm-project/pull/74038.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+21-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+44)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index d2c6ba557b9bbec..0dc097158a4a55d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { opToErase.push_back(read.getOperation()); } -/// Returns a copy of `shape` without unit dims. -static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) { - SmallVector<int64_t> reducedShape; - llvm::copy_if(shape, std::back_inserter(reducedShape), - [](int64_t dimSize) { return dimSize != 1; }); - return reducedShape; -} - /// Converts OpFoldResults to int64_t shape without unit dims. static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { SmallVector<int64_t> reducedShape; @@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern Value source = transferWriteOp.getSource(); MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); // TODO: support tensor type. - if (!sourceType || !sourceType.hasStaticShape()) - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (!sourceType) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) @@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern return failure(); // Check if the reduced vector shape matches the reduced destination shape. // Otherwise, this case is not supported yet. - int vectorReducedRank = getReducedRank(vectorType.getShape()); - if (reducedRank != vectorReducedRank) + auto reducedVectorType = trimNonScalableUnitDims(vectorType); + if (reducedRank != reducedVectorType.getRank()) return failure(); if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { return getConstantIntValue(v) != static_cast<int64_t>(0); })) return failure(); + + Value maskOp = transferWriteOp.getMask(); + if (maskOp) { + auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + transferWriteOp, + "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + FailureOr<Value> rankReducedCreateMask = + createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); + if (failed(rankReducedCreateMask)) + return failure(); + maskOp = *rankReducedCreateMask; + } Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); SmallVector<Value> zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); - VectorType reducedVectorType = VectorType::get( - getReducedShape(vectorType.getShape()), vectorType.getElementType()); - + SmallVector<bool> inBounds(reducedVectorType.getRank(), true); auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( loc, reducedVectorType, vector); rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( - transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); + transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros, + identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index 735915d43565389..d65708068862f46 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2( // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}> // CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8> +func.func @masked_transfer_write_and_vector_rank_reducing( + %arg : memref<1x1x3x1x16x1xf32>, + %vec : vector<1x3x1x16x1xf32>, + %mask_dim1 : index, + %mask_dim2 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1> + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask : + vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32> + return +} +// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32> +// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>, +// CHECK-SAME: %[[MASKDIM1:.+]]: index, +// CHECK-SAME: %[[MASKDIM2:.+]]: index +// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32> +// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32> + +func.func @masked_transfer_write_dynamic_rank_reducing( + %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, + %vec : vector<[16]x1xi8>, + %mask_dim0 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %pad = arith.constant 0 : i8 + %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1> + vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} : + vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>> + return +} +// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 +// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>, +// CHECK-SAME: %[[MASK_DIM0:.+]]: index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1> +// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}> +// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}> + /// Only masks operands of vector.create_mask are currently supported. func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1( %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, 
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

Changes

This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.


Full diff: https://github.com/llvm/llvm-project/pull/74038.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+21-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+44)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index d2c6ba557b9bbec..0dc097158a4a55d 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { opToErase.push_back(read.getOperation()); } -/// Returns a copy of `shape` without unit dims. -static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) { - SmallVector<int64_t> reducedShape; - llvm::copy_if(shape, std::back_inserter(reducedShape), - [](int64_t dimSize) { return dimSize != 1; }); - return reducedShape; -} - /// Converts OpFoldResults to int64_t shape without unit dims. static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) { SmallVector<int64_t> reducedShape; @@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern Value source = transferWriteOp.getSource(); MemRefType sourceType = dyn_cast<MemRefType>(source.getType()); // TODO: support tensor type. - if (!sourceType || !sourceType.hasStaticShape()) - return failure(); - if (sourceType.getNumElements() != vectorType.getNumElements()) + if (!sourceType) return failure(); // TODO: generalize this pattern, relax the requirements here. if (transferWriteOp.hasOutOfBoundsDim()) @@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern return failure(); // Check if the reduced vector shape matches the reduced destination shape. // Otherwise, this case is not supported yet. - int vectorReducedRank = getReducedRank(vectorType.getShape()); - if (reducedRank != vectorReducedRank) + auto reducedVectorType = trimNonScalableUnitDims(vectorType); + if (reducedRank != reducedVectorType.getRank()) return failure(); if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { return getConstantIntValue(v) != static_cast<int64_t>(0); })) return failure(); + + Value maskOp = transferWriteOp.getMask(); + if (maskOp) { + auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>(); + if (!createMaskOp) + return rewriter.notifyMatchFailure( + transferWriteOp, + "unsupported mask op, only 'vector.create_mask' is " + "currently supported"); + FailureOr<Value> rankReducedCreateMask = + createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp); + if (failed(rankReducedCreateMask)) + return failure(); + maskOp = *rankReducedCreateMask; + } Value reducedShapeSource = rankReducingSubviewDroppingUnitDims(rewriter, loc, source); Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); SmallVector<Value> zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); - VectorType reducedVectorType = VectorType::get( - getReducedShape(vectorType.getShape()), vectorType.getElementType()); - + SmallVector<bool> inBounds(reducedVectorType.getRank(), true); auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>( loc, reducedVectorType, vector); rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( - transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); + transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros, + identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds)); return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index 735915d43565389..d65708068862f46 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2( // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}> // CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8> +func.func @masked_transfer_write_and_vector_rank_reducing( + %arg : memref<1x1x3x1x16x1xf32>, + %vec : vector<1x3x1x16x1xf32>, + %mask_dim1 : index, + %mask_dim2 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1> + vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask : + vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32> + return +} +// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32> +// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>, +// CHECK-SAME: %[[MASKDIM1:.+]]: index, +// CHECK-SAME: %[[MASKDIM2:.+]]: index +// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32> +// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32> + +func.func @masked_transfer_write_dynamic_rank_reducing( + %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, + %vec : vector<[16]x1xi8>, + %mask_dim0 : index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %pad = arith.constant 0 : i8 + %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1> + vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} : + vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>> + return +} +// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8 +// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>, +// CHECK-SAME: %[[MASK_DIM0:.+]]: index +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1> +// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}> +// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}> + /// Only masks operands of vector.create_mask are currently supported. func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1( %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>, 
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@MacDue
Copy link
Member

MacDue commented Dec 1, 2023

Buildbot can't find my commit?

I've hit this a few times recently too, seems like something has changed 😕

@qedawkins
Copy link
Contributor Author

Buildbot can't find my commit?

I've hit this a few times recently too, seems like something has changed 😕

Well it worked now I guess :/

@qedawkins qedawkins merged commit fdf84cb into llvm:main Dec 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment