diff --git a/proto/column.go b/proto/column.go index a9224f2..d9a46f3 100644 --- a/proto/column.go +++ b/proto/column.go @@ -2,6 +2,7 @@ package proto import ( "fmt" + "strconv" "strings" "github.com/go-faster/errors" @@ -83,30 +84,58 @@ func (c ColumnType) Base() ColumnType { return c[:start] } +// reduces Decimal(P, ...) to Decimal32/Decimal64/Decimal128/Decimal256 +// returns c if any errors occur during conversion +func (c ColumnType) decimalDowncast() ColumnType { + if c.Base() != ColumnTypeDecimal { + return c + } + elem := c.Elem() + precStr, _, _ := strings.Cut(string(elem), ",") + precStr = strings.TrimSpace(precStr) + prec, err := strconv.Atoi(precStr) + if err != nil { + return c + } + switch { + case prec < 10: + return ColumnTypeDecimal32 + case prec < 19: + return ColumnTypeDecimal64 + case prec < 39: + return ColumnTypeDecimal128 + case prec < 77: + return ColumnTypeDecimal256 + default: + return c + } +} + // Conflicts reports whether two types conflict. func (c ColumnType) Conflicts(b ColumnType) bool { if c == b { return false } - { - a := c - if b.Base() == ColumnTypeEnum8 || b.Base() == ColumnTypeEnum16 { - a, b = b, a - } - switch { - case a.Base() == ColumnTypeEnum8 && b == ColumnTypeInt8: - return false - case a.Base() == ColumnTypeEnum16 && b == ColumnTypeInt16: - return false - } + cBase := c.Base() + bBase := b.Base() + if (cBase == ColumnTypeEnum8 && b == ColumnTypeInt8) || + (cBase == ColumnTypeEnum16 && b == ColumnTypeInt16) || + (bBase == ColumnTypeEnum8 && c == ColumnTypeInt8) || + (bBase == ColumnTypeEnum16 && c == ColumnTypeInt16) { + return false + } + if cBase == ColumnTypeDecimal || bBase == ColumnTypeDecimal { + return c.decimalDowncast() != b.decimalDowncast() } - if c.Base() != b.Base() { + if cBase != bBase { return true } if c.normalizeCommas() == b.normalizeCommas() { return false } - switch c.Base() { + switch cBase { + case ColumnTypeArray, ColumnTypeNullable, ColumnTypeLowCardinality: + return c.Elem().Conflicts(b.Elem()) case ColumnTypeDateTime, ColumnTypeDateTime64: // TODO(ernado): improve check return false diff --git a/proto/column_test.go b/proto/column_test.go index ccad996..970fb59 100644 --- a/proto/column_test.go +++ b/proto/column_test.go @@ -60,6 +60,8 @@ func TestColumnType_Elem(t *testing.T) { {A: "Map(String,String)", B: "Map(String, String)"}, {A: "Enum8('increment' = 1, 'gauge' = 2)", B: "Int8"}, {A: "Int8", B: "Enum8('increment' = 1, 'gauge' = 2)"}, + {A: "Decimal256", B: "Decimal(76, 38)"}, + {A: "Nullable(Decimal256)", B: "Nullable(Decimal(76, 38))"}, } { assert.False(t, tt.A.Conflicts(tt.B), "%s ~ %s", tt.A, tt.B, @@ -76,6 +78,7 @@ func TestColumnType_Elem(t *testing.T) { {A: ColumnTypeArray.Sub(ColumnTypeInt32), B: ColumnTypeArray.Sub(ColumnTypeInt64)}, {A: "Map(String,String)", B: "Map(String,Int32)"}, {A: "Enum16('increment' = 1, 'gauge' = 2)", B: "Int8"}, + {A: "Int8", B: "Enum16('increment' = 1, 'gauge' = 2)"}, } { assert.True(t, tt.A.Conflicts(tt.B), "%s !~ %s", tt.A, tt.B,