diff --git a/crypto/ephemeral.go b/crypto/ephemeral.go new file mode 100644 index 0000000000..87e35b8fa7 --- /dev/null +++ b/crypto/ephemeral.go @@ -0,0 +1,47 @@ +package crypto + +import ( + "crypto/ecdh" + "crypto/rand" + "crypto/sha256" + "fmt" +) + +const ( + // EphemeralKeyLength is the size of the ECDH ephemeral key in bytes. + EphemeralKeyLength = 65 +) + +// EncryptWithEphemeralKey encrypts a key using a randomly generated ephemeral ECDH key and a provided public key. +// It returns the encrypted key prepended with the ephemeral public key. +func EncryptWithEphemeralKey(plainText, publicKeyBytes []byte) ([]byte, error) { + ephemeralPriv, err := ecdh.P256().GenerateKey(rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate ephemeral key: %w", err) + } + + ephPubKeyBytes := ephemeralPriv.PublicKey().Bytes() + sharedSecret := sha256.Sum256(append(ephPubKeyBytes, publicKeyBytes...)) + + return append(ephPubKeyBytes, xorBytes(plainText, sharedSecret[:])...), nil +} + +func xorBytes(data, xor []byte) []byte { + result := make([]byte, len(data)) + for i := range data { + result[i] = data[i] ^ xor[i%len(xor)] + } + return result +} + +// DecryptWithEphemeralKey decrypts data that was encrypted using EncryptWithEphemeralKey. +// It expects the input to be the ephemeral public key followed by the encrypted data. +func DecryptWithEphemeralKey(encryptedData, publicKeyBytes []byte) ([]byte, error) { + ephPubKeyBytes := encryptedData[:EphemeralKeyLength] + cipherText := make([]byte, len(encryptedData)-EphemeralKeyLength) + copy(cipherText, encryptedData[EphemeralKeyLength:]) + + sharedSecret := sha256.Sum256(append(ephPubKeyBytes, publicKeyBytes...)) + + return xorBytes(cipherText, sharedSecret[:]), nil +} diff --git a/net/peer.go b/net/peer.go index 5ffce49bb0..06666d87dc 100644 --- a/net/peer.go +++ b/net/peer.go @@ -461,13 +461,7 @@ func (p *Peer) handleDocUpdateLog(evt event.Update) error { } func (p *Peer) handleEncryptionKeyRequest(evt encryption.RequestKeyEvent) error { - req := &pb.FetchEncryptionKeyRequest{ - DocID: []byte(evt.DocID), - Cid: evt.Cid.Bytes(), - SchemaRoot: []byte(evt.SchemaRoot), - } - - if err := p.server.requestEncryptionKey(p.ctx, req); err != nil { + if err := p.server.requestEncryptionKey(p.ctx, evt.DocID, evt.Cid, evt.SchemaRoot); err != nil { return NewErrPublishingToDocIDTopic(err, evt.Cid.String(), evt.DocID) } diff --git a/net/server.go b/net/server.go index 8dd65b6c5e..a3d326d8a2 100644 --- a/net/server.go +++ b/net/server.go @@ -14,6 +14,7 @@ package net import ( "context" + "crypto/sha256" "fmt" "sync" @@ -28,7 +29,9 @@ import ( grpcpeer "google.golang.org/grpc/peer" "google.golang.org/protobuf/proto" + libp2pCrypto "github.com/libp2p/go-libp2p/core/crypto" "github.com/sourcenetwork/defradb/client" + "github.com/sourcenetwork/defradb/crypto" "github.com/sourcenetwork/defradb/errors" "github.com/sourcenetwork/defradb/event" coreblock "github.com/sourcenetwork/defradb/internal/core/block" @@ -167,32 +170,91 @@ func (s *server) PushLog(ctx context.Context, req *pb.PushLogRequest) (*pb.PushL return &pb.PushLogReply{}, nil } -func (s *server) TryGenEncryptionKey( - ctx context.Context, - req *pb.FetchEncryptionKeyRequest, -) (*pb.FetchEncryptionKeyReply, error) { +func (s *server) TryGenEncryptionKey(ctx context.Context, req *pb.FetchEncryptionKeyRequest) (*pb.FetchEncryptionKeyReply, error) { + isValid, err := s.verifyRequestSignature(req) + if err != nil { + return nil, errors.Wrap("invalid signature", err) + } + + if !isValid { + return nil, errors.New("invalid signature") + } + + pubKey, err := libp2pCrypto.PublicKeyFromProto(req.PublicKey) + if err != nil { + return nil, errors.Wrap("failed to unmarshal public key", err) + } + + if err := s.verifyPeerInfo(libpeer.ID(req.PeerInfo.Id), pubKey); err != nil { + return nil, errors.Wrap("invalid peer info", err) + } + docID, err := client.NewDocIDFromString(string(req.DocID)) if err != nil { return nil, err } encKey, err := encryption.GetDocKey(encryption.ContextWithStore(ctx, s.peer.encstore), docID.String()) + if err != nil || len(encKey) == 0 { + return nil, err + } + pubKeyBytes, err := pubKey.Raw() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get raw Ed25519 public key: %w", err) } - if len(encKey) == 0 { - return nil, nil + encryptedKey, err := crypto.EncryptWithEphemeralKey(encKey, pubKeyBytes) + if err != nil { + return nil, errors.Wrap("failed to encrypt key for requester", err) } res := &pb.FetchEncryptionKeyReply{ - EncryptionKey: encKey, + EncryptedKey: encryptedKey, + Cid: req.Cid, + SchemaRoot: req.SchemaRoot, + } + + res.Signature, err = s.signResponse(res) + if err != nil { + return nil, errors.Wrap("failed to sign response", err) } return res, nil } +func (s *server) verifyRequestSignature(req *pb.FetchEncryptionKeyRequest) (bool, error) { + pubKey, err := libp2pCrypto.PublicKeyFromProto(req.PublicKey) + if err != nil { + return false, err + } + + return pubKey.Verify(hashFetchEncryptionKeyRequest(req), req.Signature) +} + +func (s *server) verifyPeerInfo(peerID libpeer.ID, pubKey libp2pCrypto.PubKey) error { + derivedID, err := peer.IDFromPublicKey(pubKey) + if err != nil { + return err + } + + if peerID != derivedID { + return errors.New("peer ID does not match public key") + } + + return nil +} + +func (s *server) signResponse(res *pb.FetchEncryptionKeyReply) ([]byte, error) { + hash := sha256.New() + hash.Write(res.EncryptedKey) + hash.Write(res.Cid) + hash.Write(res.SchemaRoot) + + privKey := s.peer.host.Peerstore().PrivKey(s.peer.host.ID()) + return privKey.Sign(hash.Sum(nil)) +} + // GetHeadLog receives a get head log request func (s *server) GetHeadLog( ctx context.Context, @@ -332,7 +394,7 @@ func (s *server) publishLog(ctx context.Context, topic string, req *pb.PushLogRe data, err := req.MarshalVT() if err != nil { - return errors.Wrap("failed marshling pubsub message", err) + return errors.Wrap("failed to marshal pubsub message", err) } if _, err := t.Publish(ctx, data, rpc.WithIgnoreResponse(true)); err != nil { @@ -341,55 +403,139 @@ func (s *server) publishLog(ctx context.Context, topic string, req *pb.PushLogRe return nil } +func toProtoPeerInfo(peerInfo libpeer.AddrInfo) *pb.PeerInfo { + protoPeerInfo := new(pb.PeerInfo) + protoPeerInfo.Id = []byte(peerInfo.ID) + protoPeerInfo.Addresses = make([]string, len(peerInfo.Addrs)) + for i, addr := range peerInfo.Addrs { + protoPeerInfo.Addresses[i] = addr.String() + } + return protoPeerInfo +} + +func (s *server) prepareFetchEncryptionKeyRequest(docID string, cid cid.Cid, schemaRoot string) (*pb.FetchEncryptionKeyRequest, error) { + publicKey := s.peer.host.Peerstore().PubKey(s.peer.host.ID()) + protoPublicKey, err := libp2pCrypto.PublicKeyToProto(publicKey) + if err != nil { + return nil, errors.Wrap("failed to marshal public key", err) + } + + req := &pb.FetchEncryptionKeyRequest{ + DocID: []byte(docID), + Cid: cid.Bytes(), + SchemaRoot: []byte(schemaRoot), + PublicKey: protoPublicKey, + PeerInfo: toProtoPeerInfo(s.peer.PeerInfo()), + } + + req.Signature, err = s.signRequest(req) + if err != nil { + return nil, errors.Wrap("failed to sign request", err) + } + + return req, nil +} + // requestEncryptionKey publishes the given FetchEncryptionKeyRequest object on the PubSub network -func (s *server) requestEncryptionKey(ctx context.Context, req *pb.FetchEncryptionKeyRequest) error { +func (s *server) requestEncryptionKey(ctx context.Context, docID string, cid cid.Cid, schemaRoot string) error { if s.peer.ps == nil { // skip if we aren't running with a pubsub net return nil } - s.mu.Lock() - t, ok := s.topics[encryptionTopic] - s.mu.Unlock() - if !ok { - err := s.addPubSubTopic(encryptionTopic, false) - if err != nil { - return errors.Wrap(fmt.Sprintf("failed to created single use topic %s", encryptionTopic), err) - } - return s.requestEncryptionKey(ctx, req) + + req, err := s.prepareFetchEncryptionKeyRequest(docID, cid, schemaRoot) + if err != nil { + return err } data, err := req.MarshalVT() if err != nil { - return errors.Wrap("failed marshling pubsub message", err) + return errors.Wrap("failed to marshal pubsub message", err) } + s.mu.Lock() + t := s.topics[encryptionTopic] + s.mu.Unlock() respChan, err := t.Publish(ctx, data) if err != nil { return errors.Wrap(fmt.Sprintf("failed publishing to thread %s", encryptionTopic), err) } + go func() { s.handleFetchEncryptionKeyResponse(<-respChan, req) }() + return nil } +func hashFetchEncryptionKeyRequest(req *pb.FetchEncryptionKeyRequest) []byte { + hash := sha256.New() + hash.Write(req.DocID) + hash.Write(req.Cid) + hash.Write(req.SchemaRoot) + hash.Write([]byte(req.PublicKey.Type.String())) + hash.Write(req.PublicKey.Data) + hash.Write([]byte(req.PeerInfo.String())) + return hash.Sum(nil) +} + +func (s *server) signRequest(req *pb.FetchEncryptionKeyRequest) ([]byte, error) { + privKey := s.peer.host.Peerstore().PrivKey(s.peer.host.ID()) + return privKey.Sign(hashFetchEncryptionKeyRequest(req)) +} + // handleFetchEncryptionKeyResponse handles incoming FetchEncryptionKeyResponse messages -func (s *server) handleFetchEncryptionKeyResponse( - resp rpc.Response, - req *pb.FetchEncryptionKeyRequest, -) { +func (s *server) handleFetchEncryptionKeyResponse(resp rpc.Response, req *pb.FetchEncryptionKeyRequest) { var keyResp pb.FetchEncryptionKeyReply - err := proto.Unmarshal(resp.Data, &keyResp) - if err != nil { + if err := proto.Unmarshal(resp.Data, &keyResp); err != nil { log.ErrorContextE(s.peer.ctx, "Failed to unmarshal encryption key response", err) return } + + isValid, err := s.verifyResponseSignature(&keyResp, resp.From) + if err != nil { + log.ErrorContextE(s.peer.ctx, "Failed to verify response signature", err) + return + } + + if !isValid { + log.ErrorContext(s.peer.ctx, "Invalid response signature") + return + } + + privateKey := s.peer.host.Peerstore().PrivKey(s.peer.host.ID()) + + // Use the private key to get the public key bytes + ed25519PubKeyBytes, err := privateKey.GetPublic().Raw() + if err != nil { + log.ErrorContextE(s.peer.ctx, "failed to get raw Ed25519 public key", err) + return + } + + decryptedKey, err := crypto.DecryptWithEphemeralKey(keyResp.EncryptedKey, ed25519PubKeyBytes) + if err != nil { + log.ErrorContextE(s.peer.ctx, "Failed to decrypt encryption key", err) + return + } + cid, err := cid.Cast(req.Cid) if err != nil { log.ErrorContextE(s.peer.ctx, "Failed to parse CID", err) return } + s.peer.bus.Publish(encryption.NewKeyRetrievedMessage( - string(req.DocID), cid, string(req.SchemaRoot), keyResp.EncryptionKey)) + string(req.DocID), cid, string(req.SchemaRoot), decryptedKey)) +} + +func (s *server) verifyResponseSignature(res *pb.FetchEncryptionKeyReply, fromPeer peer.ID) (bool, error) { + pubKey := s.peer.host.Peerstore().PubKey(fromPeer) + + hash := sha256.New() + hash.Write(res.EncryptedKey) + hash.Write(res.Cid) + hash.Write(res.SchemaRoot) + + return pubKey.Verify(hash.Sum(nil), res.Signature) } // pubSubMessageHandler handles incoming PushLog messages from the pubsub network. @@ -426,7 +572,7 @@ func (s *server) pubSubEncryptionMessageHandler(from libpeer.ID, topic string, m }) res, err := s.TryGenEncryptionKey(ctx, req) if err != nil { - return nil, errors.Wrap(fmt.Sprintf("Failed pushing log for doc %s", topic), err) + return nil, errors.Wrap("Failed attempt to get encryption key", err) } return res.MarshalVT() //return proto.Marshal(res)