From 195355a8d295e1b21ed1cf95640cfcf37dd14238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Sun, 13 Oct 2024 04:55:34 +0000 Subject: [PATCH] fix: infer Nullable/Array/LowCardinality with reflection helps re #152 couple fixes needed to get Array(Nullable(...)) & Nullable(DateTime64(...)) cases to work --- proto/cmd/ch-gen-col/infer.go.tmpl | 4 -- proto/col_auto.go | 59 ++++++++++++----- proto/col_auto_gen.go | 100 ----------------------------- proto/col_auto_test.go | 2 + proto/col_datetime64.go | 4 ++ proto/col_nullable.go | 6 ++ 6 files changed, 55 insertions(+), 120 deletions(-) diff --git a/proto/cmd/ch-gen-col/infer.go.tmpl b/proto/cmd/ch-gen-col/infer.go.tmpl index 01c46ed8..1ffbd85b 100644 --- a/proto/cmd/ch-gen-col/infer.go.tmpl +++ b/proto/cmd/ch-gen-col/infer.go.tmpl @@ -6,10 +6,6 @@ package proto func inferGenerated(t ColumnType) Column { switch t { {{- range . }} - case ColumnTypeArray.Sub({{ .ColumnType }}): - return new({{ .Type }}).Array() - case ColumnTypeNullable.Sub({{ .ColumnType }}): - return new({{ .Type }}).Nullable() case {{ .ColumnType }}: return new({{ .Type }}) {{- end }} diff --git a/proto/col_auto.go b/proto/col_auto.go index c72615da..d5f614de 100644 --- a/proto/col_auto.go +++ b/proto/col_auto.go @@ -1,6 +1,7 @@ package proto import ( + "reflect" "strconv" "strings" @@ -37,20 +38,8 @@ func (c *ColAuto) Infer(t ColumnType) error { switch t { case ColumnTypeNothing: c.Data = new(ColNothing) - case ColumnTypeNullable.Sub(ColumnTypeNothing): - c.Data = new(ColNothing).Nullable() - case ColumnTypeArray.Sub(ColumnTypeNothing): - c.Data = new(ColNothing).Array() case ColumnTypeString: c.Data = new(ColStr) - case ColumnTypeArray.Sub(ColumnTypeString): - c.Data = new(ColStr).Array() - case ColumnTypeNullable.Sub(ColumnTypeString): - c.Data = new(ColStr).Nullable() - case ColumnTypeLowCardinality.Sub(ColumnTypeString): - c.Data = new(ColStr).LowCardinality() - case ColumnTypeArray.Sub(ColumnTypeLowCardinality.Sub(ColumnTypeString)): - c.Data = new(ColStr).LowCardinality().Array() case ColumnTypeBool: c.Data = new(ColBool) case ColumnTypeDateTime: @@ -61,12 +50,50 @@ func (c *ColAuto) Infer(t ColumnType) error { c.Data = NewMap[string, string](new(ColStr), new(ColStr)) case ColumnTypeUUID: c.Data = new(ColUUID) - case ColumnTypeArray.Sub(ColumnTypeUUID): - c.Data = new(ColUUID).Array() - case ColumnTypeNullable.Sub(ColumnTypeUUID): - c.Data = new(ColUUID).Nullable() default: switch t.Base() { + case ColumnTypeArray: + inner := new(ColAuto) + if err := inner.Infer(t.Elem()); err != nil { + return errors.Wrap(err, "array") + } + innerValue := reflect.ValueOf(inner.Data) + arrayMethod := innerValue.MethodByName("Array") + if arrayMethod.IsValid() && arrayMethod.Type().NumOut() == 1 { + if col, ok := arrayMethod.Call(nil)[0].Interface().(Column); ok { + c.Data = col + c.DataType = t + return nil + } + } + case ColumnTypeNullable: + inner := new(ColAuto) + if err := inner.Infer(t.Elem()); err != nil { + return errors.Wrap(err, "nullable") + } + innerValue := reflect.ValueOf(inner.Data) + nullableMethod := innerValue.MethodByName("Nullable") + if nullableMethod.IsValid() && nullableMethod.Type().NumOut() == 1 { + if col, ok := nullableMethod.Call(nil)[0].Interface().(Column); ok { + c.Data = col + c.DataType = t + return nil + } + } + case ColumnTypeLowCardinality: + inner := new(ColAuto) + if err := inner.Infer(t.Elem()); err != nil { + return errors.Wrap(err, "low cardinality") + } + innerValue := reflect.ValueOf(inner.Data) + lowCardinalityMethod := innerValue.MethodByName("LowCardinality") + if lowCardinalityMethod.IsValid() && lowCardinalityMethod.Type().NumOut() == 1 { + if col, ok := lowCardinalityMethod.Call(nil)[0].Interface().(Column); ok { + c.Data = col + c.DataType = t + return nil + } + } case ColumnTypeDateTime: v := new(ColDateTime) if err := v.Infer(t); err != nil { diff --git a/proto/col_auto_gen.go b/proto/col_auto_gen.go index 70928c65..b297d927 100644 --- a/proto/col_auto_gen.go +++ b/proto/col_auto_gen.go @@ -4,154 +4,54 @@ package proto func inferGenerated(t ColumnType) Column { switch t { - case ColumnTypeArray.Sub(ColumnTypeFloat32): - return new(ColFloat32).Array() - case ColumnTypeNullable.Sub(ColumnTypeFloat32): - return new(ColFloat32).Nullable() case ColumnTypeFloat32: return new(ColFloat32) - case ColumnTypeArray.Sub(ColumnTypeFloat64): - return new(ColFloat64).Array() - case ColumnTypeNullable.Sub(ColumnTypeFloat64): - return new(ColFloat64).Nullable() case ColumnTypeFloat64: return new(ColFloat64) - case ColumnTypeArray.Sub(ColumnTypeIPv4): - return new(ColIPv4).Array() - case ColumnTypeNullable.Sub(ColumnTypeIPv4): - return new(ColIPv4).Nullable() case ColumnTypeIPv4: return new(ColIPv4) - case ColumnTypeArray.Sub(ColumnTypeIPv6): - return new(ColIPv6).Array() - case ColumnTypeNullable.Sub(ColumnTypeIPv6): - return new(ColIPv6).Nullable() case ColumnTypeIPv6: return new(ColIPv6) - case ColumnTypeArray.Sub(ColumnTypeDate): - return new(ColDate).Array() - case ColumnTypeNullable.Sub(ColumnTypeDate): - return new(ColDate).Nullable() case ColumnTypeDate: return new(ColDate) - case ColumnTypeArray.Sub(ColumnTypeDate32): - return new(ColDate32).Array() - case ColumnTypeNullable.Sub(ColumnTypeDate32): - return new(ColDate32).Nullable() case ColumnTypeDate32: return new(ColDate32) - case ColumnTypeArray.Sub(ColumnTypeInt8): - return new(ColInt8).Array() - case ColumnTypeNullable.Sub(ColumnTypeInt8): - return new(ColInt8).Nullable() case ColumnTypeInt8: return new(ColInt8) - case ColumnTypeArray.Sub(ColumnTypeUInt8): - return new(ColUInt8).Array() - case ColumnTypeNullable.Sub(ColumnTypeUInt8): - return new(ColUInt8).Nullable() case ColumnTypeUInt8: return new(ColUInt8) - case ColumnTypeArray.Sub(ColumnTypeInt16): - return new(ColInt16).Array() - case ColumnTypeNullable.Sub(ColumnTypeInt16): - return new(ColInt16).Nullable() case ColumnTypeInt16: return new(ColInt16) - case ColumnTypeArray.Sub(ColumnTypeUInt16): - return new(ColUInt16).Array() - case ColumnTypeNullable.Sub(ColumnTypeUInt16): - return new(ColUInt16).Nullable() case ColumnTypeUInt16: return new(ColUInt16) - case ColumnTypeArray.Sub(ColumnTypeInt32): - return new(ColInt32).Array() - case ColumnTypeNullable.Sub(ColumnTypeInt32): - return new(ColInt32).Nullable() case ColumnTypeInt32: return new(ColInt32) - case ColumnTypeArray.Sub(ColumnTypeUInt32): - return new(ColUInt32).Array() - case ColumnTypeNullable.Sub(ColumnTypeUInt32): - return new(ColUInt32).Nullable() case ColumnTypeUInt32: return new(ColUInt32) - case ColumnTypeArray.Sub(ColumnTypeInt64): - return new(ColInt64).Array() - case ColumnTypeNullable.Sub(ColumnTypeInt64): - return new(ColInt64).Nullable() case ColumnTypeInt64: return new(ColInt64) - case ColumnTypeArray.Sub(ColumnTypeUInt64): - return new(ColUInt64).Array() - case ColumnTypeNullable.Sub(ColumnTypeUInt64): - return new(ColUInt64).Nullable() case ColumnTypeUInt64: return new(ColUInt64) - case ColumnTypeArray.Sub(ColumnTypeInt128): - return new(ColInt128).Array() - case ColumnTypeNullable.Sub(ColumnTypeInt128): - return new(ColInt128).Nullable() case ColumnTypeInt128: return new(ColInt128) - case ColumnTypeArray.Sub(ColumnTypeUInt128): - return new(ColUInt128).Array() - case ColumnTypeNullable.Sub(ColumnTypeUInt128): - return new(ColUInt128).Nullable() case ColumnTypeUInt128: return new(ColUInt128) - case ColumnTypeArray.Sub(ColumnTypeInt256): - return new(ColInt256).Array() - case ColumnTypeNullable.Sub(ColumnTypeInt256): - return new(ColInt256).Nullable() case ColumnTypeInt256: return new(ColInt256) - case ColumnTypeArray.Sub(ColumnTypeUInt256): - return new(ColUInt256).Array() - case ColumnTypeNullable.Sub(ColumnTypeUInt256): - return new(ColUInt256).Nullable() case ColumnTypeUInt256: return new(ColUInt256) - case ColumnTypeArray.Sub(ColumnTypeFixedString.With("8")): - return new(ColFixedStr8).Array() - case ColumnTypeNullable.Sub(ColumnTypeFixedString.With("8")): - return new(ColFixedStr8).Nullable() case ColumnTypeFixedString.With("8"): return new(ColFixedStr8) - case ColumnTypeArray.Sub(ColumnTypeFixedString.With("16")): - return new(ColFixedStr16).Array() - case ColumnTypeNullable.Sub(ColumnTypeFixedString.With("16")): - return new(ColFixedStr16).Nullable() case ColumnTypeFixedString.With("16"): return new(ColFixedStr16) - case ColumnTypeArray.Sub(ColumnTypeFixedString.With("32")): - return new(ColFixedStr32).Array() - case ColumnTypeNullable.Sub(ColumnTypeFixedString.With("32")): - return new(ColFixedStr32).Nullable() case ColumnTypeFixedString.With("32"): return new(ColFixedStr32) - case ColumnTypeArray.Sub(ColumnTypeFixedString.With("64")): - return new(ColFixedStr64).Array() - case ColumnTypeNullable.Sub(ColumnTypeFixedString.With("64")): - return new(ColFixedStr64).Nullable() case ColumnTypeFixedString.With("64"): return new(ColFixedStr64) - case ColumnTypeArray.Sub(ColumnTypeFixedString.With("128")): - return new(ColFixedStr128).Array() - case ColumnTypeNullable.Sub(ColumnTypeFixedString.With("128")): - return new(ColFixedStr128).Nullable() case ColumnTypeFixedString.With("128"): return new(ColFixedStr128) - case ColumnTypeArray.Sub(ColumnTypeFixedString.With("256")): - return new(ColFixedStr256).Array() - case ColumnTypeNullable.Sub(ColumnTypeFixedString.With("256")): - return new(ColFixedStr256).Nullable() case ColumnTypeFixedString.With("256"): return new(ColFixedStr256) - case ColumnTypeArray.Sub(ColumnTypeFixedString.With("512")): - return new(ColFixedStr512).Array() - case ColumnTypeNullable.Sub(ColumnTypeFixedString.With("512")): - return new(ColFixedStr512).Nullable() case ColumnTypeFixedString.With("512"): return new(ColFixedStr512) default: diff --git a/proto/col_auto_test.go b/proto/col_auto_test.go index 5c0b3f4b..f935d8b3 100644 --- a/proto/col_auto_test.go +++ b/proto/col_auto_test.go @@ -57,6 +57,8 @@ func TestColAuto_Infer(t *testing.T) { "Decimal64(2)", "Decimal128(3)", "Decimal256(4)", + "Array(Nullable(Int8))", + "Nullable(DateTime64(3))", } { r := AutoResult("foo") require.NoError(t, r.Data.(Inferable).Infer(columnType)) diff --git a/proto/col_datetime64.go b/proto/col_datetime64.go index f4d96a49..1f93a0c4 100644 --- a/proto/col_datetime64.go +++ b/proto/col_datetime64.go @@ -126,6 +126,10 @@ func (c ColDateTime64) Raw() *ColDateTime64Raw { return &ColDateTime64Raw{ColDateTime64: c} } +func (c *ColDateTime64) Nullable() *ColNullable[time.Time] { + return &ColNullable[time.Time]{Values: c} +} + func (c *ColDateTime64) Array() *ColArr[time.Time] { return &ColArr[time.Time]{Data: c} } diff --git a/proto/col_nullable.go b/proto/col_nullable.go index dfac8fa2..fd3615d2 100644 --- a/proto/col_nullable.go +++ b/proto/col_nullable.go @@ -117,6 +117,12 @@ func (c ColNullable[T]) Row(i int) Nullable[T] { } } +func (c *ColNullable[T]) Array() *ColArr[Nullable[T]] { + return &ColArr[Nullable[T]]{ + Data: c, + } +} + func (c *ColNullable[T]) Reset() { c.Nulls.Reset() c.Values.Reset()