- Notifications
You must be signed in to change notification settings - Fork 3.6k
Add InstanceNormalization operator to QNN EP #14867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
79 commits Select commit Hold shift + click to select a range
b2c334c Merged PR 8479: Add QNN ExecutionProvider which enables OnnxRuntime i…
HectorSVC dcc47d9 Merged PR 8567: Support Where op with QDQ nodes as a node unit group
HectorSVC 6a2db15 Merged PR 8613: Make Softmax fallback more accurate
HectorSVC a5eab82 Merged PR 9177: Make StyleGAN work on Linux and enable on build pipeline
SatyaJandhyalaAtMS 0162e1b Merged PR 9204: Replaced qnn_model_wrapper pointer with a reference.
SatyaJandhyalaAtMS 4b25b1f Merged PR 9226: Remove high threshold for QNN QDQ model test
SatyaJandhyalaAtMS 410016c Merged PR 9203: QNN v2 integration
HectorSVC 83a63d7 Merged PR 9328: Fix the issues relate to execute Text Prediction QDQ …
HectorSVC 3bdfe1a Merged PR 9361: Change the way to identify the back-end capacity
HectorSVC cd401c0 Merged PR 9386: Create Windows ARM64 build pipeline
SatyaJandhyalaAtMS ff4d46e Merged PR 9456: Add missing part Qnn device creation/releasing which …
HectorSVC 2f22ca1 Merged PR 9476: Revert 'Add missing part Qnn device creation/releasin…
HectorSVC aa18a71 Merged PR 9429: Add support for single Transpose node in QDQ model
HectorSVC 028498b Merged PR 9432: Enable Quantized MobileNet test and create an interna…
SatyaJandhyalaAtMS 3d70683 Merged PR 9521: minor fix to the unreachable code
HectorSVC 9f8b209 Merged PR 9567: Moved QNN SDK version from 2.3 to 2.5
SatyaJandhyalaAtMS 55e2879 Merged PR 9572: Sync with latest Github main
HectorSVC 6ce8f78 Merged PR 9583: Enable ConvTest UT for Qnn. Exclude tests with dynami…
HectorSVC 892d3f0 Merged PR 9661: Added TopK operator and removed unnecessary variable …
SatyaJandhyalaAtMS 0f4311b Merged PR 9822: Fixed unused parameter warning.
SatyaJandhyalaAtMS b75c3c8 Merged PR 9836: Enable Tanh, ReduceMin, Slice for node unit support
HectorSVC c15e7ac Merged PR 9877: Enable Tile Op
HectorSVC cbe2339 Merged PR 10032: Convert Gather indices initializer data from int64 t…
HectorSVC f6b622e Merged PR 9968: Add NonMaxSuppression operator support on QNN EP
SatyaJandhyalaAtMS 4e72ed8 Merged PR 10131: Revert 'Add NonMaxSuppression operator support on QN…
SatyaJandhyalaAtMS d358d87 Merged PR 9866: Limit the transpose optimizer works for Transpose wit…
HectorSVC 837a70c Merged PR 10311: Update QNN version to 2.6.0
adrianlizarraga bb0cb62 Merge branch 'main' into qnn_ep_github
HectorSVC 93f2883 resolve conflicts
HectorSVC 7bc5bb3 disable some new tests for Qnn EP
HectorSVC 80e714e Disable LayerNormalization test for Qnn EP
HectorSVC 1feb5f0 disable onnx node tests: resize_downsample_scales_linear_antialias & …
HectorSVC ff485ac extend timeout limit
HectorSVC 2b22726 Merge branch 'main' into qnn_ep_github
HectorSVC ea75d78 resolve merge conflicts
HectorSVC 1b3aebb disable some LayerNormTest tests
HectorSVC 93dda20 Back out the fix for transpose optimizer issue. Wait for Scott's chan…
HectorSVC 053def6 correct typo
HectorSVC ca046d4 Update tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-cr…
HectorSVC bd2cbdd remove provider options "runtime"
HectorSVC 44728cb Merge branch 'qnn_ep_github' of https://github.com/microsoft/onnxrunt…
HectorSVC 277a860 Remove cmake_extra_defines from Linux build pipeline
HectorSVC 1284fb0 Update tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
HectorSVC f43f377 Reorganize QNN EP model tests into separate float32 and qdq folders
adrianlizarraga 906f93d use enum class
HectorSVC 1aab226 Remove Qnn graph creation during graph partitioning since it's not re…
HectorSVC 6ad1138 Update linux-qnn-ci-pipeline.yml to use a new pool
adrianlizarraga 2d9363a Update onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_build…
HectorSVC 947a026 Update onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_build…
HectorSVC 8f1f3c0 Update onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_build…
HectorSVC 83431b5 Add InstanceNormalization operator to QNN EP
adrianlizarraga 6e55965 Fix param indentation
adrianlizarraga 328237b Add InstanceNormalization unit test that includes QNN EP
adrianlizarraga 5e175d8 Start supporting inputs with rank > 2
adrianlizarraga 202821e Merge latest commits from main; Expect QNN InstanceNorm op input to h…
adrianlizarraga e198d44 Remove unnecessary comments
adrianlizarraga 22f418f Merge latest commits from main
adrianlizarraga e5843b2 Allow DQ->InstanceNorm->Q to be treated as a node unit. Fix bug when …
adrianlizarraga 687d063 Add QNN unit tests for QDQ Conv and InstanceNorm ops
adrianlizarraga 9b86707 Add unused parameter macro. Add comments
adrianlizarraga 4a684d6 Make InstanceNormalization op layout sensitive for all EPs (not just …
adrianlizarraga ce229fc Remove unnecessary epsilon attr validation; Remove unnecessary overri…
adrianlizarraga 8c4dd8a Add runtime check for Windows ARM64 that skips qdq op test if HTP bac…
adrianlizarraga 6905816 Update onnxruntime/test/providers/qnn/qnn_basic_test.cc
adrianlizarraga 9c5f880 Address comments
adrianlizarraga 818341d Merge latest commits from main
adrianlizarraga 842da60 Cache result of runtime query for HTP support on Windows ARM64
adrianlizarraga 92cccb7 Fix merge conflicts
adrianlizarraga 7e73772 Make HTPBackendTestFixture visible on linux
adrianlizarraga b5af1d8 Clean up tests
adrianlizarraga 123f523 Add new QDQ selector for InstanceNormalization op
adrianlizarraga 0982f1f Update QDQ InstanceNorm test case
adrianlizarraga 5f44994 Add static casts, clean up
adrianlizarraga 67ef4ce Add QDQ support for the InstanceNormalization operator to the quantiz…
adrianlizarraga 28b5e56 Run python black linter on test_op_instance_normalization.py
adrianlizarraga 64d04d0 Remove unused imports from test_op_instance_normalization.py
adrianlizarraga cdf885c Run python black linter on instnorm.py
adrianlizarraga a91eefc Run python isort on quantize.py
adrianlizarraga 83b75f8 Fix pylint warnings in test_op_instance_normalization.py
adrianlizarraga File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
126 changes: 126 additions & 0 deletions 126 onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
| | ||
| #include "core/providers/common.h" | ||
| #include "core/providers/shared/utils/utils.h" | ||
| #include "core/framework/tensorprotoutils.h" | ||
| #include "core/providers/qnn/builder/qnn_model_wrapper.h" | ||
| #include "core/providers/qnn/builder/op_builder_factory.h" | ||
| #include "core/common/safeint.h" | ||
| #include "onnx/defs/data_type_utils.h" | ||
| | ||
| #include "base_op_builder.h" | ||
| | ||
| namespace onnxruntime { | ||
| namespace qnn { | ||
| | ||
| class InstanceNormOpBuilder : public BaseOpBuilder { | ||
| public: | ||
| InstanceNormOpBuilder() : BaseOpBuilder("InstanceNormOpBuilder") {} | ||
| ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InstanceNormOpBuilder); | ||
| | ||
| Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, | ||
| const NodeUnit& node_unit, | ||
| const logging::Logger& logger, | ||
| bool is_quantized_model) const override final ORT_MUST_USE_RESULT; | ||
| | ||
| protected: | ||
| Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, | ||
| const NodeUnit& node_unit, | ||
| std::vector<std::string>&& input_names, | ||
| const logging::Logger& logger, | ||
| bool is_quantized_model, | ||
| bool do_op_validation) const override ORT_MUST_USE_RESULT; | ||
| }; | ||
| | ||
| // Instance normalization op is sensitive to data layout. | ||
| // The nodes from 1st call of GetCapability do not get layout transformer applied, so their shapes are still NCHW. | ||
| // The nodes from 2nd call of GetCapability get their layout transformed to NHWC. | ||
| // Therefore, we need to check the node domain to determine if the layout has been transformed. | ||
| Status InstanceNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, | ||
| const NodeUnit& node_unit, | ||
| const logging::Logger& logger, | ||
| bool is_quantized_model) const { | ||
| ORT_UNUSED_PARAMETER(logger); | ||
| | ||
| const auto float_elem_type = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"); | ||
| | ||
| // Check input type is float for CPU. | ||
| const auto& inputs = node_unit.Inputs(); | ||
| ONNX_NAMESPACE::DataType input_data_type = inputs[0].node_arg.Type(); | ||
| if (!is_quantized_model && input_data_type != float_elem_type) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm data type " + *input_data_type + " is not supported in CPU backend."); | ||
| } | ||
| | ||
| // Also check output type is float for CPU. | ||
| const auto& outputs = node_unit.Outputs(); | ||
| ONNX_NAMESPACE::DataType output_data_type = outputs[0].node_arg.Type(); | ||
| if (!is_quantized_model && output_data_type != float_elem_type) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm data type " + *output_data_type + " is not supported in CPU backend."); | ||
| } | ||
| | ||
| std::vector<uint32_t> input_shape; | ||
| ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0"); | ||
| const size_t input_rank = input_shape.size(); | ||
| | ||
| if (input_rank <= 2 || input_rank > 4) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm only supports input ranks of size 3 or 4."); | ||
| } | ||
| | ||
| const uint32_t num_channels = (node_unit.Domain() == kMSInternalNHWCDomain) ? input_shape.back() : input_shape[1]; | ||
| | ||
| std::vector<uint32_t> scale_shape; | ||
| ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)"); | ||
| if (scale_shape.size() != 1 || scale_shape[0] != num_channels) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm input 1 (scale) must have 1D shape [channel]."); | ||
| } | ||
| | ||
| std::vector<uint32_t> bias_shape; | ||
| ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)"); | ||
| if (bias_shape.size() != 1 || bias_shape[0] != num_channels) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm input 2 (bias) must have 1D shape [channel]."); | ||
| } | ||
| | ||
| NodeAttrHelper node_helper(node_unit); | ||
| const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. | ||
| if (epsilon <= 0.0f) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm epsilon must be greater than 0.0"); | ||
| } | ||
| | ||
| return Status::OK(); | ||
| } | ||
| | ||
| Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, | ||
| const NodeUnit& node_unit, | ||
| std::vector<std::string>&& input_names, | ||
| const logging::Logger& logger, | ||
| bool is_quantized_model, | ||
| bool do_op_validation) const { | ||
| NodeAttrHelper node_helper(node_unit); | ||
| std::vector<std::string> param_tensor_names; | ||
| | ||
| const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec. | ||
| Qnn_Scalar_t epsilon_param = QNN_SCALAR_INIT; | ||
| epsilon_param.dataType = QNN_DATATYPE_FLOAT_32; | ||
| epsilon_param.floatValue = epsilon; | ||
| QnnParamWrapper epsilon_param_wrapper(node_unit.Index(), | ||
| node_unit.Name(), | ||
| qnn_def::epsilon, | ||
| epsilon_param); | ||
| param_tensor_names.push_back(epsilon_param_wrapper.GetParamTensorName()); | ||
| qnn_model_wrapper.AddParamWrapper(std::move(epsilon_param_wrapper)); | ||
| | ||
| ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, | ||
| std::move(input_names), | ||
| std::move(param_tensor_names), | ||
| logger, is_quantized_model, do_op_validation)); | ||
| | ||
| return Status::OK(); | ||
| } | ||
| | ||
| void CreateInstanceNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { | ||
| op_registrations.AddOpBuilder(op_type, std::make_unique<InstanceNormOpBuilder>()); | ||
| } | ||
| | ||
| } // namespace qnn | ||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
29 changes: 29 additions & 0 deletions 29 onnxruntime/python/tools/quantization/operators/instnorm.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| | ||
| from .qdq_base_operator import QDQOperatorBase | ||
| | ||
| | ||
| class QDQInstanceNormalization(QDQOperatorBase): | ||
| def __init__(self, onnx_quantizer, onnx_node): | ||
| super().__init__(onnx_quantizer, onnx_node) | ||
| | ||
| def quantize(self): | ||
| node = self.node | ||
| assert node.op_type == "InstanceNormalization" | ||
| | ||
| # Input | ||
| self.quantizer.quantize_activation_tensor(node.input[0]) | ||
| if not self.disable_qdq_for_node_output: | ||
| self.quantizer.quantize_activation_tensor(node.output[0]) | ||
| | ||
| # Scale | ||
| if self.quantizer.is_per_channel(): | ||
| self.quantizer.quantize_weight_tensor_per_channel(node.input[1], axis=1) | ||
| else: | ||
| self.quantizer.quantize_weight_tensor(node.input[1]) | ||
| | ||
| # Bias | ||
| self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1]) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit. This suggestion is invalid because no changes were made to the code. Suggestions cannot be applied while the pull request is closed. Suggestions cannot be applied while viewing a subset of changes. Only one suggestion per line can be applied in a batch. Add this suggestion to a batch that can be applied as a single commit. Applying suggestions on deleted lines is not supported. You must change the existing code in this line in order to create a valid suggestion. Outdated suggestions cannot be applied. This suggestion has been applied or marked resolved. Suggestions cannot be applied from pending reviews. Suggestions cannot be applied on multi-line comments. Suggestions cannot be applied while the pull request is queued to merge. Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation states that input rank must be 4, but I've tested with ranks 3 & 4 for both cpu and htp backends. This is unit tested as well.