Skip to content

Commit

Permalink
Allow streaming of proto messages over a single network.Stream (#1081)
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak authored Aug 18, 2023
1 parent de52f66 commit ecb6713
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 237 deletions.
21 changes: 21 additions & 0 deletions p2p/starknet/bytereader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package starknet

import (
"io"

"google.golang.org/protobuf/encoding/protodelim"
)

var _ protodelim.Reader = (*byteReader)(nil)

type byteReader struct {
io.Reader
}

func (r *byteReader) ReadByte() (byte, error) {
var b [1]byte
if _, err := r.Read(b[:]); err != nil {
return 0, err
}
return b[0], nil
}
27 changes: 27 additions & 0 deletions p2p/starknet/bytereader_pkg_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package starknet

import (
"bytes"
"io"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestByteReader(t *testing.T) {
buffer := bytes.NewBuffer([]byte{1, 2, 3, 4})
bReader := byteReader{buffer}

read, err := bReader.ReadByte()
require.NoError(t, err)
assert.Equal(t, byte(1), read)

read, err = bReader.ReadByte()
require.NoError(t, err)
assert.Equal(t, byte(2), read)

readAll, err := io.ReadAll(bReader)
require.NoError(t, err)
assert.Equal(t, []byte{3, 4}, readAll)
}
46 changes: 29 additions & 17 deletions p2p/starknet/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/proto"
)

Expand All @@ -26,13 +27,7 @@ func NewClient(newStream NewStreamFunc, protocolID protocol.ID, log utils.Logger
}
}

