Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
b2c334c
Merged PR 8479: Add QNN ExecutionProvider which enables OnnxRuntime i…
HectorSVC Sep 23, 2022
dcc47d9
Merged PR 8567: Support Where op with QDQ nodes as a node unit group
HectorSVC Sep 29, 2022
6a2db15
Merged PR 8613: Make Softmax fallback more accurate
HectorSVC Oct 5, 2022
a5eab82
Merged PR 9177: Make StyleGAN work on Linux and enable on build pipeline
SatyaJandhyalaAtMS Nov 8, 2022
0162e1b
Merged PR 9204: Replaced qnn_model_wrapper pointer with a reference.
SatyaJandhyalaAtMS Nov 11, 2022
4b25b1f
Merged PR 9226: Remove high threshold for QNN QDQ model test
SatyaJandhyalaAtMS Nov 16, 2022
410016c
Merged PR 9203: QNN v2 integration
HectorSVC Nov 17, 2022
83a63d7
Merged PR 9328: Fix the issues relate to execute Text Prediction QDQ …
HectorSVC Nov 21, 2022
3bdfe1a
Merged PR 9361: Change the way to identify the back-end capacity
HectorSVC Nov 28, 2022
cd401c0
Merged PR 9386: Create Windows ARM64 build pipeline
SatyaJandhyalaAtMS Nov 29, 2022
ff4d46e
Merged PR 9456: Add missing part Qnn device creation/releasing which …
HectorSVC Dec 2, 2022
2f22ca1
Merged PR 9476: Revert 'Add missing part Qnn device creation/releasin…
HectorSVC Dec 6, 2022
aa18a71
Merged PR 9429: Add support for single Transpose node in QDQ model
HectorSVC Dec 6, 2022
028498b
Merged PR 9432: Enable Quantized MobileNet test and create an interna…
SatyaJandhyalaAtMS Dec 6, 2022
3d70683
Merged PR 9521: minor fix to the unreachable code
HectorSVC Dec 8, 2022
9f8b209
Merged PR 9567: Moved QNN SDK version from 2.3 to 2.5
SatyaJandhyalaAtMS Dec 12, 2022
55e2879
Merged PR 9572: Sync with latest Github main
HectorSVC Dec 12, 2022
6ce8f78
Merged PR 9583: Enable ConvTest UT for Qnn. Exclude tests with dynami…
HectorSVC Dec 13, 2022
892d3f0
Merged PR 9661: Added TopK operator and removed unnecessary variable …
SatyaJandhyalaAtMS Dec 22, 2022
0f4311b
Merged PR 9822: Fixed unused parameter warning.
SatyaJandhyalaAtMS Jan 9, 2023
b75c3c8
Merged PR 9836: Enable Tanh, ReduceMin, Slice for node unit support
HectorSVC Jan 9, 2023
c15e7ac
Merged PR 9877: Enable Tile Op
HectorSVC Jan 13, 2023
cbe2339
Merged PR 10032: Convert Gather indices initializer data from int64 t…
HectorSVC Jan 24, 2023
f6b622e
Merged PR 9968: Add NonMaxSuppression operator support on QNN EP
SatyaJandhyalaAtMS Jan 26, 2023
4e72ed8
Merged PR 10131: Revert 'Add NonMaxSuppression operator support on QN…
SatyaJandhyalaAtMS Jan 29, 2023
d358d87
Merged PR 9866: Limit the transpose optimizer works for Transpose wit…
HectorSVC Feb 8, 2023
837a70c
Merged PR 10311: Update QNN version to 2.6.0
adrianlizarraga Feb 9, 2023
bb0cb62
Merge branch 'main' into qnn_ep_github
HectorSVC Feb 18, 2023
93f2883
resolve conflicts
HectorSVC Feb 22, 2023
7bc5bb3
disable some new tests for Qnn EP
HectorSVC Feb 22, 2023
80e714e
Disable LayerNormalization test for Qnn EP
HectorSVC Feb 22, 2023
1feb5f0
disable onnx node tests: resize_downsample_scales_linear_antialias & …
HectorSVC Feb 22, 2023
ff485ac
extend timeout limit
HectorSVC Feb 22, 2023
2b22726
Merge branch 'main' into qnn_ep_github
HectorSVC Feb 22, 2023
ea75d78
resolve merge conflicts
HectorSVC Feb 22, 2023
1b3aebb
disable some LayerNormTest tests
HectorSVC Feb 22, 2023
93dda20
Back out the fix for transpose optimizer issue. Wait for Scott's chan…
HectorSVC Feb 23, 2023
053def6
correct typo
HectorSVC Feb 23, 2023
ca046d4
Update tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-cr…
HectorSVC Feb 23, 2023
bd2cbdd
remove provider options "runtime"
HectorSVC Feb 24, 2023
44728cb
Merge branch 'qnn_ep_github' of https://github.com/microsoft/onnxrunt…
HectorSVC Feb 24, 2023
277a860
Remove cmake_extra_defines from Linux build pipeline
HectorSVC Feb 24, 2023
1284fb0
Update tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml
HectorSVC Feb 24, 2023
f43f377
Reorganize QNN EP model tests into separate float32 and qdq folders
adrianlizarraga Feb 24, 2023
906f93d
use enum class
HectorSVC Feb 25, 2023
1aab226
Remove Qnn graph creation during graph partitioning since it's not re…
HectorSVC Feb 27, 2023
6ad1138
Update linux-qnn-ci-pipeline.yml to use a new pool
adrianlizarraga Feb 27, 2023
2d9363a
Update onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_build…
HectorSVC Feb 27, 2023
947a026
Update onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_build…
HectorSVC Feb 27, 2023
8f1f3c0
Update onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_build…
HectorSVC Feb 27, 2023
83431b5
Add InstanceNormalization operator to QNN EP
adrianlizarraga Mar 1, 2023
6e55965
Fix param indentation
adrianlizarraga Mar 1, 2023
328237b
Add InstanceNormalization unit test that includes QNN EP
adrianlizarraga Mar 1, 2023
5e175d8
Start supporting inputs with rank > 2
adrianlizarraga Mar 1, 2023
202821e
Merge latest commits from main; Expect QNN InstanceNorm op input to h…
adrianlizarraga Mar 1, 2023
e198d44
Remove unnecessary comments
adrianlizarraga Mar 2, 2023
22f418f
Merge latest commits from main
adrianlizarraga Mar 2, 2023
e5843b2
Allow DQ->InstanceNorm->Q to be treated as a node unit. Fix bug when …
adrianlizarraga Mar 3, 2023
687d063
Add QNN unit tests for QDQ Conv and InstanceNorm ops
adrianlizarraga Mar 3, 2023
9b86707
Add unused parameter macro. Add comments
adrianlizarraga Mar 3, 2023
4a684d6
Make InstanceNormalization op layout sensitive for all EPs (not just …
adrianlizarraga Mar 3, 2023
ce229fc
Remove unnecessary epsilon attr validation; Remove unnecessary overri…
adrianlizarraga Mar 4, 2023
8c4dd8a
Add runtime check for Windows ARM64 that skips qdq op test if HTP bac…
adrianlizarraga Mar 4, 2023
6905816
Update onnxruntime/test/providers/qnn/qnn_basic_test.cc
adrianlizarraga Mar 4, 2023
9c5f880
Address comments
adrianlizarraga Mar 4, 2023
818341d
Merge latest commits from main
adrianlizarraga Mar 4, 2023
842da60
Cache result of runtime query for HTP support on Windows ARM64
adrianlizarraga Mar 4, 2023
92cccb7
Fix merge conflicts
adrianlizarraga Mar 4, 2023
7e73772
Make HTPBackendTestFixture visible on linux
adrianlizarraga Mar 4, 2023
b5af1d8
Clean up tests
adrianlizarraga Mar 4, 2023
123f523
Add new QDQ selector for InstanceNormalization op
adrianlizarraga Mar 8, 2023
0982f1f
Update QDQ InstanceNorm test case
adrianlizarraga Mar 8, 2023
5f44994
Add static casts, clean up
adrianlizarraga Mar 8, 2023
67ef4ce
Add QDQ support for the InstanceNormalization operator to the quantiz…
adrianlizarraga Mar 10, 2023
28b5e56
Run python black linter on test_op_instance_normalization.py
adrianlizarraga Mar 10, 2023
64d04d0
Remove unused imports from test_op_instance_normalization.py
adrianlizarraga Mar 10, 2023
cdf885c
Run python black linter on instnorm.py
adrianlizarraga Mar 10, 2023
a91eefc
Run python isort on quantize.py
adrianlizarraga Mar 10, 2023
83b75f8
Fix pylint warnings in test_op_instance_normalization.py
adrianlizarraga Mar 10, 2023
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)

