Skip to content

Commit

Permalink
Support batch query protos (#215)
Browse files Browse the repository at this point in the history
These are the minimal accompanying code changes for
xmtp/proto#218, which are a prerequisite for
#126.

- Flattens query requests into a simpler structure
- Publishes and queries are now *batch* publishes and queries. For now,
we operate against the *first* item in the batch
- I added tasks to properly implement the batching and unit tests later,
but don't treat it as high priority for now
- Adds a common method to validate batch query requests (refactored out
from the subscribe worker)
  • Loading branch information
richardhuaaa authored Oct 11, 2024
1 parent cc8dbbe commit 4047800
Show file tree
Hide file tree
Showing 11 changed files with 528 additions and 699 deletions.
35 changes: 21 additions & 14 deletions pkg/api/publish_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@ func TestPublishEnvelope(t *testing.T) {
api, db, cleanup := apiTestUtils.NewTestAPIClient(t)
defer cleanup()

resp, err := api.PublishEnvelope(
resp, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: testutils.CreatePayerEnvelope(t),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{testutils.CreatePayerEnvelope(t)},
},
)
require.NoError(t, err)
Expand All @@ -29,7 +29,10 @@ func TestPublishEnvelope(t *testing.T) {
unsignedEnv := &message_api.UnsignedOriginatorEnvelope{}
require.NoError(
t,
proto.Unmarshal(resp.GetOriginatorEnvelope().GetUnsignedOriginatorEnvelope(), unsignedEnv),
proto.Unmarshal(
resp.GetOriginatorEnvelopes()[0].GetUnsignedOriginatorEnvelope(),
unsignedEnv,
),
)
clientEnv := &message_api.ClientEnvelope{}
require.NoError(
Expand All @@ -50,7 +53,7 @@ func TestPublishEnvelope(t *testing.T) {

originatorEnv := &message_api.OriginatorEnvelope{}
require.NoError(t, proto.Unmarshal(envs[0].OriginatorEnvelope, originatorEnv))
return proto.Equal(originatorEnv, resp.GetOriginatorEnvelope())
return proto.Equal(originatorEnv, resp.GetOriginatorEnvelopes()[0])
}, 500*time.Millisecond, 50*time.Millisecond)
}

Expand All @@ -60,10 +63,10 @@ func TestUnmarshalErrorOnPublish(t *testing.T) {

envelope := testutils.CreatePayerEnvelope(t)
envelope.UnsignedClientEnvelope = []byte("invalidbytes")
_, err := api.PublishEnvelope(
_, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: envelope,
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{envelope},
},
)
require.ErrorContains(t, err, "unmarshal")
Expand All @@ -75,10 +78,12 @@ func TestMismatchingOriginatorOnPublish(t *testing.T) {

clientEnv := testutils.CreateClientEnvelope()
clientEnv.Aad.TargetOriginator = 2
_, err := api.PublishEnvelope(
_, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: testutils.CreatePayerEnvelope(t, clientEnv),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{
testutils.CreatePayerEnvelope(t, clientEnv),
},
},
)
require.ErrorContains(t, err, "originator")
Expand All @@ -90,10 +95,12 @@ func TestMissingTopicOnPublish(t *testing.T) {

clientEnv := testutils.CreateClientEnvelope()
clientEnv.Aad.TargetTopic = nil
_, err := api.PublishEnvelope(
_, err := api.PublishEnvelopes(
context.Background(),
&message_api.PublishEnvelopeRequest{
PayerEnvelope: testutils.CreatePayerEnvelope(t, clientEnv),
&message_api.PublishEnvelopesRequest{
PayerEnvelopes: []*message_api.PayerEnvelope{
testutils.CreatePayerEnvelope(t, clientEnv),
},
},
)
require.ErrorContains(t, err, "topic")
Expand Down
13 changes: 4 additions & 9 deletions pkg/api/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ func TestQueryEnvelopesByOriginator(t *testing.T) {
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Filter: &message_api.EnvelopesQuery_OriginatorNodeId{
OriginatorNodeId: 2,
},
LastSeen: nil,
OriginatorNodeIds: []uint32{2},
LastSeen: nil,
},
Limit: 0,
},
Expand All @@ -126,7 +124,7 @@ func TestQueryEnvelopesByTopic(t *testing.T) {
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Filter: &message_api.EnvelopesQuery_Topic{Topic: []byte("topicA")},
Topics: [][]byte{[]byte("topicA")},
LastSeen: nil,
},
Limit: 0,
Expand All @@ -145,7 +143,6 @@ func TestQueryEnvelopesFromLastSeen(t *testing.T) {
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Filter: nil,
LastSeen: &message_api.VectorClock{NodeIdToSequenceId: map[uint32]uint64{1: 2}},
},
Limit: 0,
Expand All @@ -164,9 +161,7 @@ func TestQueryEnvelopesWithEmptyResult(t *testing.T) {
context.Background(),
&message_api.QueryEnvelopesRequest{
Query: &message_api.EnvelopesQuery{
Filter: &message_api.EnvelopesQuery_Topic{
Topic: []byte("topicC"),
},
Topics: [][]byte{[]byte("topicC")},
},
Limit: 0,
},
Expand Down
93 changes: 63 additions & 30 deletions pkg/api/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package api
import (
"context"
"database/sql"
"fmt"

"github.com/xmtp/xmtpd/pkg/blockchain"
"github.com/xmtp/xmtpd/pkg/db"
Expand All @@ -21,6 +22,8 @@ import (

const (
maxRequestedRows uint32 = 1000
maxQueriesPerRequest int = 10000
maxTopicLength int = 128
maxVectorClockLength int = 100
)

Expand Down Expand Up @@ -68,11 +71,11 @@ func (s *Service) Close() {
s.log.Info("closed")
}

func (s *Service) BatchSubscribeEnvelopes(
req *message_api.BatchSubscribeEnvelopesRequest,
stream message_api.ReplicationApi_BatchSubscribeEnvelopesServer,
func (s *Service) SubscribeEnvelopes(
req *message_api.SubscribeEnvelopesRequest,
stream message_api.ReplicationApi_SubscribeEnvelopesServer,
) error {
log := s.log.With(zap.String("method", "batchSubscribe"))
log := s.log.With(zap.String("method", "subscribe"))

// Send a header (any header) to fix an issue with Tonic based GRPC clients.
// See: https://github.com/xmtp/libxmtp/pull/58
Expand All @@ -81,21 +84,17 @@ func (s *Service) BatchSubscribeEnvelopes(
return status.Errorf(codes.Internal, "could not send header: %v", err)
}

requests := req.GetRequests()
if len(requests) == 0 {
return status.Errorf(codes.InvalidArgument, "missing requests")
}

ch, err := s.subscribeWorker.listen(stream.Context(), requests)
if err != nil {
query := req.GetQuery()
if err := s.validateQuery(query); err != nil {
return status.Errorf(codes.InvalidArgument, "invalid subscription request: %v", err)
}

ch := s.subscribeWorker.listen(stream.Context(), query)
for {
select {
case envs, open := <-ch:
if open {
err := stream.Send(&message_api.BatchSubscribeEnvelopesResponse{
err := stream.Send(&message_api.SubscribeEnvelopesResponse{
Envelopes: envs,
})
if err != nil {
Expand Down Expand Up @@ -148,6 +147,38 @@ func (s *Service) QueryEnvelopes(
}, nil
}

func (s *Service) validateQuery(
query *message_api.EnvelopesQuery,
) error {
if query == nil {
return fmt.Errorf("missing query")
}

topics := query.GetTopics()
originators := query.GetOriginatorNodeIds()
if len(topics) != 0 && len(originators) != 0 {
return fmt.Errorf(
"cannot filter by both topic and originator in same subscription request",
)
}

numQueries := len(topics) + len(originators)
if numQueries > maxQueriesPerRequest {
return fmt.Errorf(
"too many subscriptions: %d, consider subscribing to fewer topics or subscribing without a filter",
numQueries,
)
}

for _, topic := range topics {
if len(topic) == 0 || len(topic) > maxTopicLength {
return fmt.Errorf("invalid topic: %s", topic)
}
}

return nil
}

func (s *Service) queryReqToDBParams(
req *message_api.QueryEnvelopesRequest,
) (*queries.SelectGatewayEnvelopesParams, error) {
Expand All @@ -160,19 +191,15 @@ func (s *Service) queryReqToDBParams(
}

query := req.GetQuery()
if query == nil {
return nil, status.Errorf(codes.InvalidArgument, "missing query")
if err := s.validateQuery(query); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid query: %v", err)
}

switch filter := query.GetFilter().(type) {
case *message_api.EnvelopesQuery_Topic:
if len(filter.Topic) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing topic")
}
params.Topic = filter.Topic
case *message_api.EnvelopesQuery_OriginatorNodeId:
params.OriginatorNodeID = db.NullInt32(int32(filter.OriginatorNodeId))
default:
// TODO(rich): Properly support batch queries
if len(query.GetTopics()) > 0 {
params.Topic = query.GetTopics()[0]
} else if len(query.GetOriginatorNodeIds()) > 0 {
params.OriginatorNodeID = db.NullInt32(int32(query.GetOriginatorNodeIds()[0]))
}

vc := query.GetLastSeen().GetNodeIdToSequenceId()
Expand All @@ -193,11 +220,14 @@ func (s *Service) queryReqToDBParams(
return &params, nil
}

func (s *Service) PublishEnvelope(
func (s *Service) PublishEnvelopes(
ctx context.Context,
req *message_api.PublishEnvelopeRequest,
) (*message_api.PublishEnvelopeResponse, error) {
clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelope())
req *message_api.PublishEnvelopesRequest,
) (*message_api.PublishEnvelopesResponse, error) {
if len(req.GetPayerEnvelopes()) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "missing payer envelope")
}
clientEnv, err := s.validatePayerInfo(req.GetPayerEnvelopes()[0])
if err != nil {
return nil, err
}
Expand All @@ -212,10 +242,11 @@ func (s *Service) PublishEnvelope(
return nil, err
}
if didPublish {
return &message_api.PublishEnvelopeResponse{}, nil
return &message_api.PublishEnvelopesResponse{}, nil
}

payerBytes, err := proto.Marshal(req.GetPayerEnvelope())
// TODO(rich): Properly support batch publishing
payerBytes, err := proto.Marshal(req.GetPayerEnvelopes()[0])
if err != nil {
return nil, status.Errorf(codes.Internal, "could not marshal envelope: %v", err)
}
Expand All @@ -235,7 +266,9 @@ func (s *Service) PublishEnvelope(
return nil, status.Errorf(codes.Internal, "could not sign envelope: %v", err)
}

return &message_api.PublishEnvelopeResponse{OriginatorEnvelope: originatorEnv}, nil
return &message_api.PublishEnvelopesResponse{
OriginatorEnvelopes: []*message_api.OriginatorEnvelope{originatorEnv},
}, nil
}

func (s *Service) maybePublishToBlockchain(
Expand Down
Loading

0 comments on commit 4047800

Please sign in to comment.