From 3185ec8e438d9bcd52dda66f69d25f3391e993cd Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Thu, 11 Jan 2024 17:17:12 +0100 Subject: [PATCH] FIX: compute ArraySize with `len()` in `tnp` backend --- src/tensorwaves/function/sympy/_printer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tensorwaves/function/sympy/_printer.py b/src/tensorwaves/function/sympy/_printer.py index 228abe70..c2ced124 100644 --- a/src/tensorwaves/function/sympy/_printer.py +++ b/src/tensorwaves/function/sympy/_printer.py @@ -81,5 +81,6 @@ class TensorflowPrinter(CustomNumPyPrinter): def __init__(self) -> None: # https://github.com/sympy/sympy/blob/f1384c2/sympy/printing/printer.py#L21-L72 super().__init__() + self.known_functions["ArraySize"] = "len" self.known_functions["ComplexSqrt"] = "sqrt" self.printmethod = "_tensorflow_code"