if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/target/${QNN_ARCH_ABI}/lib/*.so" "${onnxruntime_QNN_HOME}/target/${QNN_ARCH_ABI}/lib/*.dll")
if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc")
file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/target/hexagon-v68/lib/unsigned/libQnnHtpV68Skel.so")
list(APPEND QNN_LIB_FILES ${EXTRA_HTP_LIB})
endif()
message(STATUS "QNN lib files: " ${QNN_LIB_FILES})
add_custom_command(
TARGET ${test_data_target} POST_BUILD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,25 @@ bool WhereNodeGroupSelector::Check(const GraphViewer &graph_viewer, const Node &

}

bool InstanceNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer,
const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const {
if (!CheckQDQNodes(graph_viewer, node, dq_nodes, q_nodes)) {
return false;
}

int32_t dt_input = dq_nodes[0]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_scale = dq_nodes[1]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_bias = dq_nodes[2]->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type();
int32_t dt_output = q_nodes[0]->OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type();

// Input, output, and scale need to be the same type. The bias is int32.
return (dt_input == dt_output) &&
(dt_input == dt_scale) &&
(dt_bias == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32);
}

} // namespace QDQ
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ class GemmNodeGroupSelector : public NodeGroupSelector {
const std::vector<const Node*>& q_nodes) const override;
};

// Input: DQ nodes for input, scale, and B
// Output: Q node for output
class InstanceNormalizationNodeGroupSelector : public NodeGroupSelector {
private:
bool Check(const GraphViewer& graph_viewer, const Node& node,
const std::vector<const Node*>& dq_nodes,
const std::vector<const Node*>& q_nodes) const override;
};

/*
* NodeSelector instances for use in the QDQ::SelectorActionTransformer.
*/
Expand Down Expand Up @@ -232,6 +241,14 @@ class GemmSelector : public BaseSelector {
void UpdateBuilder(NodesToOptimizeIndicesBuilder&) const override;
};

