Skip to content

Commit 3d3169b

Browse files
authored
fix DataLayout::kNHWC in binary.cc (#76562)
1 parent f6ba4f2 commit 3d3169b

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

paddle/phi/infermeta/binary.cc

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -831,8 +831,7 @@ void ConvTransposeInferMeta(const MetaTensor& x,
831831
"should be the same."));
832832

833833
const int64_t C =
834-
(data_layout != DataLayout::kNHWC ? x_dims[1]
835-
: x_dims[x_dims.size() - 1]);
834+
(data_layout != DataLayout::NHWC ? x_dims[1] : x_dims[x_dims.size() - 1]);
836835
PADDLE_ENFORCE_EQ(
837836
C,
838837
filter_dims[0],
@@ -849,7 +848,7 @@ void ConvTransposeInferMeta(const MetaTensor& x,
849848
data_format));
850849

851850
DDim x_data_dims;
852-
if (data_layout != DataLayout::kNHWC) {
851+
if (data_layout != DataLayout::NHWC) {
853852
x_data_dims = slice_ddim(x_dims, 2, x_dims.size());
854853
} else {
855854
x_data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1);
@@ -860,10 +859,10 @@ void ConvTransposeInferMeta(const MetaTensor& x,
860859
&paddings_, &dilations_, padding_algorithm, x_data_dims, strides, ksize);
861860

862861
std::vector<int64_t> output_shape({x_dims[0]});
863-
if (data_layout != DataLayout::kNHWC) {
862+
if (data_layout != DataLayout::NHWC) {
864863
output_shape.push_back(filter_dims[1] * groups);
865864
}
866-
const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1);
865+
const int offset = (data_layout != DataLayout::NHWC ? 2 : 1);
867866
for (int i = 0; i < static_cast<int>(strides.size()); ++i) {
868867
auto filter_extent = dilations_[i] * (filter_dims[i + 2] - 1) + 1;
869868
auto infer_shape = (config.is_runtime || x_dims[i + offset] > 0)
@@ -932,7 +931,7 @@ void ConvTransposeInferMeta(const MetaTensor& x,
932931
output_shape.push_back(infer_shape);
933932
}
934933
}
935-
if (data_layout == DataLayout::kNHWC) {
934+
if (data_layout == DataLayout::NHWC) {
936935
output_shape.push_back(filter_dims[1] * groups);
937936
}
938937

@@ -1700,7 +1699,7 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
17001699
bool should_rotate =
17011700
config.is_run_onednn_kernel &&
17021701
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
1703-
phi::DataLayout::kNHWC) &&
1702+
phi::DataLayout::NHWC) &&
17041703
(x_dims.size() >= 3 || y_dims.size() >= 3);
17051704
if (should_rotate) {
17061705
// Pick bigger shape and rotate this one
@@ -1743,10 +1742,10 @@ void ElementwiseRawInferMeta(const MetaTensor& x,
17431742
}
17441743
out->set_dtype(promote_result);
17451744

1746-
// layout need change when meet input layout contain kNHWC
1745+
// layout need change when meet input layout contain NHWC
17471746
auto layout = [&]() {
1748-
if (x.layout() == DataLayout::kNHWC || y.layout() == DataLayout::kNHWC)
1749-
return DataLayout::kNHWC;
1747+
if (x.layout() == DataLayout::NHWC || y.layout() == DataLayout::NHWC)
1748+
return DataLayout::NHWC;
17501749
return x.layout();
17511750
}();
17521751

0 commit comments

Comments
 (0)