From a9cf43677c3ecec80ebd65c1bf0639d14ad45ab5 Mon Sep 17 00:00:00 2001 From: Lasse Martin Jakobsen Date: Tue, 29 Nov 2022 16:18:54 +0100 Subject: [PATCH] feat: enable iteration of extensions of proto messages Proposal - Range Extensions (#86) * feat: enable iteration of extensions of proto messages * fix lint issue and remove unnecessary extension fields * remove extension struct * add message type to unknown message type return error in range extensions. * add warning to godocs on RangeExtensions Co-authored-by: Lasse Jakobsen --- extensions.go | 51 ++++++++++++++++++++++++++++++++++++++++++++++++++- json.go | 3 ++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/extensions.go b/extensions.go index 87b1bb9..2e45099 100644 --- a/extensions.go +++ b/extensions.go @@ -9,6 +9,55 @@ import ( "google.golang.org/protobuf/reflect/protoreflect" ) +// RangeExtensions iterates through all extension descriptors of a given proto message, calling fn +// on each iteration. It returns immediately on any error encountered. +// WARNING: RangeExtensions ranges over all registered extensions and therefore has a very high performance +// cost. Please consider using individual calls to GetExtension, if possible. +func RangeExtensions(msg interface{}, fn func(value interface{}, name string, field int32) error) error { + msgType := MsgType(msg) + + switch msgType { + case MessageTypeGogo: + exts, err := gogo.ExtensionDescs(msg.(gogo.Message)) + if err != nil { + return err + } + for _, ext := range exts { + if err = fn(ext, ext.Name, ext.Field); err != nil { + return err + } + } + return nil + case MessageTypeGoogleV1: + exts, err := google.ExtensionDescs(msg.(google.Message)) + if err != nil { + return err + } + for _, ext := range exts { + if err = fn(ext, + string(ext.TypeDescriptor().FullName()), + int32(ext.TypeDescriptor().Descriptor().Number()), + ); err != nil { + return err + } + } + return nil + case MessageTypeGoogle: + var err error + googlev2.RangeExtensions(msg.(googlev2.Message), func(t protoreflect.ExtensionType, v interface{}) bool { + err = fn(v, + string(t.TypeDescriptor().FullName()), + int32(t.TypeDescriptor().Descriptor().Number()), + ) + return err != nil + }) + return err + case MessageTypeUnknown: + return fmt.Errorf("unsupported message type: %T", msg) + } + return nil +} + // HasExtension returns true if msg contains the specified proto2 extension field, delegating to the // appropriate underlying Protobuf API based on the concrete type of msg. func HasExtension(msg interface{}, ext interface{}) bool { @@ -89,7 +138,7 @@ func GetExtension(msg interface{}, ext interface{}) (interface{}, error) { } return gogo.GetExtension(msg.(gogo.Message), ed) default: - return nil, fmt.Errorf("unsupported message type %T", ext) + return nil, fmt.Errorf("unsupported message type %T", msg) } } diff --git a/json.go b/json.go index ecef5a1..5e668fd 100644 --- a/json.go +++ b/json.go @@ -42,7 +42,8 @@ var _ json.Marshaler = (*jsonMarshaler)(nil) // this method calls the appropriate underlying runtime (Gogo vs Google V1 vs Google V2) based on // the message's actual type. func (m *jsonMarshaler) MarshalJSON() ([]byte, error) { - if m.msg == nil || reflect.ValueOf(m.msg).IsNil() { + value := reflect.ValueOf(m.msg) + if m.msg == nil || value.Kind() == reflect.Ptr && value.IsNil() { return nil, nil }