From 02d22f7da8dd485f5ed978b4b3c3ec791a7c594d Mon Sep 17 00:00:00 2001 From: Deepyaman Datta Date: Fri, 25 Aug 2023 16:53:26 -0600 Subject: [PATCH] feat(flink): fine-tune numeric literal translation --- .../test_translate_filter/out.sql | 2 +- .../test_translate_having/out.sql | 2 +- .../test_window/test_rows_window/out.sql | 2 +- ibis/backends/flink/tests/test_literals.py | 4 ++-- ibis/backends/flink/utils.py | 20 +++++++++++++++++-- ibis/backends/tests/test_numeric.py | 11 ++++++++++ 6 files changed, 34 insertions(+), 7 deletions(-) diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_filter/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_filter/out.sql index da36f49f52e36..a74d83d269b56 100644 --- a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_filter/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_filter/out.sql @@ -1,4 +1,4 @@ SELECT t0.* FROM table t0 -WHERE ((t0.`c` > 0) OR (t0.`c` < 0)) AND +WHERE ((t0.`c` > CAST(0 AS TINYINT)) OR (t0.`c` < CAST(0 AS TINYINT))) AND (t0.`g` IN ('A', 'B')) \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql index c34fa980e5ae7..1f4153d357e69 100644 --- a/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_translator/test_translate_having/out.sql @@ -1,4 +1,4 @@ SELECT t0.`g`, sum(t0.`b`) AS `b_sum` FROM table t0 GROUP BY t0.`g` -HAVING count(*) >= 1000 \ No newline at end of file +HAVING count(*) >= CAST(1000 AS SMALLINT) \ No newline at end of file diff --git a/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql b/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql index 6e0f14b726157..3173072eafb64 100644 --- a/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql +++ b/ibis/backends/flink/tests/snapshots/test_window/test_rows_window/out.sql @@ -1,2 +1,2 @@ -SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC ROWS BETWEEN 1000 PRECEDING AND CURRENT ROW) AS `Sum(f)` +SELECT sum(t0.`f`) OVER (ORDER BY t0.`f` ASC ROWS BETWEEN CAST(1000 AS SMALLINT) PRECEDING AND CURRENT ROW) AS `Sum(f)` FROM table t0 \ No newline at end of file diff --git a/ibis/backends/flink/tests/test_literals.py b/ibis/backends/flink/tests/test_literals.py index a691c8c1d2335..ab85ded90b7d3 100644 --- a/ibis/backends/flink/tests/test_literals.py +++ b/ibis/backends/flink/tests/test_literals.py @@ -14,8 +14,8 @@ @pytest.mark.parametrize( "value,expected", [ - param(5, "5", id="int"), - param(1.5, "1.5", id="float"), + param(5, "CAST(5 AS TINYINT)", id="int"), + param(1.5, "CAST(1.5 AS DOUBLE)", id="float"), param(True, "TRUE", id="true"), param(False, "FALSE", id="false"), ], diff --git a/ibis/backends/flink/utils.py b/ibis/backends/flink/utils.py index dd85b7fd8b452..19762b365fe8b 100644 --- a/ibis/backends/flink/utils.py +++ b/ibis/backends/flink/utils.py @@ -5,6 +5,8 @@ from abc import ABC, abstractmethod from collections import defaultdict +from pyflink.table.types import DataTypes + import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.temporal import IntervalUnit @@ -31,7 +33,6 @@ IntervalUnit.SECOND: 9, } - MICROSECONDS_IN_UNIT = { unit: datetime.timedelta(**{unit.plural: 1}).total_seconds() * 10**6 for unit in [ @@ -244,6 +245,21 @@ def _translate_interval(value, dtype): return interval.format_as_string() +_to_pyflink_types = { + dt.Int8: DataTypes.TINYINT(), + dt.Int16: DataTypes.SMALLINT(), + dt.Int32: DataTypes.INT(), + dt.Int64: DataTypes.BIGINT(), + dt.UInt8: DataTypes.TINYINT(), + dt.UInt16: DataTypes.SMALLINT(), + dt.UInt32: DataTypes.INT(), + dt.UInt64: DataTypes.BIGINT(), + dt.Float16: DataTypes.FLOAT(), + dt.Float32: DataTypes.FLOAT(), + dt.Float64: DataTypes.DOUBLE(), +} + + def translate_literal(op: ops.Literal) -> str: value = op.value dtype = op.dtype @@ -266,7 +282,7 @@ def translate_literal(op: ops.Literal) -> str: raise ValueError("NaN is not supported in Flink SQL") elif math.isinf(value): raise ValueError("Infinity is not supported in Flink SQL") - return repr(value) + return f"CAST({value} AS {_to_pyflink_types[type(dtype)]!s})" elif dtype.is_timestamp(): # TODO(chloeh13q): support timestamp with local timezone if isinstance(value, datetime.datetime): diff --git a/ibis/backends/tests/test_numeric.py b/ibis/backends/tests/test_numeric.py index be4a849325589..688370be44d24 100644 --- a/ibis/backends/tests/test_numeric.py +++ b/ibis/backends/tests/test_numeric.py @@ -71,6 +71,7 @@ "trino": "integer", "duckdb": "TINYINT", "postgres": "integer", + "flink": "TINYINT NOT NULL", }, id="int8", ), @@ -85,6 +86,7 @@ "trino": "integer", "duckdb": "SMALLINT", "postgres": "integer", + "flink": "SMALLINT NOT NULL", }, id="int16", ), @@ -99,6 +101,7 @@ "trino": "integer", "duckdb": "INTEGER", "postgres": "integer", + "flink": "INT NOT NULL", }, id="int32", ), @@ -113,6 +116,7 @@ "trino": "integer", "duckdb": "BIGINT", "postgres": "integer", + "flink": "BIGINT NOT NULL", }, id="int64", ), @@ -127,6 +131,7 @@ "trino": "integer", "duckdb": "UTINYINT", "postgres": "integer", + "flink": "TINYINT NOT NULL", }, id="uint8", ), @@ -141,6 +146,7 @@ "trino": "integer", "duckdb": "USMALLINT", "postgres": "integer", + "flink": "SMALLINT NOT NULL", }, id="uint16", ), @@ -155,6 +161,7 @@ "trino": "integer", "duckdb": "UINTEGER", "postgres": "integer", + "flink": "INT NOT NULL", }, id="uint32", ), @@ -169,6 +176,7 @@ "trino": "integer", "duckdb": "UBIGINT", "postgres": "integer", + "flink": "BIGINT NOT NULL", }, id="uint64", ), @@ -183,6 +191,7 @@ "trino": "double", "duckdb": "FLOAT", "postgres": "numeric", + "flink": "FLOAT NOT NULL", }, marks=[ pytest.mark.notimpl( @@ -209,6 +218,7 @@ "trino": "double", "duckdb": "FLOAT", "postgres": "numeric", + "flink": "FLOAT", }, id="float32", ), @@ -223,6 +233,7 @@ "trino": "double", "duckdb": "DOUBLE", "postgres": "numeric", + "flink": "DOUBLE", }, id="float64", ),