Skip to content

Commit

Permalink
futher fixing type inferring (ydb-platform#7456)
Browse files Browse the repository at this point in the history
  • Loading branch information
evanevanevanevannnn authored Aug 5, 2024
1 parent da7d9ee commit e4c34ae
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,21 @@ namespace NKikimr::NExternalSource::NObjectStorage::NInference {

namespace {

bool ArrowToYdbType(Ydb::Type& optionalType, const arrow::DataType& type) {
auto& resType = *optionalType.mutable_optional_type()->mutable_item();
bool ShouldBeOptional(const arrow::DataType& type) {
switch (type.id()) {
case arrow::Type::NA:
case arrow::Type::STRING:
case arrow::Type::BINARY:
case arrow::Type::LARGE_BINARY:
case arrow::Type::FIXED_SIZE_BINARY:
return false;
default:
return true;
}
}

bool ArrowToYdbType(Ydb::Type& maybeOptionalType, const arrow::DataType& type) {
auto& resType = ShouldBeOptional(type) ? *maybeOptionalType.mutable_optional_type()->mutable_item() : maybeOptionalType;
switch (type.id()) {
case arrow::Type::NA:
resType.set_type_id(Ydb::Type::UTF8);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ TEST_F(ArrowInferenceTest, csv_simple) {
ASSERT_EQ(fields[0].type().optional_type().item().type_id(), Ydb::Type::INT64);
ASSERT_EQ(fields[0].name(), "A");

ASSERT_TRUE(fields[1].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[1].type().optional_type().item().type_id(), Ydb::Type::UTF8);
ASSERT_TRUE(fields[1].type().has_type_id());
ASSERT_EQ(fields[1].type().type_id(), Ydb::Type::UTF8);
ASSERT_EQ(fields[1].name(), "B");

ASSERT_TRUE(fields[2].type().optional_type().item().has_type_id());
Expand Down Expand Up @@ -133,8 +133,8 @@ TEST_F(ArrowInferenceTest, tsv_simple) {
ASSERT_EQ(fields[0].type().optional_type().item().type_id(), Ydb::Type::INT64);
ASSERT_EQ(fields[0].name(), "A");

ASSERT_TRUE(fields[1].type().optional_type().item().has_type_id());
ASSERT_EQ(fields[1].type().optional_type().item().type_id(), Ydb::Type::UTF8);
ASSERT_TRUE(fields[1].type().has_type_id());
ASSERT_EQ(fields[1].type().type_id(), Ydb::Type::UTF8);
ASSERT_EQ(fields[1].name(), "B");

ASSERT_TRUE(fields[2].type().optional_type().item().has_type_id());
Expand Down
10 changes: 5 additions & 5 deletions ydb/tests/fq/s3/test_s3_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_inference(self, kikimr, s3, client, unique_prefix):
assert result_set.columns[0].name == "Date"
assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.DATE
assert result_set.columns[1].name == "Fruit"
assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[1].type.type_id == ydb.Type.UTF8
assert result_set.columns[2].name == "Price"
assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
assert result_set.columns[3].name == "Weight"
Expand Down Expand Up @@ -176,9 +176,9 @@ def test_inference_null_column(self, kikimr, s3, client, unique_prefix):
logging.debug(str(result_set))
assert len(result_set.columns) == 3
assert result_set.columns[0].name == "Fruit"
assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[0].type.type_id == ydb.Type.UTF8
assert result_set.columns[1].name == "Missing column"
assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[1].type.type_id == ydb.Type.UTF8
assert result_set.columns[2].name == "Price"
assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
assert len(result_set.rows) == 3
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_inference_optional_types(self, kikimr, s3, client, unique_prefix):
assert result_set.columns[0].name == "Date"
assert result_set.columns[0].type.optional_type.item.type_id == ydb.Type.DATE
assert result_set.columns[1].name == "Fruit"
assert result_set.columns[1].type.optional_type.item.type_id == ydb.Type.UTF8
assert result_set.columns[1].type.type_id == ydb.Type.UTF8
assert result_set.columns[2].name == "Price"
assert result_set.columns[2].type.optional_type.item.type_id == ydb.Type.INT64
assert result_set.columns[3].name == "Weight"
Expand All @@ -248,7 +248,7 @@ def test_inference_optional_types(self, kikimr, s3, client, unique_prefix):
assert result_set.rows[1].items[2].int64_value == 2
assert result_set.rows[1].items[3].int64_value == 22
assert result_set.rows[2].items[0].uint32_value == 19849
assert result_set.rows[2].items[1].null_flag_value == NullValue.NULL_VALUE
assert result_set.rows[2].items[1].text_value == ""
assert result_set.rows[2].items[2].int64_value == 15
assert result_set.rows[2].items[3].int64_value == 33

Expand Down

0 comments on commit e4c34ae

Please sign in to comment.