Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1748140: Modify schema_expression to be structured type aware. #2659

Merged
22 changes: 20 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
return "TRY_TO_GEOGRAPHY(NULL)"
if isinstance(data_type, GeometryType):
return "TRY_TO_GEOMETRY(NULL)"
if isinstance(data_type, ArrayType):
if isinstance(data_type, ArrayType) and not data_type.structured:
return "PARSE_JSON('NULL') :: ARRAY"
if isinstance(data_type, MapType):
if isinstance(data_type, MapType) and not data_type.structured:
return "PARSE_JSON('NULL') :: OBJECT"
if isinstance(data_type, VariantType):
return "PARSE_JSON('NULL') :: VARIANT"
Expand Down Expand Up @@ -213,9 +213,27 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
else:
return "to_timestamp('2020-09-16 06:30:00')"
if isinstance(data_type, ArrayType):
if data_type.structured:
element = schema_expression(data_type.element_type, is_nullable)
return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}"
return "to_array(0)"
if isinstance(data_type, MapType):
if data_type.structured:
key = schema_expression(data_type.key_type, is_nullable)
value = schema_expression(data_type.value_type, is_nullable)
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we determine whether keeping null values based on is_nullable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't keep nulls and either the key or value gets evaluated to a NULL :: type statement then that field would be dropped from the schema altogether. For this reason I think we always want nulls.

return "to_object(parse_json('0'))"
if isinstance(data_type, StructType):
if data_type.structured:
schema_strings = []
for field in data_type.fields:
# Even if nulls are allowed the cast will fail due to schema mismatch when passed a null field.
schema_strings += [
f"'{field.name}'",
schema_expression(field.datatype, is_nullable=False),
]
return f"object_construct_keep_null({', '.join(schema_strings)}) :: {convert_sp_to_sf_type(data_type)}"
return "to_object(parse_json('{}'))"
if isinstance(data_type, VariantType):
return "to_variant(0)"
if isinstance(data_type, GeographyType):
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def convert_metadata_to_sp_type(
StructField(
quote_name(field.name, keep_case=True),
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
)
for field in metadata.fields
],
Expand Down
128 changes: 127 additions & 1 deletion tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ def test_iceberg_nested_fields(


@pytest.mark.skip(
reason="SNOW-1748140: Need to handle structured types in datatype_mapper"
reason="SNOW-1819531: Error in _contains_external_cte_ref when analyzing lqb"
)
def test_struct_dtype_iceberg_lqb(
structured_type_session, local_testing_mode, structured_type_support
Expand Down Expand Up @@ -970,3 +970,129 @@ def test_structured_type_print_schema(
df._format_schema(1, translate_columns={'"MAP"': '"map"'})
== 'root\n |-- "map": MapType (nullable = True)'
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="local testing does not fully support structured types yet.",
)
def test_structured_type_schema_expression(
structured_type_session, local_testing_mode, structured_type_support
):
if not structured_type_support:
pytest.skip("Test requires structured type support.")

table_name = f"snowpark_schema_expresion_test_{uuid.uuid4().hex[:5]}".upper()
non_null_table_name = (
f"snowpark_schema_expresion_nonnull_test_{uuid.uuid4().hex[:5]}".upper()
)
nested_table_name = (
f"snowpark_schema_expresion_nested_test_{uuid.uuid4().hex[:5]}".upper()
)

expected_schema = StructType(
[
StructField(
"MAP",
MapType(StringType(), DoubleType(), structured=True),
nullable=True,
),
StructField("ARR", ArrayType(DoubleType(), structured=True), nullable=True),
StructField(
"OBJ",
StructType(
[
StructField("FIELD1", StringType(), nullable=True),
StructField("FIELD2", DoubleType(), nullable=True),
],
structured=True,
),
nullable=True,
),
]
)

expected_non_null_schema = StructType(
[
StructField(
"MAP",
MapType(StringType(), DoubleType(), structured=True),
nullable=False,
),
StructField(
"ARR", ArrayType(DoubleType(), structured=True), nullable=False
),
StructField(
"OBJ",
StructType(
[
StructField("FIELD1", StringType(), nullable=False),
StructField("FIELD2", DoubleType(), nullable=False),
],
structured=True,
),
nullable=False,
),
]
)

expected_nested_schema = StructType(
[
StructField(
"MAP",
MapType(
StringType(),
StructType(
[StructField("ARR", ArrayType(DoubleType(), structured=True))],
structured=True,
),
structured=True,
),
)
]
)

try:
# SNOW-1819428: Nullability doesn't seem to be respected when creating
# a structured type dataframe so use a table instead.
structured_type_session.sql(
f"create table {table_name} (MAP MAP(VARCHAR, DOUBLE), ARR ARRAY(DOUBLE), "
"OBJ OBJECT(FIELD1 VARCHAR, FIELD2 DOUBLE))"
).collect()
structured_type_session.sql(
f"create table {non_null_table_name} (MAP MAP(VARCHAR, DOUBLE) NOT NULL, "
"ARR ARRAY(DOUBLE) NOT NULL, OBJ OBJECT(FIELD1 VARCHAR NOT NULL, FIELD2 "
"DOUBLE NOT NULL) NOT NULL)"
).collect()
structured_type_session.sql(
f"create table {nested_table_name} (MAP MAP(VARCHAR, OBJECT(ARR ARRAY(DOUBLE))))"
).collect()

table = structured_type_session.table(table_name)
non_null_table = structured_type_session.table(non_null_table_name)
nested_table = structured_type_session.table(nested_table_name)

assert table.schema == expected_schema
assert non_null_table.schema == expected_non_null_schema
assert nested_table.schema == expected_nested_schema

# Dataframe.union forces a schema_expression call
assert table.union(table).schema == expected_schema
# Functions used in schema generation don't respect nested nullability so compare query string instead
non_null_union = non_null_table.union(non_null_table)
assert non_null_union._plan.schema_query == (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also create a test case for nested array and object? like to_array(... to_array(...))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

"( SELECT object_construct_keep_null('a' :: STRING (16777216), 0 :: DOUBLE) :: "
'MAP(STRING(16777216), DOUBLE) AS "MAP", to_array(0 :: DOUBLE) :: ARRAY(DOUBLE) AS "ARR",'
" object_construct_keep_null('FIELD1', 'a' :: STRING (16777216), 'FIELD2', 0 :: "
'DOUBLE) :: OBJECT(FIELD1 STRING(16777216), FIELD2 DOUBLE) AS "OBJ") UNION ( SELECT '
"object_construct_keep_null('a' :: STRING (16777216), 0 :: DOUBLE) :: "
'MAP(STRING(16777216), DOUBLE) AS "MAP", to_array(0 :: DOUBLE) :: ARRAY(DOUBLE) AS "ARR", '
"object_construct_keep_null('FIELD1', 'a' :: STRING (16777216), 'FIELD2', 0 :: "
'DOUBLE) :: OBJECT(FIELD1 STRING(16777216), FIELD2 DOUBLE) AS "OBJ")'
)

assert nested_table.union(nested_table).schema == expected_nested_schema
finally:
Utils.drop_table(structured_type_session, table_name)
Utils.drop_table(structured_type_session, non_null_table_name)
Utils.drop_table(structured_type_session, nested_table_name)
1 change: 1 addition & 0 deletions tests/unit/test_datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,4 @@ def test_schema_expression():
schema_expression(VectorType(float, 3), False)
== "[0.0, 1.0, 2.0] :: VECTOR(float,3)"
)
assert schema_expression(StructType([]), False) == "to_object(parse_json('{}'))"
Loading