func (c *Client) sendAndReceiveInto(ctx context.Context, req, res proto.Message) error {
stream, err := c.newStream(ctx, c.protocolID)
if err != nil {
return err
}
defer stream.Close() // todo: dont ignore close errors

func (c *Client) sendAndCloseWrite(stream network.Stream, req proto.Message) error {
reqBytes, err := proto.Marshal(req)
if err != nil {
return err
Expand All @@ -41,33 +36,50 @@ func (c *Client) sendAndReceiveInto(ctx context.Context, req, res proto.Message)
if _, err = stream.Write(reqBytes); err != nil {
return err
}
return stream.CloseWrite()
}

if err = stream.CloseWrite(); err != nil {
func (c *Client) receiveInto(stream network.Stream, res proto.Message) error {
return protodelim.UnmarshalFrom(&byteReader{stream}, res)
}

func (c *Client) sendAndReceiveInto(ctx context.Context, req, res proto.Message) error {
stream, err := c.newStream(ctx, c.protocolID)
if err != nil {
return err
}
defer stream.Close() // todo: dont ignore close errors

buffer := getBuffer()
defer bufferPool.Put(buffer)

if _, err = buffer.ReadFrom(stream); err != nil {
if err = c.sendAndCloseWrite(stream, req); err != nil {
return err
}

return proto.Unmarshal(buffer.Bytes(), res)
return c.receiveInto(stream, res)
}

func (c *Client) GetBlocks(ctx context.Context, req *spec.GetBlocks) (*spec.GetBlocksResponse, error) {
func (c *Client) GetBlocks(ctx context.Context, req *spec.GetBlocks) (Stream[*spec.BlockHeader], error) {
wrappedReq := spec.Request{
Req: &spec.Request_GetBlocks{
GetBlocks: req,
},
}

var res spec.GetBlocksResponse
if err := c.sendAndReceiveInto(ctx, &wrappedReq, &res); err != nil {
stream, err := c.newStream(ctx, c.protocolID)
if err != nil {
return nil, err
}
return &res, nil
if err := c.sendAndCloseWrite(stream, &wrappedReq); err != nil {
return nil, err
}

return func() (*spec.BlockHeader, bool) {
var res spec.BlockHeader
if err := c.receiveInto(stream, &res); err != nil {
stream.Close() // todo: dont ignore close errors
return nil, false
}
return &res, true
}, nil
}

func (c *Client) GetSignatures(ctx context.Context, req *spec.GetSignatures) (*spec.Signatures, error) {
Expand Down
52 changes: 28 additions & 24 deletions p2p/starknet/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/NethermindEth/juno/p2p/starknet/spec"
"github.com/NethermindEth/juno/utils"
"github.com/libp2p/go-libp2p/core/network"
"google.golang.org/protobuf/encoding/protodelim"
"google.golang.org/protobuf/proto"
)

Expand Down Expand Up @@ -65,47 +66,50 @@ func (h *Handler) StreamHandler(stream network.Stream) {
return
}

responseBytes, err := proto.Marshal(response)
if err != nil {
h.log.Debugw("Error marshalling response", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err, "response", response)
return
}

if _, err = stream.Write(responseBytes); err != nil {
h.log.Debugw("Error writing response", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
return
for msg, valid := response(); valid; msg, valid = response() {
if _, err := protodelim.MarshalTo(stream, msg); err != nil { // todo: figure out if we need buffered io here
h.log.Debugw("Error writing response", "peer", stream.ID(), "protocol", stream.Protocol(), "err", err)
}
}
}

func (h *Handler) reqHandler(req *spec.Request) (proto.Message, error) {
func (h *Handler) reqHandler(req *spec.Request) (Stream[proto.Message], error) {
var singleResponse proto.Message
var err error
switch typedReq := req.GetReq().(type) {
case *spec.Request_GetBlocks:
return h.HandleGetBlocks(typedReq.GetBlocks)
case *spec.Request_GetSignatures:
return h.HandleGetSignatures(typedReq.GetSignatures)
singleResponse, err = h.HandleGetSignatures(typedReq.GetSignatures)
case *spec.Request_GetEvents:
return h.HandleGetEvents(typedReq.GetEvents)
singleResponse, err = h.HandleGetEvents(typedReq.GetEvents)
case *spec.Request_GetReceipts:
return h.HandleGetReceipts(typedReq.GetReceipts)
singleResponse, err = h.HandleGetReceipts(typedReq.GetReceipts)
case *spec.Request_GetTransactions:
return h.HandleGetTransactions(typedReq.GetTransactions)
singleResponse, err = h.HandleGetTransactions(typedReq.GetTransactions)
default:
return nil, fmt.Errorf("unhandled request %T", typedReq)
}

if err != nil {
return nil, err
}
return StaticStream[proto.Message](singleResponse), nil
}

func (h *Handler) HandleGetBlocks(req *spec.GetBlocks) (*spec.GetBlocksResponse, error) {
func (h *Handler) HandleGetBlocks(req *spec.GetBlocks) (Stream[proto.Message], error) {
// todo: read from bcReader and adapt to p2p type
return &spec.GetBlocksResponse{
Blocks: []*spec.HeaderAndStateDiff{
{
Header: &spec.BlockHeader{
State: &spec.Merkle{
NLeaves: 251,
},
},
count := uint32(0)
return func() (proto.Message, bool) {
if count > 3 {
return nil, false
}
count++
return &spec.BlockHeader{
State: &spec.Merkle{
NLeaves: count - 1,
},
},
}, true
}, nil
}

Expand Down
11 changes: 0 additions & 11 deletions p2p/starknet/p2p/proto/requests.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import "p2p/proto/block.proto";
import "p2p/proto/event.proto";
import "p2p/proto/receipt.proto";
import "p2p/proto/transaction.proto";
import "p2p/proto/state.proto";


message Request {
oneof req {
Expand All @@ -16,12 +14,3 @@ message Request {
GetTransactions get_transactions = 6;
}
}

message GetBlocksResponse {
repeated HeaderAndStateDiff blocks = 1;
}

message HeaderAndStateDiff {
BlockHeader header = 1;
StateDiff state_diff = 2;
}
Loading

0 comments on commit ecb6713

Please sign in to comment.