From 989d7bee1dcd9601f8a8fd2091bd0323a693033b Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Thu, 8 Dec 2022 09:01:09 +0000 Subject: [PATCH 1/2] fix eigh backward test case --- api/tests/eigh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/tests/eigh.py b/api/tests/eigh.py index 186be5734b..5712444d24 100644 --- a/api/tests/eigh.py +++ b/api/tests/eigh.py @@ -31,7 +31,7 @@ def build_graph(self, config): self.feed_list = [x] self.fetch_list = [out_w, out_v] if config.backward: - self.append_gradients(out_w, [x]) + self.append_gradients([out_w, out_v], [x]) @benchmark_registry.register("eigh") @@ -43,4 +43,4 @@ def build_graph(self, config): self.feed_list = [x] self.fetch_list = [out_w, out_v] if config.backward: - self.append_gradients(out_w, [x]) + self.append_gradients([out_w, out_v], [x]) From cea43d1e351360403358f6fa2cfa5aed931660b8 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Thu, 8 Dec 2022 09:56:00 +0000 Subject: [PATCH 2/2] fix bug caused by backward diff between paddle and torch --- api/tests/eigh.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/tests/eigh.py b/api/tests/eigh.py index 5712444d24..611eb5a3df 100644 --- a/api/tests/eigh.py +++ b/api/tests/eigh.py @@ -31,7 +31,7 @@ def build_graph(self, config): self.feed_list = [x] self.fetch_list = [out_w, out_v] if config.backward: - self.append_gradients([out_w, out_v], [x]) + self.append_gradients(out_w.sum() + paddle.abs(out_v).sum(), [x]) @benchmark_registry.register("eigh") @@ -43,4 +43,4 @@ def build_graph(self, config): self.feed_list = [x] self.fetch_list = [out_w, out_v] if config.backward: - self.append_gradients([out_w, out_v], [x]) + self.append_gradients(out_w.sum() + torch.abs(out_v).sum(), [x])