Skip to content

Conversation

@kuhar
Copy link
Member

@kuhar kuhar commented Oct 24, 2025

Use the same format as introduced for wmma by
#164920.

Also make blocks default to 1.

Use the same format as introduced for wmma by llvm#164920. Also make `blocks` default to 1.
@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-backend-amdgpu

Author: Jakub Kuderski (kuhar)

Changes

Use the same format as introduced for wmma by
#164920.

Also make blocks default to 1.


Patch is 45.36 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/165037.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td (+11-8)
  • (modified) mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp (+3-3)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir (+27-27)
  • (modified) mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir (+39-39)
  • (modified) mlir/test/Dialect/AMDGPU/invalid.mlir (+42-30)
  • (modified) mlir/test/Dialect/AMDGPU/ops.mlir (+10-3)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index d74abc22acd5e..090a93a1a95a6 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -923,10 +923,10 @@ def AMDGPU_MFMAOp : AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>, Pure]>, Arguments<(ins - I32Attr:$m, - I32Attr:$n, - I32Attr:$k, - I32Attr:$blocks, + ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32]>]>:$m, + ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32]>]>:$n, + ConfinedAttr<I32Attr, [IntIsOneOf<[1, 2, 4, 8, 16, 32, 64, 128]>]>:$k, + DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[1, 2, 4, 16]>]>, "1">:$blocks, MFMAInTypes:$sourceA, MFMAInTypes:$sourceB, MFMAOutTypes:$destC, @@ -969,14 +969,17 @@ def AMDGPU_MFMAOp : Example: ```mlir - %0 = amdgpu.mfma %matA * %matB + %matC - { abid = 1 : i32, cbsz = 1 : i32, - m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 } + %0 = amdgpu.mfma 16x16x16 %matA * %matB + %matC + { abid = 0 : i32, cbsz = 0 : i32 } + blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + + %1 = amdgpu.mfma 32x32x1 %matD * %matE + %matF + { abid = 1 : i32, cbsz = 1 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32> ``` }]; let assemblyFormat = [{ - $sourceA `*` $sourceB `+` $destC + custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC attr-dict `blgp` `=` $blgp `:` type($sourceA) `,` type($sourceB) `,` type($destC) diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 4c4965e67676e..585b6dacfa648 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -422,11 +422,11 @@ LogicalResult MFMAOp::verify() { Type sourceElem = sourceType, destElem = destType; uint32_t sourceLen = 1, destLen = 1; - if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) { + if (auto sourceVector = dyn_cast<VectorType>(sourceType)) { sourceLen = sourceVector.getNumElements(); sourceElem = sourceVector.getElementType(); } - if (auto destVector = llvm::dyn_cast<VectorType>(destType)) { + if (auto destVector = dyn_cast<VectorType>(destType)) { destLen = destVector.getNumElements(); destElem = destVector.getElementType(); } @@ -451,7 +451,7 @@ LogicalResult MFMAOp::verify() { return emitOpError("expected both non-small-float source operand types " "to match exactly"); } - // Normalize the wider integer types the compiler expects to i8 + // Normalize the wider integer types the compiler expects to i8. if (sourceElem.isInteger(32)) { sourceLen *= 4; sourceElem = b.getI8Type(); diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir index 39c31d5bf2fa3..6de55d534affb 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir @@ -8,46 +8,46 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>, // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32> + amdgpu.mfma 32x32x16 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32> + amdgpu.mfma 16x16x32 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32> + amdgpu.mfma 32x32x16 %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32> + amdgpu.mfma 16x16x32 %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32> // CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32> - amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32> + amdgpu.mfma 32x32x32 %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32> // CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32> - amdgpu.mfma %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32> + amdgpu.mfma 16x16x64 %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32> // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32> + amdgpu.mfma 32x32x64 %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32> + amdgpu.mfma 16x16x128 %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32> // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32> + amdgpu.mfma 32x32x64 %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32> + amdgpu.mfma 16x16x128 %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32> // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32> + amdgpu.mfma 32x32x64 %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32> + amdgpu.mfma 16x16x128 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32> // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32> + amdgpu.mfma 32x32x64 %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32> + amdgpu.mfma 16x16x128 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32> // CHECK-DAG: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg11 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32> + amdgpu.mfma 32x32x64 %arg11 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg11 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32> + amdgpu.mfma 16x16x128 %arg11 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32> // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg9 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<16xf32> + amdgpu.mfma 32x32x64 %arg9 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg9 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<4xf32> + amdgpu.mfma 16x16x128 %arg9 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<4xf32> func.return } @@ -57,9 +57,9 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>, func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>, %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>, %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>, - %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,  + %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>, %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) { -  + // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32 // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32 @@ -69,21 +69,21 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>, amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32> -  + // CHECK: llvm.bitcast -  + // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32> -  + // CHECK: llvm.bitcast -  + // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32> -  + // CHECK: llvm.bitcast // CHECK: llvm.mlir.constant(3 : i32) : i32 @@ -91,10 +91,10 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>, amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32> -  + // CHECK: llvm.bitcast // CHECK: llvm.mlir.constant(4 : i32) : i32 -  + // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32> amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32> // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32> diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir index 52db1421dc3c6..e292d98183cd5 100644 --- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir +++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir @@ -9,89 +9,89 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>, %arg14 : vector<2xf32>, %arg15 : vector<8xf8E5M2FNUZ>, %arg16 : vector<8xf8E4M3FNUZ>) { // CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32> - amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : f32, f32, vector<32xf32> + amdgpu.mfma 32x32x1 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : f32, f32, vector<32xf32> // CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : f32, f32, vector<16xf32> + amdgpu.mfma 16x16x1 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : f32, f32, vector<16xf32> // CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : f32, f32, vector<4xf32> + amdgpu.mfma 4x4x1 %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : f32, f32, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32> - amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : f32, f32, vector<16xf32> + amdgpu.mfma 32x32x2 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<16xf32> // CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32> - amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f32, f32, vector<4xf32> + amdgpu.mfma 16x16x4 %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<4xf32> // CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32> - amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<32xf32> + amdgpu.mfma 32x32x4 %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = ... [truncated] 
@kuhar kuhar requested a review from Groverkss October 24, 2025 20:29
kuhar added a commit to kuhar/llvm-project that referenced this pull request Oct 24, 2025
Use the same format as introduced for wmma by llvm#164920 and for mfma by llvm#165037.
@kuhar kuhar merged commit f248010 into llvm:main Oct 25, 2025
10 checks passed
kuhar added a commit to kuhar/llvm-project that referenced this pull request Oct 25, 2025
Use the same format as introduced for wmma by llvm#164920 and for mfma by llvm#165037.
kuhar added a commit that referenced this pull request Oct 25, 2025
#165044) Use the same format as introduced for wmma by #164920 and for mfma by #165037.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Oct 25, 2025
…rinsic shape (#165044) Use the same format as introduced for wmma by llvm/llvm-project#164920 and for mfma by llvm/llvm-project#165037.
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
…#165037) Use the same format as introduced for wmma by llvm#164920. Also make `blocks` default to 1.
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
…#165037) Use the same format as introduced for wmma by llvm#164920. Also make `blocks` default to 1.
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
…#165037) Use the same format as introduced for wmma by llvm#164920. Also make `blocks` default to 1.
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment