Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EXP] GTFS protobuf experiment #299

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions tl/tt/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ func FromCsv(val any, strv string) error {
*vf = v
case canFromCsvString:
if err := vf.FromCsv(strv); err != nil {
p = errors.New("field not scannable")
p = errors.New("field not scannable, FromCsv failed")
}
case canScan:
if err := vf.Scan(strv); err != nil {
p = errors.New("field not scannable")
p = errors.New("field not scannable, Scan failed")
}
default:
p = errors.New("field not scannable")
Expand Down
7 changes: 7 additions & 0 deletions tlcsv/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ func GetString(ent tl.Entity, key string) (string, error) {
}

// Loading: fast and reflect paths //
func LoadRow(ent any, row Row) []error {
return loadRow(ent, row)
}

// loadRow selects the fastest method for loading an entity.
func loadRow(ent any, row Row) []error {
Expand Down Expand Up @@ -126,6 +129,8 @@ func loadRowReflect(ent interface{}, row Row) []error {
strv = row.Row[i]
}
fieldInfo, ok := fmap[fieldName]
// fmt.Printf("FIELD: %s\n", fieldName)

// Add to extra fields if there's no struct tag
if !ok {
if extEnt, ok2 := ent.(tl.EntityWithExtra); ok2 {
Expand All @@ -145,6 +150,7 @@ func loadRowReflect(ent interface{}, row Row) []error {
}
continue
}

// Handle different known types
fieldValue := reflectx.FieldByIndexes(entValue, fieldInfo.Index).Addr().Interface()
if err := tt.FromCsv(fieldValue, strv); err != nil {
Expand All @@ -159,6 +165,7 @@ func loadRowReflect(ent interface{}, row Row) []error {
}
}
}
// fmt.Printf("ENT DONE: %T %#v\n", ent, ent)
return errs
}

Expand Down
94 changes: 94 additions & 0 deletions tlpb/codegen_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package tlpb

import (
"context"
"fmt"
"os"
"strings"
"testing"

"github.com/bufbuild/protocompile"

Check failure on line 10 in tlpb/codegen_test.go

View workflow job for this annotation

GitHub Actions / test (ubuntu-latest)

no required module provides package github.com/bufbuild/protocompile; to add it:

Check failure on line 10 in tlpb/codegen_test.go

View workflow job for this annotation

GitHub Actions / test (macos-latest)

no required module provides package github.com/bufbuild/protocompile; to add it:
"google.golang.org/protobuf/reflect/protoreflect"
)

func TestCodegen(t *testing.T) {
compiler := protocompile.Compiler{
Resolver: &protocompile.SourceResolver{},
}
files, err := compiler.Compile(context.Background(), "gtfs.proto")
if err != nil {
t.Fatal(err)
}
outf, err := os.Create("gtfs/gtfs.go")
if err != nil {
t.Fatal(err)
}
defer outf.Close()

outf.WriteString("package gtfs\n\n")
outf.WriteString("type EnumValue int32\n\n")

for _, lf := range files {
// fmt.Printf("file %#v\n", file)
enums := lf.Enums()
for i := 0; i < enums.Len(); i++ {
en := enums.Get(i)
outf.WriteString(fmt.Sprintf("type %s int32\n\n", en.Name()))
}
msgs := lf.Messages()
for i := 0; i < msgs.Len(); i++ {
msg := msgs.Get(i)
fields := msg.Fields()
if fields.Len() == 1 && fields.Get(0).Name() == "val" {
field := fields.Get(0)
outf.WriteString(fmt.Sprintf("type %s struct { Option[%s] }\n\n", msg.Name(), mapKind(field)))
continue
}

outf.WriteString(fmt.Sprintf("type %s struct {\n", msg.Name()))
for j := 0; j < fields.Len(); j++ {
field := fields.Get(j)
fieldName := toCamelCase(string(field.Name()))
fieldKind := mapKind(field)
switch fieldKind {
case "DatabaseEntity":
outf.WriteString("\tDatabaseEntity\n")
default:
outf.WriteString(fmt.Sprintf("\t%s %s\n", fieldName, fieldKind))
}

}
outf.WriteString("}\n\n")
}
}
}

func mapKind(field protoreflect.FieldDescriptor) string {
fieldKind := field.Kind().String()
switch fieldKind {
case "enum":
fieldKind = string(field.Enum().Name())
case "double":
fieldKind = "float64"
case "float":
fieldKind = "float32"
}
if fmsg := field.Message(); fmsg != nil {
fieldKind = string(fmsg.Name())
}
return fieldKind
}

func toCamelCase(v string) string {
a := strings.Split(v, "_")
for i := 0; i < len(a); i++ {
s := a[i]
if s == "id" {
s = "ID"
} else {
s = strings.ToUpper(s[0:1]) + s[1:]
}
a[i] = s
}
return strings.Join(a, "")
}
149 changes: 149 additions & 0 deletions tlpb/csv_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package tlpb

import (
"encoding/json"
"fmt"
"testing"

"github.com/stretchr/testify/assert"

"github.com/interline-io/transitland-lib/internal/testutil"
"github.com/interline-io/transitland-lib/tl"
"github.com/interline-io/transitland-lib/tlcsv"
"github.com/interline-io/transitland-lib/tlpb/gtfs"
"github.com/interline-io/transitland-lib/tlpb/pb"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/reflect/protoreflect"
)

func printFirst(v []any) {
if len(v) == 0 {
return
}
fmt.Println(toJson(v[0]))
}
func printAll(v []any) {
for _, ent := range v {
fmt.Println(toJson(ent))
}
}

func pbJson(v protoreflect.ProtoMessage) string {
jj, _ := protojson.Marshal(v)
return string(jj)
}

func toJson(v any) string {
jj, _ := json.Marshal(v)
return string(jj)
}

var TESTFILE = ""
var TESTTABLE = ""

func init() {
TESTFILE = testutil.RelPath("test/data/external/bart.zip")
TESTTABLE = "stops.txt"
}

//////////////////

func TestReadPB(t *testing.T) {
// ents, err := ReadPB(TESTFILE)
// if err != nil {
// t.Fatal(err)
// }
// for _, ent := range ents {
// fmt.Println(ent)
// }
}

func BenchmarkReadPB(b *testing.B) {
for n := 0; n < b.N; n++ {
ReadPB(TESTFILE)
}
}

func ReadPB(fn string) ([]any, error) {
a := tlcsv.NewZipAdapter(fn)
if err := a.Open(); err != nil {
panic(err)
}
var ret []any
err := a.ReadRows(TESTTABLE, func(row tlcsv.Row) {
ent := &pb.Stop{}
if errs := tlcsv.LoadRow(ent, row); errs != nil {
for _, err := range errs {
panic(err)
}
}
ret = append(ret, ent)
})
return ret, err
}

//////////////////

func TestReadTT(t *testing.T) {
ents, err := ReadTT(TESTFILE)
assert.NoError(t, err)
printAll(ents)
}

func BenchmarkReadTT(b *testing.B) {
for n := 0; n < b.N; n++ {
a, _ := ReadTT(TESTFILE)
printFirst(a)
}
}

func ReadTT(fn string) ([]any, error) {
a := tlcsv.NewZipAdapter(fn)
if err := a.Open(); err != nil {
panic(err)
}
var ret []any
err := a.ReadRows(TESTTABLE, func(row tlcsv.Row) {
ent := tl.Stop{}
if errs := tlcsv.LoadRow(&ent, row); errs != nil {
for _, err := range errs {
panic(err)
}
}
ret = append(ret, ent)
})
return ret, err
}

//////////////////

func TestReadG(t *testing.T) {
ents, err := ReadG(TESTFILE)
assert.NoError(t, err)
printAll(ents)
}

func BenchmarkReadG(b *testing.B) {
for n := 0; n < b.N; n++ {
a, _ := ReadG(TESTFILE)
printFirst(a)
}
}

func ReadG(fn string) ([]any, error) {
a := tlcsv.NewZipAdapter(fn)
if err := a.Open(); err != nil {
panic(err)
}
var ret []any
err := a.ReadRows(TESTTABLE, func(row tlcsv.Row) {
ent := gtfs.Stop{}
if errs := tlcsv.LoadRow(&ent, row); errs != nil {
for _, err := range errs {
panic(err)
}
}
ret = append(ret, ent)
})
return ret, err
}
Loading
Loading