diff --git a/commons/config.go b/commons/config.go index c257dff..41562c9 100644 --- a/commons/config.go +++ b/commons/config.go @@ -91,6 +91,7 @@ type PubsubConfig struct { MsgValidationTimeout time.Duration `json:"msgValidationTimeout,omitempty" yaml:"msgValidationTimeout,omitempty"` Scoring *ScoringParams `json:"scoring,omitempty" yaml:"scoring,omitempty"` MsgValidator *MsgValidationConfig `json:"msgValidator,omitempty" yaml:"msgValidator,omitempty"` + MsgIDFnConfig *MsgIDFnConfig `json:"msgIDFn,omitempty" yaml:"msgIDFn,omitempty"` Trace bool `json:"trace,omitempty" yaml:"trace,omitempty"` } @@ -103,6 +104,11 @@ func (psc PubsubConfig) GetTopicConfig(name string) (TopicConfig, bool) { return TopicConfig{}, false } +type MsgIDFnConfig struct { + Type string `json:"type,omitempty" yaml:"type,omitempty"` + Size int `json:"size,omitempty" yaml:"size,omitempty"` +} + type MsgValidationConfig struct { Timeout time.Duration `json:"timeout,omitempty" yaml:"timeout,omitempty"` Concurrency int `json:"concurrency,omitempty" yaml:"concurrency,omitempty"` diff --git a/core/gossip/msg_id.go b/core/gossip/msg_id.go index a9ef48e..b64d8f9 100644 --- a/core/gossip/msg_id.go +++ b/core/gossip/msg_id.go @@ -1,6 +1,7 @@ package gossip import ( + "crypto/md5" "crypto/sha256" "encoding/hex" @@ -8,8 +9,29 @@ import ( pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb" ) -// msgIDSha256 uses sha256 hash of the message content -func MsgIDSha256(size int) pubsub.MsgIdFunction { +type MsgIDSize int +type MsgIDFuncType string + +const ( + MsgIDSha256Type MsgIDFuncType = "sha256" + MsgIDMD5Type MsgIDFuncType = "md5" +) + +var DefaultMsgIDFn = MsgIDSha256(20) + +func MsgIDFn(tp MsgIDFuncType, size MsgIDSize) pubsub.MsgIdFunction { + switch tp { + case MsgIDSha256Type: + return MsgIDSha256(size) + case MsgIDMD5Type: + return MsgIDMD5(size) + default: + return DefaultMsgIDFn + } +} + +// MsgIDSha256 uses sha256 hash of the message content +func MsgIDSha256(size MsgIDSize) pubsub.MsgIdFunction { return func(pmsg *pubsub_pb.Message) string { msg := pmsg.GetData() if len(msg) == 0 { @@ -17,6 +39,25 @@ func MsgIDSha256(size int) pubsub.MsgIdFunction { } // TODO: optimize, e.g. by using a pool of hashers h := sha256.Sum256(msg) + if msgSize := MsgIDSize(len(h)); size > msgSize { + size = msgSize + } + return hex.EncodeToString(h[:size]) + } +} + +// MsgIDSMD5 uses md5 hash of the message content +func MsgIDMD5(size MsgIDSize) pubsub.MsgIdFunction { + return func(pmsg *pubsub_pb.Message) string { + msg := pmsg.GetData() + if len(msg) == 0 { + return "" + } + // TODO: optimize, e.g. by using a pool of hashers + h := md5.Sum(msg) + if msgSize := MsgIDSize(len(h)); size > msgSize { + size = msgSize + } return hex.EncodeToString(h[:size]) } } diff --git a/core/gossip/msg_id_test.go b/core/gossip/msg_id_test.go new file mode 100644 index 0000000..df5c38c --- /dev/null +++ b/core/gossip/msg_id_test.go @@ -0,0 +1,57 @@ +package gossip + +import ( + "testing" + + pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/stretchr/testify/require" +) + +func TestMsgID(t *testing.T) { + tests := []struct { + name string + msgID MsgIDFuncType + size MsgIDSize + input []byte + want string + }{ + { + name: "sha256", + msgID: MsgIDSha256Type, + size: 20, + input: []byte("hello world"), + want: "b94d27b9934d3e08a52e52d7da7dabfac484efe3", + }, + { + name: "md5", + msgID: MsgIDMD5Type, + size: 10, + input: []byte("hello world"), + want: "5eb63bbbe01eeed093cb", + }, + { + name: "default", + msgID: "", + size: 0, + input: []byte("hello world"), + want: "b94d27b9934d3e08a52e52d7da7dabfac484efe3", + }, + { + name: "size overflow", + msgID: MsgIDMD5Type, + size: 100, + input: []byte("hello world"), + want: "5eb63bbbe01eeed093cb22bb8f5acdc3", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + msgIDFn := MsgIDFn(tc.msgID, tc.size) + got := msgIDFn(&pubsub_pb.Message{ + Data: tc.input, + }) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/core/pubsub.go b/core/pubsub.go index f3294a1..bb8fbf4 100644 --- a/core/pubsub.go +++ b/core/pubsub.go @@ -21,11 +21,12 @@ var ( ) func (c *Controller) setupPubsubRouter(ctx context.Context, cfg commons.Config) error { + msgID := gossip.MsgIDFn(gossip.MsgIDFuncType(cfg.Pubsub.MsgIDFnConfig.Type), gossip.MsgIDSize(cfg.Pubsub.MsgIDFnConfig.Size)) opts := []pubsub.Option{ pubsub.WithMessageSigning(false), pubsub.WithMessageSignaturePolicy(pubsub.StrictNoSign), pubsub.WithGossipSubParams(gossip.GossipSubParams(cfg.Pubsub.Overlay)), - pubsub.WithMessageIdFn(gossip.MsgIDSha256(20)), + pubsub.WithMessageIdFn(msgID), } if cfg.Pubsub.Scoring != nil { diff --git a/core/testutils.go b/core/testutils.go index b1c552c..df48819 100644 --- a/core/testutils.go +++ b/core/testutils.go @@ -64,11 +64,11 @@ func SetupTestControllers(ctx context.Context, t *testing.T, n int, routingFn fu } msgRouter := NewMsgRouter(1024, 4, func(mw *MsgWrapper[error]) { routingFn(mw.Msg) - }, gossip.MsgIDSha256(20)) + }, gossip.DefaultMsgIDFn) valRouter := NewMsgRouter(1024, 4, func(mw *MsgWrapper[pubsub.ValidationResult]) { res := valFn(mw.Peer, mw.Msg) mw.Result = res - }, gossip.MsgIDSha256(20)) + }, gossip.DefaultMsgIDFn) c, err := NewController(ctx, cfg, msgRouter, valRouter, fmt.Sprintf("peer-%d", i+1)) require.NoError(t, err) controllers[i] = c