// Input: DQ nodes for input, scale, and B (bias)
// Output: Q node for output
class InstanceNormalizationSelector : public BaseSelector {
public:
InstanceNormalizationSelector()
: BaseSelector(std::make_unique<InstanceNormalizationNodeGroupSelector>()) {}
};

} // namespace QDQ
} // namespace onnxruntime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ static const OpVersionsAndSelector::OpVersionsMap GetMatMulOpVersionsMap() {
static const OpVersionsAndSelector::OpVersionsMap GetGemmOpVersionsMap() {
return {{"Gemm", {}}};
}
static const OpVersionsAndSelector::OpVersionsMap GetInstanceNormalizationOpVersionsMap() {
return {{"InstanceNormalization", {}}};
}

/* Selector rules registration related */
void RegisterMiscSelectors(Selectors& qdq_selectors) {
Expand Down Expand Up @@ -133,6 +136,13 @@ void RegisterGemmSelector(Selectors& qdq_selectors) {
std::move(selector));
}

void RegisterInstanceNormalizationSelector(Selectors& qdq_selectors) {
/* register selector for InstanceNormalization op */
std::unique_ptr<NodeGroupSelector> selector = std::make_unique<InstanceNormalizationNodeGroupSelector>();
qdq_selectors.RegisterSelector(GetInstanceNormalizationOpVersionsMap(),
std::move(selector));
}

void SelectorManager::CreateSelectors() {
RegisterMiscSelectors(qdq_selectors_);
RegisterUnarySelectors(qdq_selectors_);
Expand All @@ -142,6 +152,7 @@ void SelectorManager::CreateSelectors() {
RegisterConvTransposeSelector(qdq_selectors_);
RegisterMatMulSelector(qdq_selectors_);
RegisterGemmSelector(qdq_selectors_);
RegisterInstanceNormalizationSelector(qdq_selectors_);
}

void SelectorManager::InitializeSelectorsMap() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2043,7 +2043,7 @@ const std::unordered_set<std::string_view>& GetLayoutSensitiveOps() {
"Conv", "QLinearConv", "BatchNormalization",
"AveragePool", "GlobalAveragePool", "MaxPool",
"GlobalMaxPool", "LRN", "GridSample",
"DepthToSpace", "SpaceToDepth", "ConvTranspose", "MaxUnpool"};
"DepthToSpace", "SpaceToDepth", "ConvTranspose", "MaxUnpool", "InstanceNormalization"};

