-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【complex op】 No.10 add complex support for exp/expm1 #56398
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
|
✅ This PR's description meets the template requirements! |
@@ -183,28 +263,72 @@ def test_api_int(self): | |||
paddle.enable_static() | |||
|
|||
|
|||
class TestParameter: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把这个类放到原来位置吧,不然diff太大了
class TestExpm1(TestActivation): | ||
def setUp(self): | ||
self.op_type = "expm1" | ||
self.prim_op_type = "prim" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该不需要添加这行
self.python_api = paddle.expm1 | ||
self.public_python_api = paddle.exp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这行应该也不需要添加,并且如果添加应该是 paddle.expm1
out = np.expm1(x) | ||
|
||
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} | ||
self.outputs = {'Out': out} | ||
self.convert_input_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], 'Out') | ||
if self.dtype == np.complex64 or self.dtype == np.complex128: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块应该也不需要更改,check_grad
中 check_prim
默认为 False
。之前的逻辑里非 complex 类型的数据也没有 check prim,保持逻辑一致吧
pass | ||
|
||
|
||
class TestExp_Complex128(OpTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个应该可以继承 TestExp_Complex64
类
@@ -96,6 +96,25 @@ struct Cosine<dtype::bfloat16> { | |||
} | |||
}; | |||
|
|||
template <typename T> | |||
struct Exp { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块或许并不需要显示地定义 Exp
?
typename Out, | ||
typename dOut, | ||
typename dX> | ||
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该是 (Device d, X x UNUSED, Out out, dOut dout, dX dx)
typename dX> | ||
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { | ||
dx.device(d) = | ||
dout * x.unaryExpr(Exp<ComplexType<T>>()).unaryExpr(Conj<T>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只需要求 out
的共轭就行了吧
typename Out, | ||
typename dOut, | ||
typename dX> | ||
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该是 (Device d, X x UNUSED, Out out, dOut dout, dX dx)
typename dX> | ||
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { | ||
dx.device(d) = | ||
dout * x.unaryExpr(Exp<ComplexType<T>>()).unaryExpr(Conj<T>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块的求导应该不对
@Wanglongzhi2001 可以在 |
谢谢您的建议,第一次给Paddle提PR,还不大熟悉~ |
@ScottWong98 您好,现在的情况是如图所示测试无法通过,想向您请教一下。 第一个问题是Complex128类型时输出梯度的容忍错误度为1e-6,而我的达到了0.005, 查看Optest类后发现对于Complex128类型容忍错误度写死了1e-6无法通过 Paddle/test/legacy_test/eager_op_test.py Lines 2570 to 2576 in 8495377
第二个出现的问题是GPU的check_output问题,在CPU上输出没有问题,但是GPU上会出现虚部一直为0的情况,但是GPU上就仅是简单的使用thrust库进行计算而已,我也在本地测试过thrust的exp函数没有问题,不知道问题出在哪儿?
|
|
好的,谢谢~ |
@ScottWong98 @GGBond8488 麻烦 reveiw 一下 |
typename dOut, | ||
typename dX> | ||
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const { | ||
dx.device(d) = dout * out.unaryExpr(Conj<T>()) + dout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请教一下,为什么这里 expm1 (exp - 1) 的 梯度是 dout*(exp + 1) :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
事实上我也不是很理解,tensorflow 的梯度实现也是不用加上这个 dout 的:
https://github.com/tensorflow/tensorflow/blob/f82986df65bea201e5aa466e6993504372132cec/tensorflow/python/ops/math_grad.py#L688-L695
但是我看到 paddle 的其他数据类型的梯度实现是加上了这个 dout,并且我不加上这个 dout 确实梯度误差检查过不了,所以我就加上了,我也想请教下 paddle 的前辈们是不是 paddle 的算子梯度的实现这块不一样导致这样的结果。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exp(x)-1的grad是exp(x),exp(x)=(exp(x)-1)+1=out+1
@@ -640,7 +640,16 @@ def expm1(x, name=None): | |||
check_variable_and_dtype( | |||
x, | |||
'x', | |||
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], | |||
[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该方法的 docstring 也相应修改一下
self.check_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['X'], 'Out', max_relative_error=0.006) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的 max_relative_error
对于 complex64 和 complex128 都需要嘛
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的
@ScottWong98 docstring 已经修改过了,请问还有什么需要改正的地方吗?麻烦 review 一下 ^_^ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, 麻烦 @GGBond8488 再 review 一下
}; | ||
|
||
template <typename T> | ||
struct Expm1 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里为什么要单独定义Expm1呢,按理说这里本身是定义functor的, 这个也相当于是定义了expm1的运算,与下面的有点重复,如果是想定义复数expm1运算,建议放到complex.h里面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
您好,因为 trust 不支持 expm1 算子,并且 C++ 中的 expm1 不支持复数类型,所以 expm1 的复数实现需要用 exp 的复数实现来复合。之前我是放在 complex.h 里的,但是之前 @ScottWong98 建议 C++ 中不支持的放在 activation_functor.h 中
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,但是确定需要自己额外补充定义expm1吗,我看下面非复数的expm1也是调用的函数,如果确认需要的话,可以在activation_functor中定义,但是需要是复数的特化,不要代表所有的类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
确实是我的疏忽,应该只需要定义复数的特化,感谢您的建议~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改了,麻烦 review 一下 @GGBond8488
@@ -466,6 +466,16 @@ HOSTDEVICE inline complex<T> tanh(const complex<T>& a) { | |||
#endif | |||
} | |||
|
|||
template <typename T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里之前的pr应该已经加了exp, 可以同步一下最新的代码,不要加重复了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
PR types
New features
PR changes
OPs
Description
运行测试会出现问题,报如下的错误:
![2023-08-17 14-35-55屏幕截图](https://private-user-images.githubusercontent.com/69797242/261266346-110b79d5-9fb1-48ca-83b4-5f34b7ef94f9.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MzkxMTY5MjEsIm5iZiI6MTczOTExNjYyMSwicGF0aCI6Ii82OTc5NzI0Mi8yNjEyNjYzNDYtMTEwYjc5ZDUtOWZiMS00OGNhLTgzYjQtNWYzNGI3ZWY5NGY5LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMDklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjA5VDE1NTcwMVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTVjZmEwMzlhYmE2NjlkNzJmNDEwN2ZkNzdhMmQ0YWE4ZTMzNjI2OTQ1NTU5ODAxMTEzZjZlM2MzOTA3OTVmMjAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.kwTLX8k-hjCJ1viBdz_VIoB7DVNVlfSqo0EqOo4CUGI)
可以麻烦帮忙看看吗, @ScottWong98
#56145