Skip to content

[Zero-Dim] support input 0D Tensor for std/var#49735

Merged
zhwesky2010 merged 54 commits intoPaddlePaddle:developfrom
mhy-666:0dtensor_1
Feb 10, 2023
Merged

[Zero-Dim] support input 0D Tensor for std/var#49735
zhwesky2010 merged 54 commits intoPaddlePaddle:developfrom
mhy-666:0dtensor_1

Conversation

@mhy-666
Copy link
Contributor

@mhy-666 mhy-666 commented Jan 11, 2023

PR types

New features

PR changes

APIs

Describe

[Zero-Dim] support input 0D Tensor for std/var:

为以下API支持输入0D Tensor,不修改其输出端行为,属于新增功能,无不兼容影响:
paddle.std
paddle.var

@paddle-bot
Copy link

paddle-bot bot commented Jan 11, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这两个直接加到reduce_api_list应该可以

@zhwesky2010 zhwesky2010 changed the title add test_std [Zero-Dim] support input 0D Tensor for std/var Feb 8, 2023
out.backward()

# checkout shape of out
self.assertEqual(out.shape, [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要对反向进行测试,梯度为0.

def test_std(self):
x = paddle.rand([])
x.stop_gradient = False
out = paddle.std(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

再加一种 paddle.std(x, []) 的情形

# checkout value of out
self.assertEqual(out, 0)

def test_var(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同std

self.assertEqual(res[2].shape, ())
np.testing.assert_allclose(res[2], 1.0)

@prog_scope()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

与动态图问题一致

"""
if not in_dygraph_mode():
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'var')
if len(x.shape) == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样写可能不太美观,如果直接通过原来的分支实现会有问题吗?

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit 86cc694 into PaddlePaddle:develop Feb 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

2 participants