return layout_sensitive_ops;
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
{
CreateTileOpBuilder("Tile", *this);
}

{
CreateInstanceNormOpBuilder("InstanceNormalization", *this);
}
}

const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,7 @@ void CreateTopKOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_

void CreateTileOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateInstanceNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class BaseOpBuilder : public IOpBuilder {
{"ArgMin", "Argmin"},
{"ConvTranspose", "TransposeConv2d"},
{"Tile", "Tile"},
{"TopK", "TopK"}};
{"TopK", "TopK"},
{"InstanceNormalization", "InstanceNorm"}};
auto it = onnx_op_type_to_qnn_op_type.find(onnx_op_type);
ORT_ENFORCE(it != onnx_op_type_to_qnn_op_type.end());
return it->second;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ConvOpBuilder : public BaseOpBuilder {
// The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW
// The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC
// Need to do op validation in 1st call of GetCapability
// TODO: Check if node domain == kMSInternalNHWCDomain to determine if the layout has been transformed.
Status ConvOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/common/safeint.h"
#include "onnx/defs/data_type_utils.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace qnn {

class InstanceNormOpBuilder : public BaseOpBuilder {
public:
InstanceNormOpBuilder() : BaseOpBuilder("InstanceNormOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InstanceNormOpBuilder);

Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model) const override final ORT_MUST_USE_RESULT;

protected:
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool is_quantized_model,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

// Instance normalization op is sensitive to data layout.
// The nodes from 1st call of GetCapability do not get layout transformer applied, so their shapes are still NCHW.
// The nodes from 2nd call of GetCapability get their layout transformed to NHWC.
// Therefore, we need to check the node domain to determine if the layout has been transformed.
Status InstanceNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
bool is_quantized_model) const {
ORT_UNUSED_PARAMETER(logger);

const auto float_elem_type = ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float");

// Check input type is float for CPU.
const auto& inputs = node_unit.Inputs();
ONNX_NAMESPACE::DataType input_data_type = inputs[0].node_arg.Type();
if (!is_quantized_model && input_data_type != float_elem_type) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm data type " + *input_data_type + " is not supported in CPU backend.");
}

// Also check output type is float for CPU.
const auto& outputs = node_unit.Outputs();
ONNX_NAMESPACE::DataType output_data_type = outputs[0].node_arg.Type();
if (!is_quantized_model && output_data_type != float_elem_type) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm data type " + *output_data_type + " is not supported in CPU backend.");
}

std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0");
const size_t input_rank = input_shape.size();

if (input_rank <= 2 || input_rank > 4) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation states that input rank must be 4, but I've tested with ranks 3 & 4 for both cpu and htp backends. This is unit tested as well.

return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm only supports input ranks of size 3 or 4.");
}

const uint32_t num_channels = (node_unit.Domain() == kMSInternalNHWCDomain) ? input_shape.back() : input_shape[1];

std::vector<uint32_t> scale_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)");
if (scale_shape.size() != 1 || scale_shape[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm input 1 (scale) must have 1D shape [channel].");
}

std::vector<uint32_t> bias_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)");
if (bias_shape.size() != 1 || bias_shape[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm input 2 (bias) must have 1D shape [channel].");
}

NodeAttrHelper node_helper(node_unit);
const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec.
if (epsilon <= 0.0f) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN InstanceNorm epsilon must be greater than 0.0");
}

return Status::OK();
}

Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool is_quantized_model,
bool do_op_validation) const {
NodeAttrHelper node_helper(node_unit);
std::vector<std::string> param_tensor_names;

const float epsilon = node_helper.Get("epsilon", 1e-05f); // Default is 1e-05 according to ONNX spec.
Qnn_Scalar_t epsilon_param = QNN_SCALAR_INIT;
epsilon_param.dataType = QNN_DATATYPE_FLOAT_32;
epsilon_param.floatValue = epsilon;
QnnParamWrapper epsilon_param_wrapper(node_unit.Index(),
node_unit.Name(),
qnn_def::epsilon,
epsilon_param);
param_tensor_names.push_back(epsilon_param_wrapper.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(epsilon_param_wrapper));

ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger, is_quantized_model, do_op_validation));

return Status::OK();
}

void CreateInstanceNormOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.AddOpBuilder(op_type, std::make_unique<InstanceNormOpBuilder>());
}

} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class PoolOpBuilder : public BaseOpBuilder {
// The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW
// The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC
// Need to do op validation in 1st call of GetCapability
// TODO: Check if node domain == kMSInternalNHWCDomain to determine if the layout has been transformed.
Status PoolOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ResizeOpBuilder : public BaseOpBuilder {
// The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW
// The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC
// Need to do op validation in 1st call of GetCapability
// TODO: Check if node domain == kMSInternalNHWCDomain to determine if the layout has been transformed.
Status ResizeOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class SimpleOpBuilder : public BaseOpBuilder {
const std::string input_name) const;
Status HandleSingleTransposeNode(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names) const;
std::vector<std::string>&& input_names,
bool is_quantized_model) const;
};

Status SimpleOpBuilder::ExplictOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
Expand Down Expand Up @@ -153,14 +154,42 @@ Status SimpleOpBuilder::ProcessAlphaAttribute(QnnModelWrapper& qnn_model_wrapper

// Support Transpose single node in QDQ model since it just change the data layout
// Single node doesn't has any quantization parameters
// Input tensors are created by previous node, output tensors created by next node
// Input tensors are created by the previous node. Output tensors are created by the next node,
// unless the output is the graph's final output.
Status SimpleOpBuilder::HandleSingleTransposeNode(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names) const {
std::vector<std::string>&& input_names,
bool is_quantized_model) const {
std::vector<std::string> param_tensor_names;
ORT_RETURN_IF_ERROR(ProcessPermAttribute(qnn_model_wrapper, node_unit, param_tensor_names));
const auto& outputs = node_unit.Outputs();
ORT_ENFORCE(outputs.size() == 1, "QNN Transpose node must have a single output.");
const auto& output = outputs[0];
auto& output_name = output.node_arg.Name();

const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_name);

// Need to add output to the QNN model wrapper if this Transpose node's output is also
// the graph's output.
if (is_graph_output) {
const auto* type_proto = output.node_arg.TypeAsProto();
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED;
ORT_RETURN_IF_ERROR(GetQnnDataType(is_quantized_model, type_proto, qnn_data_type));

Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT;
std::vector<uint32_t> output_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape),
"Cannot get shape for QNN Transpose output");

QnnTensorWrapper output_tensorwrapper(output_name,
QNN_TENSOR_TYPE_APP_READ,
qnn_data_type,
quantize_param,
std::move(output_shape));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)),
"Failed to add output tensor for QNN Transpose");
}

auto& output_name = node_unit.Outputs()[0].node_arg.Name();
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(GetNodeName(node_unit),
qnn_def::package_name,
GetQnnOpType(node_unit.OpType()),
Expand All @@ -186,7 +215,7 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
} else if (is_quantized_model && NodeUnit::Type::SingleNode == node_unit.UnitType() &&
node_unit.OpType() == "Transpose") {
LOGS(logger, VERBOSE) << "Add single Transpose node: " << node_unit.Name();
return HandleSingleTransposeNode(qnn_model_wrapper, node_unit, std::move(input_names));
return HandleSingleTransposeNode(qnn_model_wrapper, node_unit, std::move(input_names), is_quantized_model);
}

std::vector<std::string> param_tensor_names;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/qnn/builder/qnn_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ const std::string nearest_mode = "nearest_mode";
const std::string rounding_mode = "rounding_mode";
const std::string topk = "k";
const std::string multiples = "multiples";
const std::string epsilon = "epsilon";
} // namespace qnn_def

} // namespace qnn
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/python/tools/quantization/operators/instnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from .qdq_base_operator import QDQOperatorBase


class QDQInstanceNormalization(QDQOperatorBase):
def __init__(self, onnx_quantizer, onnx_node):
super().__init__(onnx_quantizer, onnx_node)

def quantize(self):
node = self.node
assert node.op_type == "InstanceNormalization"

# Input
self.quantizer.quantize_activation_tensor(node.input[0])
if not self.disable_qdq_for_node_output:
self.quantizer.quantize_activation_tensor(node.output[0])

# Scale
if self.quantizer.is_per_channel():
self.quantizer.quantize_weight_tensor_per_channel(node.input[1], axis=1)
else:
self.quantizer.quantize_weight_tensor(node.input[1])

# Bias
self.quantizer.quantize_bias_tensor(node.input[2], node.input[0], node.input[1])
Loading