diff --git a/example/proto2/googlev2_benchmarks_test.go b/example/proto2/googlev2_benchmarks_test.go index 47cde89..b24c1eb 100644 --- a/example/proto2/googlev2_benchmarks_test.go +++ b/example/proto2/googlev2_benchmarks_test.go @@ -69,15 +69,46 @@ func BenchmarkLazyDecodeGoogleV2(b *testing.B) { evt = createGoogleV2Event(b) def = lazyproto.NewDef(1, 2, 3, 4) ) - _ = def.NestedTag(100, 1) + _ = def.NestedTag(100, 1, 2, 3, 4, 5, 6, 7, 8, 9) data, _ := proto.Marshal(evt) b.ResetTimer() for n := 0; n < b.N; n++ { - r, _ := lazyproto.Decode(data, def) + r, _ := lazyproto.Decode(data, def) //nolint: staticcheck // benchmarking deprecated function to demonstrate the difference _ = r.Close() } } +func BenchmarkLazyDecoder(b *testing.B) { + var ( + evt = createGoogleV2Event(b) + def = lazyproto.NewDef(1, 2, 3, 4) + ) + _ = def.NestedTag(100, 1, 2, 3, 4, 5, 6, 7, 8, 9) + data, _ := proto.Marshal(evt) + b.Run("safe", func(b *testing.B) { + dec, _ := lazyproto.NewDecoder(def, lazyproto.WithMaxBufferSize(3)) + b.ResetTimer() + for n := 0; n < b.N; n++ { + r, _ := dec.Decode(data) + nest, _ := r.NestedResult(100) + discardStrings, _ = nest.StringValues(4) + _ = r.Close() + } + }) + b.Run("unsafe", func(b *testing.B) { + dec, _ := lazyproto.NewDecoder(def, lazyproto.WithMode(csproto.DecoderModeFast), lazyproto.WithMaxBufferSize(3)) + b.ResetTimer() + for n := 0; n < b.N; n++ { + r, _ := dec.Decode(data) + nest, _ := r.NestedResult(100) + discardStrings, _ = nest.StringValues(4) + _ = r.Close() + } + }) +} + +var discardStrings []string + func createGoogleV2Event(t interface{ Errorf(string, ...interface{}) }) *googlev2.BaseEvent { eventType := googlev2.EventType_EVENT_TYPE_ONE baseEvent := googlev2.BaseEvent{ diff --git a/example/proto3/googlev2_benchmarks_test.go b/example/proto3/googlev2_benchmarks_test.go index 8f51368..21e86b6 100644 --- a/example/proto3/googlev2_benchmarks_test.go +++ b/example/proto3/googlev2_benchmarks_test.go @@ -68,15 +68,45 @@ func BenchmarkLazyDecodeGoogleV2(b *testing.B) { evt = createGoogleV2Event() def = lazyproto.NewDef(1) ) - _ = def.NestedTag(5, 1) + _ = def.NestedTag(5, 1, 2, 3, 4) data, _ := proto.Marshal(evt) b.ResetTimer() for n := 0; n < b.N; n++ { - r, _ := lazyproto.Decode(data, def) + r, _ := lazyproto.Decode(data, def) //nolint: staticcheck // benchmarking deprecated function to demonstrate the difference _ = r.Close() } } +func BenchmarkLazyDecoder(b *testing.B) { + var ( + evt = createGoogleV2Event() + def = lazyproto.NewDef(1, 2, 3, 4, 5, 6, 7, 8, 9) + ) + _ = def.NestedTag(5, 1, 2, 3, 4) + data, _ := proto.Marshal(evt) + b.Run("safe", func(b *testing.B) { + dec, _ := lazyproto.NewDecoder(def) + b.ResetTimer() + for n := 0; n < b.N; n++ { + r, _ := dec.Decode(data) + discardStrings, _ = r.StringValues(4) + _ = r.Close() + } + }) + + b.Run("unsafe", func(b *testing.B) { + dec, _ := lazyproto.NewDecoder(def, lazyproto.WithMode(csproto.DecoderModeFast)) + b.ResetTimer() + for n := 0; n < b.N; n++ { + r, _ := dec.Decode(data) + discardStrings, _ = r.StringValues(4) + _ = r.Close() + } + }) +} + +var discardStrings []string + func createGoogleV2Event() *googlev2.TestEvent { event := googlev2.TestEvent{ Name: "test-event", diff --git a/lazyproto/decode.go b/lazyproto/decode.go index 8d769be..d0bc6f7 100644 --- a/lazyproto/decode.go +++ b/lazyproto/decode.go @@ -2,11 +2,21 @@ package lazyproto import ( "fmt" + "slices" + "sync" "github.com/CrowdStrike/csproto" ) var ( + // ErrNestingNotDefined is returned by [PartialDecodeResult.FieldData] when the specified tag + // was not supplied with a nested definitions. This will also return true for errors.Is(err, ErrTagNotFound) but is + // more specific because this means the tag nesting was not in the original decoder definition + ErrNestingNotDefined = fmt.Errorf("%w: the requested tag was not defined as a nested tag in the decoder", ErrTagNotFound) + // ErrTagNotDefined is returned by [PartialDecodeResult.FieldData] when the specified tag + // was not defined in the decoder. This will also return true for errors.Is(err, ErrTagNotFound) but is + // more specific because this means the tag was not in the original decoder definition + ErrTagNotDefined = fmt.Errorf("%w: the requested tag was not defined in the decoder", ErrTagNotFound) // ErrTagNotFound is returned by [PartialDecodeResult.FieldData] when the specified tag(s) do not // exist in the result. ErrTagNotFound = fmt.Errorf("the requested tag does not exist in the partial decode result") @@ -26,184 +36,195 @@ var emptyResult DecodeResult // of field values are needed, so [PartialDecodeResult] and [FieldData] only support extracting // scalar values or slices of scalar values. Consumers that need to decode entire messages will need // to use [Unmarshal] instead. -func Decode(data []byte, def Def) (res DecodeResult, err error) { +// +// Deprecated: use NewDecoder(def) and (*Decoder).Decode(data) +// For best performance, decoders should be initialized once per definition and reused when decoding data +func Decode(data []byte, def Def) (DecodeResult, error) { if len(data) == 0 || len(def) == 0 { return emptyResult, nil } - if err := def.Validate(); err != nil { + dec := &Decoder{} + result, err := dec.newBaseResult(def) + if err != nil || result == nil { return emptyResult, err } - res.m = fieldDataMapPool.Get().(map[int]*FieldData) - defer func() { - // call res.Close() on error to clean up field data - if err != nil { - _ = res.Close() - } - }() - for dec := csproto.NewDecoder(data); dec.More(); { - tag, wt, err := dec.DecodeTag() - if err != nil { - return emptyResult, err - } - var ( - dv Def - want, wantRaw bool - ) - dv, want = def.Get(tag) - _, wantRaw = def.Get(-1 * tag) - if !want && !wantRaw { - if _, err := dec.Skip(tag, wt); err != nil { - return emptyResult, err - } - continue - } - switch wt { - case csproto.WireTypeVarint, csproto.WireTypeFixed32, csproto.WireTypeFixed64: - if wantRaw { - return emptyResult, fmt.Errorf("invalid definition: raw mode only supported for length-delimited fields (tag=%d, wire type=%s)", tag, wt) - } - // varint, fixed32, and fixed64 could be multiple Go types so - // grab the raw bytes and defer interpreting them to the consumer/caller - // . varint -> int32, int64, uint32, uint64, sint32, sint64, bool, enum - // . fixed32 -> int32, uint32, float32 - // . fixed64 -> int32, uint64, float64 - val, err := dec.Skip(tag, wt) - if err != nil { - return emptyResult, err - } - fd, err := res.getOrAddFieldData(tag, wt) - if err != nil { - return emptyResult, err - } - // Skip() returns the entire field contents, both the tag and the value, so we need to skip past the tag - val = val[csproto.SizeOfTagKey(tag):] - fd.data = append(fd.data, val) - case csproto.WireTypeLengthDelimited: - val, err := dec.DecodeBytes() - if err != nil { - return emptyResult, err - } - if len(dv) > 0 { - // recurse - subResult, err := Decode(val, dv) - if err != nil { - return emptyResult, err - } - fd, err := res.getOrAddFieldData(tag, wt) - if err != nil { - return emptyResult, err - } - fd.data = append(fd.data, subResult.m) - } else { - fd, err := res.getOrAddFieldData(tag, wt) - if err != nil { - return emptyResult, err - } - fd.data = append(fd.data, val) - } - if wantRaw { - fd, err := res.getOrAddFieldData(-1*tag, wt) - if err != nil { - return emptyResult, err - } - fd.data = append(fd.data, val) - } - default: - return emptyResult, fmt.Errorf("read unknown/unsupported protobuf wire type (%v)", wt) - } + for i := range result.flatData { + result.flatData[i] = new(FieldData) } - return res, nil + err = result.decode(slices.Clone(data)) + if err != nil { + return emptyResult, err + } + return *result, nil } -// DecodeResult holds a (possibly nested) mapping of integer field tags to FieldData instances -// which can be used to retrieve typed values for specific Protobuf message fields. -type DecodeResult struct { - m map[int]*FieldData +// Option is a functional option that allows configuring new Decoders +type Option func(*Decoder) error + +// WithMaxBufferSize will prevent slices greater than n from being cached for future reuse +func WithMaxBufferSize(n int) Option { + return func(d *Decoder) error { + if n < 0 { + return fmt.Errorf("WithMaxBuffer: negative max buffer size is not allowed") + } + d.maxBuffer = n + return nil + } } -// Close releases all internal resources held by r. +// WithBufferFilterFunc can prevent slices from being cached for future reuse. // -// Consumers should always call Close() on instances returned by [Decode] to ensure that internal -// resources are cleaned up. -func (r *DecodeResult) Close() error { - for k, v := range r.m { - if v != nil { - v.close() +// Under the hood, decoders will use sync Pools to avoid allocations for slices. +// If messages intermitently contain very large slices it can cause all of the cached +// decode results to eventually use the max slice value. +// This may result in larger memory use than is necessary. To avoid such a scenario, +// clients may set a BufferFilterFunction to cleanup slices based on their capacity. +// +// fn accepts the current slice capacity and should return the target capacity. +// Negative capacities will be ignored. +func WithBufferFilterFunc(fn func(capacity int) int) Option { + return func(d *Decoder) error { + if fn == nil { + return fmt.Errorf("WithMaxBufferFunc: nil function is not allowed") } - delete(r.m, k) + d.filter = fn + return nil } - if r.m != nil { - fieldDataMapPool.Put(r.m) +} + +// WithMode will set the mode of operation for the decoder +// +// - DecoderModeSafe will create copies of the input data slice and create new slices for any returned field data results +// +// - DecoderModeFast will not reallocate input data or slices. When the mode is DecoderModeFast it is not safe to modify the +// input data slice after calling (*Decoder).Decode(data) and it is not safe to use any slices after calling (*DecodeResult).Close() +func WithMode(mode csproto.DecoderMode) Option { + return func(d *Decoder) error { + d.mode = mode + return nil } - r.m = nil - return nil } -// The FieldData method returns a FieldData instance for the specified tag "path", if it exists. +// Decoder is a lazy decoder that will reuse DecodeResults after (DecodeResult).Close() +// is called. // -// The tags parameter is a list of one or more integer field tags that act as a "path" to a particular -// field to support retreiving fields from nested messages. Each value is used to retreieve the field -// data at the corresponding level of nesting, i.e. a value of [1, 2] would return the field data for -// tag 2 within the nested data for tag 1 at the root. -func (r *DecodeResult) FieldData(tags ...int) (*FieldData, error) { - if r == nil || len(r.m) == 0 { - return nil, ErrTagNotFound - } - if len(tags) == 0 { - return nil, fmt.Errorf("at least one tag key must be specified") - } - // special case: - // - negative tag values are used to extract the raw bytes of a field, but it must be the only - // (or last) field in the path - if len(tags) > 1 { - for i := 0; i < len(tags)-1; i++ { - if tags[i] < 0 { - return nil, fmt.Errorf("invalid tag in path at index %d, negative tags must be the last (or only) path item", i) - } - } +// A decoder is unique to a given Definition and can be reused for any protobuf message using +// that definition. +// +// Decoder methods are thread safe and can be used by concurrent/parallel processes. +type Decoder struct { + pool *sync.Pool + filter func(int) int + maxBuffer int + mode csproto.DecoderMode +} + +// NewDecoder creates a new Decoder for a given Def. See NewDef for defining a definition. +// +// NewDecoder will return an error if the definition or options are invalid +func NewDecoder(def Def, opts ...Option) (*Decoder, error) { + err := def.Validate() + if err != nil { + return nil, fmt.Errorf("invalid definition: %w", err) } - var ( - fd *FieldData - ok = true - ) - for dd := r.m; ok && len(tags) > 0; { - fd, ok = dd[tags[0]] - if !ok || len(fd.data) == 0 { - return nil, ErrTagNotFound - } - tags = tags[1:] - if len(tags) == 0 { - return fd, nil + + dec := &Decoder{ + pool: new(sync.Pool), + maxBuffer: -1, + mode: csproto.DecoderModeSafe, + } + for _, opt := range opts { + err := opt(dec) + if err != nil { + return nil, fmt.Errorf("invalid option: %w", err) } - dd, ok = fd.data[0].(map[int]*FieldData) } - return nil, ErrTagNotFound + + base, err := dec.newBaseResult(def) + if err != nil { + return nil, err + } + dec.pool.New = func() any { + return base.clone() + } + return dec, nil +} + +// Decode will convert the raw []byte slice to a DecodeResult +func (dec *Decoder) Decode(data []byte) (*DecodeResult, error) { + if dec.mode == csproto.DecoderModeSafe { + return dec.decodeWithPool(slices.Clone(data)) + } + return dec.decodeWithPool(data) +} + +func (dec *Decoder) decodeWithPool(data []byte) (*DecodeResult, error) { + if len(data) == 0 { + return nil, nil + } + res, ok := dec.pool.Get().(*DecodeResult) + if !ok { + // This will only happen if the decoder was initialized outside of NewDecoder + return nil, fmt.Errorf("invalid decoder") + } + err := res.decode(data) + if err != nil { + // call res.Close() on error to clean up field data + _ = res.Close() + return nil, err + } + return res, nil } -// getOrAddFieldData is a helper to consolidate the logic of checking if a given tag exists in the -// field data map and adding it if not. -func (r *DecodeResult) getOrAddFieldData(tag int, wt csproto.WireType) (*FieldData, error) { - // first key: add a new entry and return - if len(r.m) == 0 { - fd := &FieldData{ - wt: wt, +// newBaseResult creates a new DecodeResult object based on the given definition +// all other initialization of this DecodeResult is done by cloning the resulting object +func (dec *Decoder) newBaseResult(def Def) (*DecodeResult, error) { + err := def.Validate() + if err != nil { + return nil, fmt.Errorf("invalid definition: %w", err) + } + result := &DecodeResult{ + pool: dec.pool, + filter: dec.filter, + maxBuffer: dec.maxBuffer, + unsafe: dec.mode != csproto.DecoderModeSafe, + } + + // iterate over all of the flat/raw tags and any nested tags in the definition + for k, v := range def { + if k < 0 { + k *= -1 } - r.m = fieldDataMapPool.Get().(map[int]*FieldData) - r.m[tag] = fd - return fd, nil - } - // if the key doesn't exist, add a new entry - fd, exists := r.m[tag] - if !exists { - fd = &FieldData{ - wt: wt, + result.flatTags = append(result.flatTags, k) + if v == nil { + continue } - r.m[tag] = fd - } - // double-check wire type - if fd.wt != wt { - return nil, fmt.Errorf("invalid message data - repeated tag %d w/ different wire types (prev=%v, current=%v)", tag, fd.wt, wt) + result.nestedTags = append(result.nestedTags, k) } - return fd, nil + // sort and deduplicate the tags + slices.Sort(result.flatTags) + result.flatTags = slices.Compact(result.flatTags) + slices.Sort(result.nestedTags) + result.nestedTags = slices.Compact(result.nestedTags) + + // intitialize the slice length + // (we don't need to fill the slice because that will be done when it's cloned) + result.flatData = make([]*FieldData, len(result.flatTags)) + + // create decoders for any nested results + result.nestedDecoders = make([]*Decoder, len(result.nestedTags)) + for i, tag := range result.nestedTags { + nestedDec, err := NewDecoder(def[tag], func(d *Decoder) error { + d.filter = dec.filter + d.maxBuffer = dec.maxBuffer + d.mode = dec.mode + return nil + }) + if err != nil { + return nil, fmt.Errorf("invalid definition on tag: %d", tag) + } + result.nestedDecoders[i] = nestedDec + } + return result, nil } diff --git a/lazyproto/decode_result.go b/lazyproto/decode_result.go new file mode 100644 index 0000000..5fd9f65 --- /dev/null +++ b/lazyproto/decode_result.go @@ -0,0 +1,685 @@ +package lazyproto + +import ( + "fmt" + "slices" + "sync" + + "github.com/CrowdStrike/csproto" +) + +// DecodeResult holds a (possibly nested) mapping of integer field tags to FieldData instances +// which can be used to retrieve typed values for specific Protobuf message fields. +type DecodeResult struct { + pool *sync.Pool + filter func(int) int + + // flatTags and flatData are equal length, binary searches are used as an optimization + // but this is essentially equal to a map[int]*FieldData but ranging over a slice is faster than a map. + flatTags []int + flatData []*FieldData + + // like flatTags/flatData, nestedTags/nestedData uses binary searching to treat this as + // a map[int]*Decoder but ranging over a slice is faster than a map. + nestedTags []int + nestedDecoders []*Decoder + + // closers is used to return any nested DecodeResults to their respective pool + // calling (DecodeResult).Close on a nested DecodeResult has no effect + closers []*DecodeResult + + maxBuffer int + skipClose bool + unsafe bool +} + +// decode parses the data and adds it to the DecodeResult +func (r *DecodeResult) decode(data []byte) error { + dec := csproto.NewDecoder(data) + dec.SetMode(csproto.DecoderModeFast) + for dec.More() { + tag, wt, err := dec.DecodeTag() + if err != nil { + return err + } + + flatIdx, hasFlat := slices.BinarySearch(r.flatTags, tag) + if !hasFlat { + if _, err := dec.Skip(tag, wt); err != nil { + return err + } + continue + } + fd := r.flatData[flatIdx] + if len(fd.data) > 0 && fd.wt != wt { + return fmt.Errorf("invalid message data - repeated tag %d w/ different wire types (prev=%v, current=%v)", tag, fd.wt, wt) + } + switch wt { + case csproto.WireTypeVarint, csproto.WireTypeFixed32, csproto.WireTypeFixed64: + // varint, fixed32, and fixed64 could be multiple Go types so + // grab the raw bytes and defer interpreting them to the consumer/caller + // . varint -> int32, int64, uint32, uint64, sint32, sint64, bool, enum + // . fixed32 -> int32, uint32, float32 + // . fixed64 -> int32, uint64, float64 + val, err := dec.Skip(tag, wt) + if err != nil { + return err + } + + // Skip() returns the entire field contents, both the tag and the value, so we need to skip past the tag + val = val[csproto.SizeOfTagKey(tag):] + fd.wt = wt + fd.data = append(fd.data, val) + case csproto.WireTypeLengthDelimited: + val, err := dec.DecodeBytes() + if err != nil { + return err + } + + fd.wt = wt + fd.data = append(fd.data, val) + default: + return fmt.Errorf("read unknown/unsupported protobuf wire type (%v)", wt) + } + } + + return nil +} + +// clone will create a copy of the DecodeResult with the same tag and decoder information +// but a new fieldData slice +func (r *DecodeResult) clone() *DecodeResult { + if r == nil { + return nil + } + res := &DecodeResult{ + pool: r.pool, + filter: r.filter, + flatTags: r.flatTags, + flatData: make([]*FieldData, len(r.flatData)), + nestedTags: r.nestedTags, + nestedDecoders: r.nestedDecoders, + maxBuffer: r.maxBuffer, + unsafe: r.unsafe, + } + for i := range r.flatData { + res.flatData[i] = &FieldData{ + unsafe: r.unsafe, + } + } + return res +} + +// Close releases all internal resources held by r. +// +// Consumers should always call Close() on instances returned by [Decode] to ensure that internal +// resources are cleaned up. +// +// When using with csproto.DecoderModeFast it is important that any strings, bytes, etc that were generated +// using any of the DecodeResult/FieldData methods have moved out of scope before closing the DecodeResult. +func (r *DecodeResult) Close() error { + if r == nil || r.skipClose { + return nil + } + r.close() + return nil +} + +// close will recursively close the nested DecodeResults and return them to their respective pools +func (r *DecodeResult) close() { + for i := range r.flatData { + if r.flatData[i] == nil { + continue + } + r.flatData[i].data = r.flatData[i].data[:0] + } + + if r.closers != nil { + for i := range r.closers { + r.closers[i].close() + } + r.closers = r.closers[:0] + } + + if r.pool != nil { + if r.maxBuffer >= 0 { + r.trunc(r.maxBuffer) + } + if r.filter != nil { + if n := r.filter(r.cap()); n >= 0 { + r.trunc(n) + } + } + r.pool.Put(r) + } +} + +// trunc will reallocate slices that have a capacity greater than n +func (r *DecodeResult) trunc(n int) { + if r == nil || n < 0 { + return + } + if cap(r.closers) > n { + r.closers = make([]*DecodeResult, n) + } + for i := range r.flatData { + r.flatData[i].trunc(n) + } +} + +// cap will return the largest capacity of any slice in the DecodeResult +func (r *DecodeResult) cap() int { + if r == nil { + return 0 + } + c := cap(r.closers) + for i := range r.flatData { + if cc := r.flatData[i].cap(); cc > c { + c = cc + } + } + return c +} + +// FieldData returns a FieldData instance for the specified tag "path", if it exists. +// +// The tags parameter is a list of one or more integer field tags that act as a "path" to a particular +// field to support retreiving fields from nested messages. Each value is used to retreieve the field +// data at the corresponding level of nesting, i.e. a value of [1, 2] would return the field data for +// tag 2 within the nested data for tag 1 at the root. +func (r *DecodeResult) FieldData(tags ...int) (*FieldData, error) { + if r == nil || (len(r.flatTags) == 0 && len(r.nestedTags) == 0) { + return nil, ErrTagNotDefined + } + switch n := len(tags); n { + case 0: + return nil, fmt.Errorf("at least one tag key must be specified") + case 1: + return r.GetFieldData(tags[0]) + default: + var err error + for i := range tags[:n-1] { + r, err = r.NestedResult(tags[i]) + if err != nil { + return nil, err + } + } + return r.GetFieldData(tags[n-1]) + } +} + +// NestedResult will return the last/only DecodeResult located at the give tag path +// +// The tag parameter acts as a "path" to a particular field to support retrieving DecodeResult +// from nested messages. The value is used to retreieve the field data at the corresponding protonumber +func (r *DecodeResult) NestedResult(tag int) (*DecodeResult, error) { + if r == nil || len(r.nestedTags) == 0 { + return nil, ErrTagNotDefined + } + if tag < 0 { + tag *= -1 + } + flatIdx, hasFlat := slices.BinarySearch(r.flatTags, tag) + if !hasFlat { + return nil, ErrTagNotDefined + } + nestedIdx, hasNested := slices.BinarySearch(r.nestedTags, tag) + if !hasNested { + return nil, ErrNestingNotDefined + } + + // get the raw []byte slice for this tag + b, err := scalarValue(r.flatData[flatIdx], csproto.WireTypeLengthDelimited, func(b []byte) ([]byte, error) { + return b, nil + }) + if err != nil { + return nil, err + } + tmp, err := r.nestedDecoders[nestedIdx].decodeWithPool(b) + if err != nil { + return nil, err + } + tmp.skipClose = true + r.closers = append(r.closers, tmp) + return tmp, nil +} + +// NestedResults will return all of the DecodeResult located at the give tag path +// +// The tag parameter acts as a "path" to a particular field to support retrieving DecodeResult +// from nested messages. The value is used to retreieve the field data at the corresponding protonumber +func (r *DecodeResult) NestedResults(tag int) ([]*DecodeResult, error) { + if r == nil || len(r.nestedTags) == 0 { + return nil, ErrTagNotDefined + } + if tag < 0 { + tag *= -1 + } + nestedIdx, hasNested := slices.BinarySearch(r.nestedTags, tag) + if !hasNested { + return nil, ErrTagNotDefined + } + flatIdx, hasFlat := slices.BinarySearch(r.flatTags, tag) + if !hasFlat { + return nil, ErrTagNotDefined + } + fd := r.flatData[flatIdx] + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + results := make([]*DecodeResult, 0, len(fd.data)) + dec := r.nestedDecoders[nestedIdx] + for _, b := range fd.data { + res, err := dec.decodeWithPool(b) + if err != nil { + return nil, err + } + res.skipClose = true + results = append(results, res) + } + r.closers = append(r.closers, results...) + return results, nil +} + +// Range will iterate over all tags in the DecodeResult. +// +// If a field was not present in the original message fn will be called with the tag and a nil field. +// +// Currently, Range will iterate in the order of the tags, but this is not guaranteed for future use. +func (r *DecodeResult) Range(fn func(tag int, field *FieldData) bool) { + for idx, tag := range r.flatTags { + field := r.flatData[idx] + if field == nil || len(field.data) == 0 { + if !fn(tag, nil) { + return + } + continue + } + if !fn(tag, field) { + return + } + } +} + +// GetFieldData returns the raw field data object at the given tag +func (r *DecodeResult) GetFieldData(tag int) (*FieldData, error) { + if r == nil || (len(r.flatTags) == 0) { + return nil, ErrTagNotDefined + } + if tag < 0 { + tag *= -1 + } + + flatIdx, hasFlat := slices.BinarySearch(r.flatTags, tag) + if !hasFlat { + return nil, ErrTagNotDefined + } + + fd := r.flatData[flatIdx] + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + return fd, nil +} + +// loadFieldDataType will use the provided function to load the field data at the given tag and cast it to the desired type. +func loadFieldDataType[T any](r *DecodeResult, tag int, fn func(*FieldData) (T, error)) (T, error) { + fd, err := r.GetFieldData(tag) + if err != nil { + var zero T + return zero, err + } + return fn(fd) +} + +// BoolValue is a helper method to get the field data at the given tag and return it as a boolean. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.BoolValue() +func (r *DecodeResult) BoolValue(tag int) (bool, error) { + return loadFieldDataType(r, tag, (*FieldData).BoolValue) +} + +// BoolValues is a helper method to get the field data at the given tag and return it as a boolean slice. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.BoolValues() +func (r *DecodeResult) BoolValues(tag int) ([]bool, error) { + return loadFieldDataType(r, tag, (*FieldData).BoolValues) +} + +// BytesValue is a helper method to get the field data at the given tag and return it as a byte slice. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.BytesValue() +func (r *DecodeResult) BytesValue(tag int) ([]byte, error) { + return loadFieldDataType(r, tag, (*FieldData).BytesValue) +} + +// BytesValues is a helper method to get the field data at the given tag and return it as a []byte slice. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.BytesValues() +func (r *DecodeResult) BytesValues(tag int) ([][]byte, error) { + return loadFieldDataType(r, tag, (*FieldData).BytesValues) +} + +// Fixed32Value is a helper method to get the field data at the given tag and return it as a uint32. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Fixed32Value() +func (r *DecodeResult) Fixed32Value(tag int) (uint32, error) { + return loadFieldDataType(r, tag, (*FieldData).Fixed32Value) +} + +// Fixed32Values is a helper method to get the field data at the given tag and return it as a []uint32. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Fixed32Values() +func (r *DecodeResult) Fixed32Values(tag int) ([]uint32, error) { + return loadFieldDataType(r, tag, (*FieldData).Fixed32Values) +} + +// Fixed64Value is a helper method to get the field data at the given tag and return it as a uint64. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Fixed64Value() +func (r *DecodeResult) Fixed64Value(tag int) (uint64, error) { + return loadFieldDataType(r, tag, (*FieldData).Fixed64Value) +} + +// Fixed64Values is a helper method to get the field data at the given tag and return it as a []uint64. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Fixed64Values() +func (r *DecodeResult) Fixed64Values(tag int) ([]uint64, error) { + return loadFieldDataType(r, tag, (*FieldData).Fixed64Values) +} + +// Float32Value is a helper method to get the field data at the given tag and return it as a float32. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Float32Value() +func (r *DecodeResult) Float32Value(tag int) (float32, error) { + return loadFieldDataType(r, tag, (*FieldData).Float32Value) +} + +// Float32Values is a helper method to get the field data at the given tag and return it as a []float32. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Float32Values() +func (r *DecodeResult) Float32Values(tag int) ([]float32, error) { + return loadFieldDataType(r, tag, (*FieldData).Float32Values) +} + +// Float64Value is a helper method to get the field data at the given tag and return it as a float64. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Float64Value() +func (r *DecodeResult) Float64Value(tag int) (float64, error) { + return loadFieldDataType(r, tag, (*FieldData).Float64Value) +} + +// Float64Values is a helper method to get the field data at the given tag and return it as a []float64. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Float64Values() +func (r *DecodeResult) Float64Values(tag int) ([]float64, error) { + return loadFieldDataType(r, tag, (*FieldData).Float64Values) +} + +// Int32Value is a helper method to get the field data at the given tag and return it as an int32. +// +// Use this method to retrieve values that are defined as int32 in the Protobuf message. Fields that +// are defined as sint32 (and so use the [Protobuf ZigZag encoding]) should be retrieved using +// SInt32Value() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Int32Value() +func (r *DecodeResult) Int32Value(tag int) (int32, error) { + return loadFieldDataType(r, tag, (*FieldData).Int32Value) +} + +// Int32Values is a helper method to get the field data at the given tag and return it as a []int32. +// +// Use this method to retrieve values that are defined as repeated int32 in the Protobuf message. Fields that +// are defined as repeated sint32 (and so use the [Protobuf ZigZag encoding]) should be retrieved using +// SInt32Values() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Int32Values() +func (r *DecodeResult) Int32Values(tag int) ([]int32, error) { + return loadFieldDataType(r, tag, (*FieldData).Int32Values) +} + +// Int64Value is a helper method to get the field data at the given tag and return it as an int64. +// +// Use this method to retrieve values that are defined as int64 in the Protobuf message. Fields that +// are defined as sint64 (and so use the [Protobuf ZigZag encoding]) should be retrieved using +// SInt64Value() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Int64Value() +func (r *DecodeResult) Int64Value(tag int) (int64, error) { + return loadFieldDataType(r, tag, (*FieldData).Int64Value) +} + +// Int64Values is a helper method to get the field data at the given tag and return it as a []int64. +// +// Use this method to retrieve values that are defined as repeated int64 in the Protobuf message. Fields that +// are defined as repeated sint64 (and so use the [Protobuf ZigZag encoding]) should be retrieved using +// SInt64Values() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.Int64Values() +func (r *DecodeResult) Int64Values(tag int) ([]int64, error) { + return loadFieldDataType(r, tag, (*FieldData).Int64Values) +} + +// SInt32Value is a helper method to get the field data at the given tag and return it as an int32. +// +// Use this method to retrieve values that are defined as sint32 in the Protobuf message. Fields that +// are defined as int32 (and so use the [Protobuf base128 varint encoding]) should be retrieved using +// Int32Value() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.SInt32Value() +func (r *DecodeResult) SInt32Value(tag int) (int32, error) { + return loadFieldDataType(r, tag, (*FieldData).SInt32Value) +} + +// SInt32Values is a helper method to get the field data at the given tag and return it as a []int32. +// +// Use this method to retrieve values that are defined as repeated sint32 in the Protobuf message. Fields that +// are defined as repeated int32 (and so use the [Protobuf base128 varint encoding]) should be retrieved using +// Int32Values() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.SInt32Values() +func (r *DecodeResult) SInt32Values(tag int) ([]int32, error) { + return loadFieldDataType(r, tag, (*FieldData).SInt32Values) +} + +// SInt64Value is a helper method to get the field data at the given tag and return it as an int64. +// +// Use this method to retrieve values that are defined as sint64 in the Protobuf message. Fields that +// are defined as int64 (and so use the [Protobuf base128 varint encoding]) should be retrieved using +// Int64Value() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.SInt32Value() +func (r *DecodeResult) SInt64Value(tag int) (int64, error) { + return loadFieldDataType(r, tag, (*FieldData).SInt64Value) +} + +// SInt64Values is a helper method to get the field data at the given tag and return it as a []int64. +// +// Use this method to retrieve values that are defined as repeated sint64 in the Protobuf message. Fields that +// are defined as repeated int64 (and so use the [Protobuf base128 varint encoding]) should be retrieved using +// Int64Values() instead. +// +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.SInt64Values() +func (r *DecodeResult) SInt64Values(tag int) ([]int64, error) { + return loadFieldDataType(r, tag, (*FieldData).SInt64Values) +} + +// StringValue is a helper method to get the field data at the given tag and return it as a string. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.StringValue() +func (r *DecodeResult) StringValue(tag int) (string, error) { + return loadFieldDataType(r, tag, (*FieldData).StringValue) +} + +// StringValues is a helper method to get the field data at the given tag and return it as a []string. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.StringValues() +func (r *DecodeResult) StringValues(tag int) ([]string, error) { + return loadFieldDataType(r, tag, (*FieldData).StringValues) +} + +// UInt32Value is a helper method to get the field data at the given tag and return it as a uint32. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.UInt32Value() +func (r *DecodeResult) UInt32Value(tag int) (uint32, error) { + return loadFieldDataType(r, tag, (*FieldData).UInt32Value) +} + +// UInt32Values is a helper method to get the field data at the given tag and return it as a []uint32. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.UInt32Values() +func (r *DecodeResult) UInt32Values(tag int) ([]uint32, error) { + return loadFieldDataType(r, tag, (*FieldData).UInt32Values) +} + +// UInt64Value is a helper method to get the field data at the given tag and return it as a uint64. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.UInt64Value() +func (r *DecodeResult) UInt64Value(tag int) (uint64, error) { + return loadFieldDataType(r, tag, (*FieldData).UInt64Value) +} + +// UInt64Values is a helper method to get the field data at the given tag and return it as a []uint64. +// It is equivalent to +// +// fd, err := r.GetFieldData(tag) +// if err != nil { +// return err +// } +// v, err := fd.UInt64Values() +func (r *DecodeResult) UInt64Values(tag int) ([]uint64, error) { + return loadFieldDataType(r, tag, (*FieldData).UInt64Values) +} diff --git a/lazyproto/decode_test.go b/lazyproto/decode_test.go index e80529f..3b19361 100644 --- a/lazyproto/decode_test.go +++ b/lazyproto/decode_test.go @@ -1,6 +1,7 @@ package lazyproto import ( + "cmp" "fmt" "math" "testing" @@ -36,12 +37,22 @@ func ExampleDecodeResult_FieldData() { def := NewDef() // extract tags 1 and 2 from the nested message at tag 2 in the outer message _ = def.NestedTag(2, 1, 2) - res, err := Decode(data, def) + + // create a new decoder for this definition + dec, err := NewDecoder(def, WithMode(csproto.DecoderModeFast), WithMaxBufferSize(1024)) + if err != nil { + fmt.Println("unable to create new decoder:", err) + return + } + + // decode arbitrary data + res, err := dec.Decode(data) if err != nil { fmt.Println("error from decode:", err) return } defer func() { + // Only close the result after we are completely done using it and any values we retrieved from it if err := res.Close(); err != nil { fmt.Println("error from DecodeResult.Close():", err) } @@ -74,6 +85,205 @@ func ExampleDecodeResult_FieldData() { // description: bar } +func FuzzStrings(f *testing.F) { + for i, s := range []string{"hello", "world", "", " ", "123"} { + f.Add(i, s) + } + const tagID = 123 + dec, err := NewDecoder(Def{tagID: nil}) + require.NoError(f, err) + f.Fuzz(func(t *testing.T, n int, expected string) { + var msg []byte + for i := 0; i < n; i++ { + sz := len(fmt.Sprintf("%s%d", expected, i)) + msg = append(msg, make([]byte, csproto.SizeOfTagKey(tagID)+csproto.SizeOfVarint(uint64(sz))+sz)...) + } + enc := csproto.NewEncoder(msg) + for i := 0; i < n; i++ { + enc.EncodeString(tagID, fmt.Sprintf("%s%d", expected, i)) + } + result, err := dec.Decode(msg) + require.NoError(t, err) + v := checkErr(result.StringValue(tagID))(t, n) + vs := checkErr(result.StringValues(tagID))(t, n) + b := checkErr(result.BytesValue(tagID))(t, n) + bs := checkErr(result.BytesValues(tagID))(t, n) + if n <= 0 { + return + } + + exp := fmt.Sprintf("%s%d", expected, n-1) + assert.EqualValues(t, exp, v) + assert.EqualValues(t, exp, string(b)) + require.Len(t, vs, n) + require.Len(t, bs, n) + for i := 0; i < n; i++ { + exp = fmt.Sprintf("%s%d", expected, i) + assert.EqualValues(t, exp, vs[i]) + assert.EqualValues(t, exp, string(bs[i])) + } + }) +} + +func FuzzBools(f *testing.F) { + f.Add(0) + f.Add(1) + f.Add(2) + f.Add(20) + const tagID = 123 + dec, err := NewDecoder(Def{tagID: nil}) + require.NoError(f, err) + f.Fuzz(func(t *testing.T, n int) { + msg := make([]byte, max(n*(csproto.SizeOfTagKey(tagID)+1), 0)) + var expected []bool + enc := csproto.NewEncoder(msg) + for i := 0; i < n; i++ { + expected = append(expected, n%3 == 0) + enc.EncodeBool(tagID, n%3 == 0) + } + result, err := dec.Decode(msg) + require.NoError(t, err) + v := checkErr(result.BoolValue(tagID))(t, n) + vs := checkErr(result.BoolValues(tagID))(t, n) + if n <= 0 { + return + } + assert.EqualValues(t, expected[len(expected)-1], v) + assert.EqualValues(t, expected, vs) + }) +} + +func FuzzNumbers(f *testing.F) { + for i, v := range []float64{math.MinInt64, -100, -2, -1, 0, 1, 2, math.MaxUint64} { + f.Add(i, v) + } + + const ( + float32Tag = 1 + iota + float64Tag + int32Tag + int64Tag + fixed32Tag + fixed64Tag + sInt32Tag + sInt64Tag + uint32Tag + uint64Tag + ) + def := Def{} + for i := float32Tag; i <= uint64Tag; i++ { + def[i] = nil + } + dec, err := NewDecoder(def) + require.NoError(f, err) + + f.Fuzz(func(t *testing.T, n int, expected float64) { + var sz int + for i := 0; i < n; i++ { + exp := expected * float64(i) + sz += csproto.SizeOfTagKey(float32Tag) + 4 + + csproto.SizeOfTagKey(float64Tag) + 8 + + csproto.SizeOfTagKey(int32Tag) + csproto.SizeOfVarint(uint64(int32(exp))) + + csproto.SizeOfTagKey(int64Tag) + csproto.SizeOfVarint(uint64(int64(exp))) + + csproto.SizeOfTagKey(fixed32Tag) + 4 + + csproto.SizeOfTagKey(fixed64Tag) + 8 + + csproto.SizeOfTagKey(sInt32Tag) + csproto.SizeOfZigZag(uint64(int32(exp))) + + csproto.SizeOfTagKey(sInt64Tag) + csproto.SizeOfZigZag(uint64(int64(exp))) + + csproto.SizeOfTagKey(uint32Tag) + csproto.SizeOfVarint(uint64(uint32(exp))) + + csproto.SizeOfTagKey(uint64Tag) + csproto.SizeOfVarint(uint64(exp)) + } + msg := make([]byte, sz) + enc := csproto.NewEncoder(msg) + for i := 0; i < n; i++ { + exp := expected * float64(i) + enc.EncodeFloat32(float32Tag, float32(exp)) + enc.EncodeFloat64(float64Tag, exp) + t.Log("encoding", int32(exp)) + t.Log(msg) + enc.EncodeInt32(int32Tag, int32(exp)) + t.Log(msg) + enc.EncodeInt64(int64Tag, int64(exp)) + enc.EncodeFixed32(fixed32Tag, uint32(exp)) + enc.EncodeFixed64(fixed64Tag, uint64(exp)) + enc.EncodeSInt32(sInt32Tag, int32(exp)) + enc.EncodeSInt64(sInt64Tag, int64(exp)) + enc.EncodeUInt32(uint32Tag, uint32(exp)) + enc.EncodeUInt64(uint64Tag, uint64(exp)) + } + result, err := dec.Decode(msg) + require.NoError(t, err) + exp := expected * float64(n-1) + checkNum(result.Float32Value(float32Tag))(t, n, exp) + checkNum(result.Float64Value(float64Tag))(t, n, exp) + checkNum(result.Int32Value(int32Tag))(t, n, exp) + checkNum(result.Int64Value(int64Tag))(t, n, exp) + checkNum(result.Fixed32Value(fixed32Tag))(t, n, exp) + checkNum(result.Fixed64Value(fixed64Tag))(t, n, exp) + checkNum(result.SInt32Value(sInt32Tag))(t, n, exp) + checkNum(result.SInt64Value(sInt64Tag))(t, n, exp) + checkNum(result.UInt32Value(uint32Tag))(t, n, exp) + checkNum(result.UInt64Value(uint64Tag))(t, n, exp) + if t.Failed() { + return + } + + checkNums(result.Float32Values(float32Tag))(t, n, expected) + checkNums(result.Float64Values(float64Tag))(t, n, expected) + checkNums(result.Int32Values(int32Tag))(t, n, expected) + checkNums(result.Int64Values(int64Tag))(t, n, expected) + checkNums(result.Fixed32Values(fixed32Tag))(t, n, expected) + checkNums(result.Fixed64Values(fixed64Tag))(t, n, expected) + checkNums(result.SInt32Values(sInt32Tag))(t, n, expected) + checkNums(result.SInt64Values(sInt64Tag))(t, n, expected) + checkNums(result.UInt32Values(uint32Tag))(t, n, expected) + checkNums(result.UInt64Values(uint64Tag))(t, n, expected) + + }) +} + +type number interface { + ~int32 | ~int64 | ~uint32 | ~uint64 | ~float32 | ~float64 +} + +func checkNum[T number](x T, err error) func(t testing.TB, n int, expected float64) { + return func(t testing.TB, n int, expected float64) { + if n <= 0 { + assert.ErrorIs(t, err, ErrTagNotFound) + assert.EqualValues(t, 0, cmp.Compare(x, T(0))) + return + } + require.NoError(t, err) + assert.EqualValues(t, 0, cmp.Compare(x, T(expected))) + } +} + +func checkNums[T number](x []T, err error) func(t testing.TB, n int, expected float64) { + return func(t testing.TB, n int, expected float64) { + t.Logf("N:%d, Expected:%f, ExpectedT:%+v, T:%T, actual:%+v", n, expected, T(expected), T(expected), x) + if n <= 0 { + assert.ErrorIs(t, err, ErrTagNotFound) + return + } + require.NoError(t, err) + require.Len(t, x, n) + for i := 0; i < n; i++ { + exp := expected * float64(i) + assert.EqualValues(t, 0, cmp.Compare(x[i], T(exp))) + } + } +} + +func checkErr[T any](x T, err error) func(t testing.TB, n int) T { + return func(t testing.TB, n int) T { + if n <= 0 { + assert.ErrorIs(t, err, ErrTagNotFound) + return x + } + require.NoError(t, err) + return x + } +} + func TestDecode(t *testing.T) { var sampleMessage = []byte{ // field 1: varint boolean true @@ -96,7 +306,7 @@ func TestDecode(t *testing.T) { defer func() { _ = res.Close() }() assert.NoError(t, err) - assert.Empty(t, res.m) + assert.Empty(t, res.flatData) }) t.Run("decode with nil def", func(t *testing.T) { t.Parallel() @@ -104,7 +314,7 @@ func TestDecode(t *testing.T) { defer func() { _ = res.Close() }() assert.NoError(t, err) - assert.Empty(t, res.m) + assert.Empty(t, res.flatData) }) t.Run("decode with empty def", func(t *testing.T) { t.Parallel() @@ -112,7 +322,7 @@ func TestDecode(t *testing.T) { defer func() { _ = res.Close() }() assert.NoError(t, err) - assert.Empty(t, res.m) + assert.Empty(t, res.flatData) }) t.Run("decode with missing def keys", func(t *testing.T) { t.Parallel() @@ -121,7 +331,7 @@ func TestDecode(t *testing.T) { defer func() { _ = res.Close() }() assert.NoError(t, err) - assert.Empty(t, res.m) + assert.Empty(t, res.flatData[0].data) }) t.Run("decode with matching def keys", func(t *testing.T) { t.Parallel() @@ -130,7 +340,7 @@ func TestDecode(t *testing.T) { defer func() { _ = res.Close() }() assert.NoError(t, err) - assert.Len(t, res.m, 4, "should have 4 results") + assert.Len(t, res.flatData, 4, "should have 4 results") }) t.Run("decode with nested def keys", func(t *testing.T) { t.Parallel() @@ -140,8 +350,9 @@ func TestDecode(t *testing.T) { defer func() { _ = res.Close() }() assert.NoError(t, err) - assert.Len(t, res.m, 4, "should have 4 results") - fd := res.m[3] + assert.Len(t, res.flatData, 4, "should have 4 results") + fd, err := res.FieldData(3, 2) + assert.NoError(t, err) assert.Len(t, fd.data, 1) }) t.Run("get field data with nested def keys", func(t *testing.T) { @@ -162,7 +373,7 @@ func TestDecode(t *testing.T) { def := NewDef(csproto.MaxTagValue + 1) res, err := Decode(sampleMessage, def) defer func() { _ = res.Close() }() - assert.Empty(t, res.m) + assert.Empty(t, res.flatData) assert.Error(t, err) }) } @@ -219,7 +430,9 @@ func TestDecodeResultFieldData(t *testing.T) { fd, err := res.FieldData(1, -1, 1) assert.Nil(t, fd) assert.Error(t, err) - assert.Contains(t, fmt.Sprint(err), "negative tags must be the last (or only) path item") + assert.ErrorIs(t, err, ErrTagNotFound) + // negative tags no longer result in an error + //assert.Contains(t, fmt.Sprint(err), "negative tags must be the last (or only) path item") }) t.Run("returns nil and not found error for unmatched path", func(t *testing.T) { t.Parallel() @@ -307,20 +520,6 @@ func TestRawFieldData(t *testing.T) { defer func() { _ = res.Close() }() assert.NoError(t, err) defer res.Close() - - // negative tags must be the last/only value - fd, err := res.FieldData(1, -1, 1) - assert.Error(t, err) - assert.Nil(t, fd) - }) - t.Run("fails for incorrect wire type", func(t *testing.T) { - t.Parallel() - def := NewDef(-1) - res, err := Decode(sampleMessage, def) - defer func() { _ = res.Close() }() - assert.Error(t, err) - assert.Contains(t, fmt.Sprintf("%s", err), "invalid definition: raw mode only supported for length-delimited fields") - assert.Empty(t, res.m) }) } diff --git a/lazyproto/fielddata.go b/lazyproto/fielddata.go index ca07448..c0c5f8e 100644 --- a/lazyproto/fielddata.go +++ b/lazyproto/fielddata.go @@ -3,11 +3,12 @@ package lazyproto import ( "encoding/binary" "fmt" + "io" "math" - "reflect" + "slices" "sort" "strings" - "sync" + "unsafe" "github.com/CrowdStrike/csproto" ) @@ -35,13 +36,28 @@ import ( // // To avoid panics, any method called on a nil instance returns a zero value and [ErrTagNotFound]. type FieldData struct { + // one or more []byte values containing the raw bytes from the decoded message for single or + // repeated scalar values + data [][]byte + + // these are slices that are initialized when creating a new result, + // they will be reused if the result is closed and put back in the sync.Pool + // so we can reduce the overall allocations + boolSlice []bool + uint64Slice []uint64 + int64Slice []int64 + uint32Slice []uint32 + int32Slice []int32 + stringSlice []string + float32Slice []float32 + float64Slice []float64 + // holds the Protobuf wire type from the source data wt csproto.WireType - // holds either: - // . one or more []byte values containing the raw bytes from the decoded message for single or - // repeated scalar values - // . a map[int]*FieldData for nested values - data []any + + // holds the maximum capacity of the slice values + maxCap int + unsafe bool } // BoolValue converts the lazily-decoded field data into a bool. @@ -67,13 +83,21 @@ func (fd *FieldData) BoolValue() (bool, error) { // // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) BoolValues() ([]bool, error) { - return sliceValue(fd, csproto.WireTypeVarint, func(data []byte) (bool, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeVarint, fd.boolSlice, func(data []byte) (bool, int, error) { v, n, err := csproto.DecodeVarint(data) if err != nil { return false, 0, err } return v != 0, n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.boolSlice = s + } + return s, err } // StringValue converts the lazily-decoded field data into a string. @@ -81,6 +105,9 @@ func (fd *FieldData) BoolValues() ([]bool, error) { // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) StringValue() (string, error) { return scalarValue(fd, csproto.WireTypeLengthDelimited, func(data []byte) (string, error) { + if fd.unsafe { + return unsafe.String(unsafe.SliceData(data), len(data)), nil + } return string(data), nil }) } @@ -89,9 +116,20 @@ func (fd *FieldData) StringValue() (string, error) { // // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) StringValues() ([]string, error) { - return sliceValue(fd, csproto.WireTypeLengthDelimited, func(data []byte) (string, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeLengthDelimited, fd.stringSlice, func(data []byte) (string, int, error) { + if fd.unsafe { + return unsafe.String(unsafe.SliceData(data), len(data)), len(data), nil + } return string(data), len(data), nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.stringSlice = s + } + return s, err } // BytesValue converts the lazily-decoded field data into a []byte. @@ -99,7 +137,10 @@ func (fd *FieldData) StringValues() ([]string, error) { // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) BytesValue() ([]byte, error) { return scalarValue(fd, csproto.WireTypeLengthDelimited, func(data []byte) ([]byte, error) { - return data, nil + if fd.unsafe { + return data, nil + } + return slices.Clone(data), nil }) } @@ -107,9 +148,17 @@ func (fd *FieldData) BytesValue() ([]byte, error) { // // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) BytesValues() ([][]byte, error) { - return sliceValue(fd, csproto.WireTypeLengthDelimited, func(data []byte) ([]byte, int, error) { - return data, len(data), nil - }) + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + if fd.unsafe { + return fd.data, nil + } + output := make([][]byte, len(fd.data)) + for i := range output { + output[i] = slices.Clone(fd.data[i]) + } + return output, nil } // UInt32Value converts the lazily-decoded field data into a uint32. @@ -132,7 +181,10 @@ func (fd *FieldData) UInt32Value() (uint32, error) { // // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) UInt32Values() ([]uint32, error) { - return sliceValue(fd, csproto.WireTypeVarint, func(data []byte) (uint32, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeVarint, fd.uint32Slice, func(data []byte) (uint32, int, error) { value, n, err := csproto.DecodeVarint(data) if err != nil { return 0, 0, err @@ -142,11 +194,16 @@ func (fd *FieldData) UInt32Values() ([]uint32, error) { } return uint32(value), n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.uint32Slice = s + } + return s, err } // Int32Value converts the lazily-decoded field data into an int32. // -// Use this method to retreive values that are defined as int32 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as int32 in the Protobuf message. Fields that // are defined as sint32 (and so use the [Protobuf ZigZag encoding]) should be retrieved using // SInt32Value() instead. // @@ -169,7 +226,7 @@ func (fd *FieldData) Int32Value() (int32, error) { // Int32Values converts the lazily-decoded field data into a []int32. // -// Use this method to retreive values that are defined as int32 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as int32 in the Protobuf message. Fields that // are defined as sint32 (and so use the [Protobuf ZigZag encoding]) should be retrieved using // SInt32Values() instead. // @@ -177,22 +234,30 @@ func (fd *FieldData) Int32Value() (int32, error) { // // [Protobuf ZigZag encoding]: https://developers.google.com/protocol-buffers/docs/encoding#signed-ints func (fd *FieldData) Int32Values() ([]int32, error) { - return sliceValue(fd, csproto.WireTypeVarint, func(data []byte) (int32, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeVarint, fd.int32Slice, func(data []byte) (int32, int, error) { value, n, err := csproto.DecodeVarint(data) if err != nil { return 0, 0, err } // ensure the result is within [-math.MaxInt32, math.MaxInt32] when converted to a signed value - if value > math.MaxUint32 { + if i64 := int64(value); i64 > math.MaxInt32 || i64 < math.MinInt32 { return 0, 0, csproto.ErrValueOverflow } return int32(value), n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.int32Slice = s + } + return s, err } // SInt32Value converts the lazily-decoded field data into an int32. // -// Use this method to retreive values that are defined as sint32 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as sint32 in the Protobuf message. Fields that // are defined as int32 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int32Value() instead. // @@ -211,7 +276,7 @@ func (fd *FieldData) SInt32Value() (int32, error) { // SInt32Values converts the lazily-decoded field data into a []int32. // -// Use this method to retreive values that are defined as sint32 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as sint32 in the Protobuf message. Fields that // are defined as int32 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int32Values() instead. // @@ -219,13 +284,21 @@ func (fd *FieldData) SInt32Value() (int32, error) { // // [Protobuf base128 varint encoding]: https://developers.google.com/protocol-buffers/docs/encoding#varints func (fd *FieldData) SInt32Values() ([]int32, error) { - return sliceValue(fd, csproto.WireTypeVarint, func(data []byte) (int32, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeVarint, fd.int32Slice, func(data []byte) (int32, int, error) { value, n, err := csproto.DecodeZigZag32(data) if err != nil { return 0, 0, err } return value, n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.int32Slice = s + } + return s, err } // UInt64Value converts the lazily-decoded field data into a uint64. @@ -245,18 +318,26 @@ func (fd *FieldData) UInt64Value() (uint64, error) { // // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) UInt64Values() ([]uint64, error) { - return sliceValue(fd, csproto.WireTypeVarint, func(data []byte) (uint64, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeVarint, fd.uint64Slice, func(data []byte) (uint64, int, error) { value, n, err := csproto.DecodeVarint(data) if err != nil { return 0, 0, err } return value, n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.uint64Slice = s + } + return s, err } // Int64Value converts the lazily-decoded field data into an int64. // -// Use this method to retreive values that are defined as int64 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as int64 in the Protobuf message. Fields that // are defined as sint64 (and so use the [Protobuf ZigZag encoding]) should be retrieved using // SInt64Value() instead. // @@ -275,7 +356,7 @@ func (fd *FieldData) Int64Value() (int64, error) { // Int64Values converts the lazily-decoded field data into a []int64. // -// Use this method to retreive values that are defined as int64 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as int64 in the Protobuf message. Fields that // are defined as sint64 (and so use the [Protobuf ZigZag encoding]) should be retrieved using // SInt64Values() instead. // @@ -283,18 +364,26 @@ func (fd *FieldData) Int64Value() (int64, error) { // // [Protobuf ZigZag encoding]: https://developers.google.com/protocol-buffers/docs/encoding#signed-ints func (fd *FieldData) Int64Values() ([]int64, error) { - return sliceValue(fd, csproto.WireTypeVarint, func(data []byte) (int64, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeVarint, fd.int64Slice, func(data []byte) (int64, int, error) { value, n, err := csproto.DecodeVarint(data) if err != nil { return 0, 0, err } return int64(value), n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.int64Slice = s + } + return s, err } // SInt64Value converts the lazily-decoded field data into an int64. // -// Use this method to retreive values that are defined as sint64 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as sint64 in the Protobuf message. Fields that // are defined as int64 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int64Value() instead. // @@ -313,7 +402,7 @@ func (fd *FieldData) SInt64Value() (int64, error) { // SInt64Values converts the lazily-decoded field data into a []int64. // -// Use this method to retreive values that are defined as sint64 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as sint64 in the Protobuf message. Fields that // are defined as int64 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int64Values() instead. // @@ -321,18 +410,26 @@ func (fd *FieldData) SInt64Value() (int64, error) { // // [Protobuf base128 varint encoding]: https://developers.google.com/protocol-buffers/docs/encoding#varints func (fd *FieldData) SInt64Values() ([]int64, error) { - return sliceValue(fd, csproto.WireTypeVarint, func(data []byte) (int64, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeVarint, fd.int64Slice, func(data []byte) (int64, int, error) { value, n, err := csproto.DecodeZigZag64(data) if err != nil { return 0, 0, err } return value, n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.int64Slice = s + } + return s, err } // Fixed32Value converts the lazily-decoded field data into a uint32. // -// Use this method to retreive values that are defined as fixed32 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as fixed32 in the Protobuf message. Fields that // are defined as uint32 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int32Value() instead. // @@ -351,7 +448,7 @@ func (fd *FieldData) Fixed32Value() (uint32, error) { // Fixed32Values converts the lazily-decoded field data into a []uint32. // -// Use this method to retreive values that are defined as fixed32 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as fixed32 in the Protobuf message. Fields that // are defined as uint32 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int32Values() instead. // @@ -359,18 +456,26 @@ func (fd *FieldData) Fixed32Value() (uint32, error) { // // [Protobuf base128 varint encoding]: https://developers.google.com/protocol-buffers/docs/encoding#varints func (fd *FieldData) Fixed32Values() ([]uint32, error) { - return sliceValue(fd, csproto.WireTypeFixed32, func(data []byte) (uint32, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeFixed32, fd.uint32Slice, func(data []byte) (uint32, int, error) { value, n, err := csproto.DecodeFixed32(data) if err != nil { return 0, 0, err } return value, n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.uint32Slice = s + } + return s, err } // Fixed64Value converts the lazily-decoded field data into a uint64. // -// Use this method to retreive values that are defined as fixed64 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as fixed64 in the Protobuf message. Fields that // are defined as uint64 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int64Value() instead. // @@ -389,7 +494,7 @@ func (fd *FieldData) Fixed64Value() (uint64, error) { // Fixed64Values converts the lazily-decoded field data into a []uint64. // -// Use this method to retreive values that are defined as fixed64 in the Protobuf message. Fields that +// Use this method to retrieve values that are defined as fixed64 in the Protobuf message. Fields that // are defined as uint64 (and so use the [Protobuf base128 varint encoding]) should be retrieved using // Int64Values() instead. // @@ -397,13 +502,21 @@ func (fd *FieldData) Fixed64Value() (uint64, error) { // // [Protobuf base128 varint encoding]: https://developers.google.com/protocol-buffers/docs/encoding#varints func (fd *FieldData) Fixed64Values() ([]uint64, error) { - return sliceValue(fd, csproto.WireTypeFixed64, func(data []byte) (uint64, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeFixed64, fd.uint64Slice, func(data []byte) (uint64, int, error) { value, n, err := csproto.DecodeFixed64(data) if err != nil { return 0, 0, err } return value, n, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.uint64Slice = s + } + return s, err } // Float32Value converts the lazily-decoded field data into a float32. @@ -411,6 +524,9 @@ func (fd *FieldData) Fixed64Values() ([]uint64, error) { // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) Float32Value() (float32, error) { return scalarValue(fd, csproto.WireTypeFixed32, func(data []byte) (float32, error) { + if len(data) < 4 { + return 0, io.ErrUnexpectedEOF + } return math.Float32frombits(binary.LittleEndian.Uint32(data)), nil }) } @@ -419,9 +535,20 @@ func (fd *FieldData) Float32Value() (float32, error) { // // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) Float32Values() ([]float32, error) { - return sliceValue(fd, csproto.WireTypeFixed32, func(data []byte) (float32, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeFixed32, fd.float32Slice, func(data []byte) (float32, int, error) { + if len(data) < 4 { + return 0, 0, io.ErrUnexpectedEOF + } return math.Float32frombits(binary.LittleEndian.Uint32(data)), 4, nil }) + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.float32Slice = s + } + return s, err } // Float64Value converts the lazily-decoded field data into a float64. @@ -429,6 +556,9 @@ func (fd *FieldData) Float32Values() ([]float32, error) { // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) Float64Value() (float64, error) { return scalarValue(fd, csproto.WireTypeFixed64, func(data []byte) (float64, error) { + if len(data) < 8 { + return 0, io.ErrUnexpectedEOF + } return math.Float64frombits(binary.LittleEndian.Uint64(data)), nil }) } @@ -437,36 +567,20 @@ func (fd *FieldData) Float64Value() (float64, error) { // // See the [FieldData] docs for more specific details about interpreting lazily-decoded data. func (fd *FieldData) Float64Values() ([]float64, error) { - return sliceValue(fd, csproto.WireTypeFixed64, func(data []byte) (float64, int, error) { + if fd == nil || len(fd.data) == 0 { + return nil, ErrTagNotFound + } + s, err := sliceValue(fd, csproto.WireTypeFixed64, fd.float64Slice, func(data []byte) (float64, int, error) { + if len(data) < 8 { + return 0, 0, io.ErrUnexpectedEOF + } return math.Float64frombits(binary.LittleEndian.Uint64(data)), 8, nil }) -} - -// close releases all internal resources held by fd. -// -// This is unexported because consumers should not call this method directly. It is called automatically -// by [DecodeResult.Close]. -func (fd *FieldData) close() { - for i, d := range fd.data { - if sub, ok := d.(map[int]*FieldData); ok && sub != nil { - for k, v := range sub { - if v != nil { - v.close() - } - delete(sub, k) - } - fieldDataMapPool.Put(sub) - } - fd.data[i] = nil + if fd.unsafe { + fd.maxCap = max(fd.maxCap, cap(s)) + fd.float64Slice = s } - fd.data = nil -} - -// a sync.Pool of field data maps to cut down on repeated small allocations -var fieldDataMapPool = sync.Pool{ - New: func() any { - return make(map[int]*FieldData) - }, + return s, err } // scalarProtoFieldGoType is a generic constraint that defines the Go types that can be created from @@ -486,30 +600,18 @@ func scalarValue[T scalarProtoFieldGoType](fd *FieldData, wt csproto.WireType, c if fd.wt != wt { return zero, wireTypeMismatchError(fd.wt, wt) } - switch data := fd.data[len(fd.data)-1].(type) { - case []byte: - value, err := convertFn(data) - if err != nil { - return zero, err - } - return value, nil - case map[int]*FieldData: - return zero, fmt.Errorf("cannot convert field data for a nested message into %T", zero) - default: - // TODO: should this be a panic? - // . elements of fd.data *SHOULD* always contain either []byte or map[int]*FieldData so this - // is a "just in case" path - return zero, rawValueConversionError[T](data) + data := fd.data[len(fd.data)-1] + value, err := convertFn(data) + if err != nil { + return zero, err } + return value, nil } // sliceValue is a helper to convert the lazily-decoded field data in fd to a slice of values of // concrete type T by successively invoking the provided convertFn to produce each value. The wt parameter // contains the expected Protobuf wire type for a Go value of type T. -func sliceValue[T scalarProtoFieldGoType](fd *FieldData, wt csproto.WireType, convertFn func([]byte) (T, int, error)) ([]T, error) { - if fd == nil { - return nil, ErrTagNotFound - } +func sliceValue[T scalarProtoFieldGoType](fd *FieldData, wt csproto.WireType, res []T, convertFn func([]byte) (T, int, error)) ([]T, error) { switch fd.wt { // wt is the wire type for values of type T // packed repeated fields are always WireTypeLengthDelimited @@ -517,31 +619,24 @@ func sliceValue[T scalarProtoFieldGoType](fd *FieldData, wt csproto.WireType, co if len(fd.data) == 0 { return nil, ErrTagNotFound } - var res []T - for _, rv := range fd.data { - switch data := rv.(type) { - case []byte: - // data contains 1 or more encoded values of type T - // . invoke convertFn at each successive offset to extract them - for offset := 0; offset < len(data); { - v, n, err := convertFn(data[offset:]) - if err != nil { - return nil, err - } - if n == 0 { - return nil, csproto.ErrInvalidVarintData - } - res = append(res, v) - offset += n + if res != nil { + res = res[:0] + } else { + res = make([]T, 0, len(fd.data)) + } + for _, data := range fd.data { + // data contains 1 or more encoded values of type T + // . invoke convertFn at each successive offset to extract them + for offset := 0; offset < len(data); { + v, n, err := convertFn(data[offset:]) + if err != nil { + return nil, err } - case map[int]*FieldData: - var zero T - return nil, fmt.Errorf("cannot convert field data for a nested message into a %T", zero) - default: - // TODO: should this be a panic? - // . elements of fd.data *SHOULD* always contain either []byte or map[int]*FieldData so this - // is a "just in case" path - return nil, rawValueConversionError[T](data) + if n <= 0 { + return nil, csproto.ErrInvalidVarintData + } + res = append(res, v) + offset += n } } return res, nil @@ -582,16 +677,10 @@ func (e *WireTypeMismatchError) Error() string { return string(*e) } -// rawValueConversionError constructs a new RawValueConversionError -func rawValueConversionError[T any](from any) *RawValueConversionError { - var target T - msg := fmt.Sprintf("unable to convert raw value (Kind = %s) to %T", reflect.ValueOf(from).Kind().String(), target) - err := RawValueConversionError(msg) - return &err -} - // RawValueConversionError is returned when the lazily-decoded value for a Protobuf field could not // be converted to the requested Go type. +// +// Deprecated: RawValueConversionError is no longer possible to return type RawValueConversionError string // Error satisfies the error interface @@ -601,3 +690,55 @@ func (e *RawValueConversionError) Error() string { } return string(*e) } + +// trunc will reduce the capacity of every slice in *FieldData to n +func (fd *FieldData) trunc(n int) { + if fd == nil { + return + } + if n < 0 { + n = 0 + } + if cap(fd.data) > n { + fd.data = make([][]byte, 0, n) + } + + // if the slices aren't large enough we don't need to reallocate + if n > fd.maxCap { + return + } + + if cap(fd.boolSlice) > n { + fd.boolSlice = make([]bool, 0, n) + } + if cap(fd.uint64Slice) > n { + fd.uint64Slice = make([]uint64, 0, n) + } + if cap(fd.int64Slice) > n { + fd.int64Slice = make([]int64, 0, n) + } + if cap(fd.uint32Slice) > n { + fd.uint32Slice = make([]uint32, 0, n) + } + if cap(fd.int32Slice) > n { + fd.int32Slice = make([]int32, 0, n) + } + if cap(fd.stringSlice) > n { + fd.stringSlice = make([]string, 0, n) + } + if cap(fd.float32Slice) > n { + fd.float32Slice = make([]float32, 0, n) + } + if cap(fd.float64Slice) > n { + fd.float64Slice = make([]float64, 0, n) + } + fd.maxCap = n +} + +// cap returns the capacity of the largest slice within fd +func (fd *FieldData) cap() int { + if fd == nil { + return 0 + } + return max(cap(fd.data), fd.maxCap) +}