Skip to content

Commit

Permalink
feat: enable iteration of extensions of proto messages
Browse files Browse the repository at this point in the history
Proposal - Range Extensions (#86)
* feat: enable iteration of extensions of proto messages
* fix lint issue and remove unnecessary extension fields
* remove extension struct
* add message type to unknown message type return error in range extensions.
* add warning to godocs on RangeExtensions

Co-authored-by: Lasse Jakobsen <[email protected]>
  • Loading branch information
Pungyeon and Lasse Jakobsen authored Nov 29, 2022
1 parent 4b67217 commit a9cf436
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
51 changes: 50 additions & 1 deletion extensions.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,55 @@ import (
"google.golang.org/protobuf/reflect/protoreflect"
)

// RangeExtensions iterates through all extension descriptors of a given proto message, calling fn
// on each iteration. It returns immediately on any error encountered.
// WARNING: RangeExtensions ranges over all registered extensions and therefore has a very high performance
// cost. Please consider using individual calls to GetExtension, if possible.
func RangeExtensions(msg interface{}, fn func(value interface{}, name string, field int32) error) error {
msgType := MsgType(msg)

switch msgType {
case MessageTypeGogo:
exts, err := gogo.ExtensionDescs(msg.(gogo.Message))
if err != nil {
return err
}
for _, ext := range exts {
if err = fn(ext, ext.Name, ext.Field); err != nil {
return err
}
}
return nil
case MessageTypeGoogleV1:
exts, err := google.ExtensionDescs(msg.(google.Message))
if err != nil {
return err
}
for _, ext := range exts {
if err = fn(ext,
string(ext.TypeDescriptor().FullName()),
int32(ext.TypeDescriptor().Descriptor().Number()),
); err != nil {
return err
}
}
return nil
case MessageTypeGoogle:
var err error
googlev2.RangeExtensions(msg.(googlev2.Message), func(t protoreflect.ExtensionType, v interface{}) bool {
err = fn(v,
string(t.TypeDescriptor().FullName()),
int32(t.TypeDescriptor().Descriptor().Number()),
)
return err != nil
})
return err
case MessageTypeUnknown:
return fmt.Errorf("unsupported message type: %T", msg)
}
return nil
}

// HasExtension returns true if msg contains the specified proto2 extension field, delegating to the
// appropriate underlying Protobuf API based on the concrete type of msg.
func HasExtension(msg interface{}, ext interface{}) bool {
Expand Down Expand Up @@ -89,7 +138,7 @@ func GetExtension(msg interface{}, ext interface{}) (interface{}, error) {
}
return gogo.GetExtension(msg.(gogo.Message), ed)
default:
return nil, fmt.Errorf("unsupported message type %T", ext)
return nil, fmt.Errorf("unsupported message type %T", msg)
}
}

Expand Down
3 changes: 2 additions & 1 deletion json.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ var _ json.Marshaler = (*jsonMarshaler)(nil)
// this method calls the appropriate underlying runtime (Gogo vs Google V1 vs Google V2) based on
// the message's actual type.
func (m *jsonMarshaler) MarshalJSON() ([]byte, error) {
if m.msg == nil || reflect.ValueOf(m.msg).IsNil() {
value := reflect.ValueOf(m.msg)
if m.msg == nil || value.Kind() == reflect.Ptr && value.IsNil() {
return nil, nil
}

Expand Down

0 comments on commit a9cf436

Please sign in to comment.