Skip to content

Conversation

@mahabadm
Copy link
Contributor

@mahabadm mahabadm commented Aug 6, 2025

This PR fixes #152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary.

This addresses issue llvm#152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary. Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian@arm.com>
@llvmbot
Copy link
Member

llvmbot commented Aug 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Mohammadreza Ameri Mahabadian (mahabadm)

Changes

This addresses issue #152012 where serialization of TensorARM values into OpConstantComposite resulted in invalid binary.


Full diff: https://github.com/llvm/llvm-project/pull/152391.diff

3 Files Affected:

  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+13-1)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+9-24)
  • (modified) mlir/test/Target/SPIRV/arm-tensor-constant.mlir (+48-8)
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index c967e863554fc..d8c54ec5f88c3 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) { } auto resultID = operands[1]; - if (auto shapedType = dyn_cast<ShapedType>(resultType)) { + if (auto tensorType = dyn_cast<TensorArmType>(resultType)) { + SmallVector<Attribute> flattenedElems; + for (Attribute element : elements) { + if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) { + for (auto value : denseElemAttr.getValues<Attribute>()) + flattenedElems.push_back(value); + } else { + flattenedElems.push_back(element); + } + } + auto attr = DenseElementsAttr::get(tensorType, flattenedElems); + constantMap.try_emplace(resultID, attr, tensorType); + } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) { auto attr = DenseElementsAttr::get(shapedType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index c049574fbc9e3..04277be1a192d 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -956,6 +956,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, uint32_t resultID = getNextID(); SmallVector<uint32_t, 4> operands = {typeID, resultID}; auto elementType = cast<spirv::CompositeType>(constType).getElementType(0); + if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) { + ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front(); + if (innerShape.size() > 0) + elementType = spirv::TensorArmType::get(innerShape, elementType); + } // "If the Result Type is a cooperative matrix type, then there must be only // one Constituent, with scalar type matching the cooperative matrix Component @@ -979,30 +984,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, } else { return 0; } - } else if (isa<spirv::TensorArmType>(constType)) { - if (isZeroValue(valueAttr)) { - encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, - {typeID, resultID}); - return resultID; - } - numberOfConstituents = shapedType.getNumElements(); - operands.reserve(numberOfConstituents + 2); - for (int i = 0; i < numberOfConstituents; ++i) { - uint32_t elementID = 0; - if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { - elementID = - elementType.isInteger(1) - ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i]) - : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]); - } - if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { - elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]); - } - if (!elementID) { - return 0; - } - operands.push_back(elementID); - } + } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) { + encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull, + {typeID, resultID}); + return resultID; } else { operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { diff --git a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir index 275e586f70634..7fb8af1904388 100644 --- a/mlir/test/Target/SPIRV/arm-tensor-constant.mlir +++ b/mlir/test/Target/SPIRV/arm-tensor-constant.mlir @@ -1,17 +1,36 @@ // RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s -// DISABLED: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %} - -// FIXME(#152012): Fix arm tensor constant validation errors and reenable spirv-val tests. +// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv %s | spirv-val %} spirv.module Logical Vulkan requires #spirv.vce<v1.3, [VulkanMemoryModel, Shader, TensorsARM, Linkage], [SPV_KHR_vulkan_memory_model, SPV_ARM_tensors]> { - // CHECK-LABEL: @arm_tensor_of_i32 - spirv.func @arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { + // CHECK-LABEL: @rank_1_arm_tensor_of_i32 + spirv.func @rank_1_arm_tensor_of_i32() -> (!spirv.arm.tensor<3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32> + %0 = spirv.Constant dense<[1, 2, 3]> : !spirv.arm.tensor<3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<3xi32> + } + + // CHECK-LABEL: @rank_2_arm_tensor_of_i32 + spirv.func @rank_2_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> %0 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : !spirv.arm.tensor<2x3xi32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> } + // CHECK-LABEL: @rank_3_arm_tensor_of_i32 + spirv.func @rank_3_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x2x3xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1, 2, 3], [4, 5, 6]], {{\[}}[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32> + %0 = spirv.Constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : !spirv.arm.tensor<2x2x3xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xi32> + } + + // CHECK-LABEL: @rank_4_arm_tensor_of_i32 + spirv.func @rank_4_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3x4x5xi32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32> + %0 = spirv.Constant dense<5> : !spirv.arm.tensor<2x3x4x5xi32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xi32> + } + // CHECK-LABEL: @splat_arm_tensor_of_i32 spirv.func @splat_arm_tensor_of_i32() -> (!spirv.arm.tensor<2x3xi32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<2> : !spirv.arm.tensor<2x3xi32> @@ -19,13 +38,34 @@ spirv.module Logical Vulkan requires #spirv.vce<v1.3, spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xi32> } - // CHECK-LABEL: @arm_tensor_of_f32 - spirv.func @arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { + // CHECK-LABEL: @rank_1_arm_tensor_of_f32 + spirv.func @rank_1_arm_tensor_of_f32() -> (!spirv.arm.tensor<3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : !spirv.arm.tensor<3xf32> + %0 = spirv.Constant dense<[1.0, 2.0, 3.0]> : !spirv.arm.tensor<3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<3xf32> + } + + // CHECK-LABEL: @rank_2_arm_tensor_of_f32 + spirv.func @rank_2_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : !spirv.arm.tensor<2x3xf32> - %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]>: !spirv.arm.tensor<2x3xf32> + %0 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : !spirv.arm.tensor<2x3xf32> spirv.ReturnValue %0 : !spirv.arm.tensor<2x3xf32> } + // CHECK-LABEL: @rank_3_arm_tensor_of_f32 + spirv.func @rank_3_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x2x3xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<{{\[}}{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]], {{\[}}[7.000000e+00, 8.000000e+00, 9.000000e+00], [1.000000e+01, 1.100000e+01, 1.200000e+01]]]> : !spirv.arm.tensor<2x2x3xf32> + %0 = spirv.Constant dense<[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], [[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]]> : !spirv.arm.tensor<2x2x3xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x2x3xf32> + } + + // CHECK-LABEL: @rank_4_arm_tensor_of_f32 + spirv.func @rank_4_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3x4x5xf32>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<5.000000e+00> : !spirv.arm.tensor<2x3x4x5xf32> + %0 = spirv.Constant dense<5.0> : !spirv.arm.tensor<2x3x4x5xf32> + spirv.ReturnValue %0 : !spirv.arm.tensor<2x3x4x5xf32> + } + // CHECK-LABEL: @splat_arm_tensor_of_f32 spirv.func @splat_arm_tensor_of_f32() -> (!spirv.arm.tensor<2x3xf32>) "None" { // CHECK: {{%.*}} = spirv.Constant dense<2.000000e+00> : !spirv.arm.tensor<2x3xf32> 
@kuhar kuhar changed the title [mlir][spirv]Fix serialization of TensorARM with rank higher than one [mlir][spirv] Fix serialization of TensorARM with rank higher than one Aug 7, 2025
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this. Could you rebase your PR? With #152124, we should be able to run spirv-val in the CI and make sure it doesn't complain.

@mahabadm
Copy link
Contributor Author

mahabadm commented Aug 7, 2025

@kuhar Thanks for your note. I have rebased and seems like that the test have passed.

@kuhar kuhar requested review from Hardcode84 and IgWod-IMG August 7, 2025 16:31
Copy link
Contributor

@IgWod-IMG IgWod-IMG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Mohammadreza Ameri Mahabadian <mohammadreza.amerimahabadian@arm.com>
@mahabadm
Copy link
Contributor Author

mahabadm commented Aug 8, 2025

@kuhar Would you please kindly merge this patch, if there are no further comments? Many thanks.

@kuhar kuhar merged commit 688551f into llvm:main Aug 8, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

4 participants