Skip to content

Commit

Permalink
change counters to support encryption
Browse files Browse the repository at this point in the history
  • Loading branch information
fredcarle committed Jun 7, 2024
1 parent 0c134e5 commit 1e35cd2
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 114 deletions.
8 changes: 3 additions & 5 deletions internal/core/block/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func init() {
&crdt.CRDT{},
&crdt.LWWRegDelta{},
&crdt.CompositeDAGDelta{},
&crdt.CounterDelta[int64]{}, // Only need to call one of the CounterDelta types.
&crdt.CounterDelta{}, // Only need to call one of the CounterDelta types.
)
}

Expand Down Expand Up @@ -149,10 +149,8 @@ func New(delta core.Delta, links []DAGLink, heads ...cid.Cid) *Block {
crdtDelta = crdt.CRDT{LWWRegDelta: delta}
case *crdt.CompositeDAGDelta:
crdtDelta = crdt.CRDT{CompositeDAGDelta: delta}
case *crdt.CounterDelta[int64]:
crdtDelta = crdt.CRDT{CounterDeltaInt: delta}
case *crdt.CounterDelta[float64]:
crdtDelta = crdt.CRDT{CounterDeltaFloat: delta}
case *crdt.CounterDelta:
crdtDelta = crdt.CRDT{CounterDelta: delta}
}

return &Block{
Expand Down
129 changes: 78 additions & 51 deletions internal/core/crdt/counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Incrementable interface {
}

// CounterDelta is a single delta operation for a Counter
type CounterDelta[T Incrementable] struct {
type CounterDelta struct {
DocID []byte
FieldName string
Priority uint64
Expand All @@ -44,69 +44,60 @@ type CounterDelta[T Incrementable] struct {
//
// It can be used to identify the collection datastructure state at the time of commit.
SchemaVersionID string
Data T
Data []byte
}

var _ core.Delta = (*CounterDelta[float64])(nil)
var _ core.Delta = (*CounterDelta[int64])(nil)
var _ core.Delta = (*CounterDelta)(nil)

// IPLDSchemaBytes returns the IPLD schema representation for the type.
//
// This needs to match the [CounterDelta[T]] struct or [coreblock.mustSetSchema] will panic on init.
func (delta *CounterDelta[T]) IPLDSchemaBytes() []byte {
// This needs to match the [CounterDelta] struct or [coreblock.mustSetSchema] will panic on init.
func (delta *CounterDelta) IPLDSchemaBytes() []byte {
return []byte(`
type CounterDeltaFloat struct {
type CounterDelta struct {
docID Bytes
fieldName String
priority Int
nonce Int
schemaVersionID String
data Float
}
type CounterDeltaInt struct {
docID Bytes
fieldName String
priority Int
nonce Int
schemaVersionID String
data Int
data Bytes
}`)
}

// GetPriority gets the current priority for this delta.
func (delta *CounterDelta[T]) GetPriority() uint64 {
func (delta *CounterDelta) GetPriority() uint64 {
return delta.Priority
}

// SetPriority will set the priority for this delta.
func (delta *CounterDelta[T]) SetPriority(prio uint64) {
func (delta *CounterDelta) SetPriority(prio uint64) {
delta.Priority = prio
}

// Counter, is a simple CRDT type that allows increment/decrement
// of an Int and Float data types that ensures convergence.
type Counter[T Incrementable] struct {
type Counter struct {
baseCRDT
AllowDecrement bool
Kind client.ScalarKind
}

var _ core.ReplicatedData = (*Counter[float64])(nil)
var _ core.ReplicatedData = (*Counter[int64])(nil)
var _ core.ReplicatedData = (*Counter)(nil)

// NewCounter returns a new instance of the Counter with the given ID.
func NewCounter[T Incrementable](
func NewCounter(
store datastore.DSReaderWriter,
schemaVersionKey core.CollectionSchemaVersionKey,
key core.DataStoreKey,
fieldName string,
allowDecrement bool,
) Counter[T] {
return Counter[T]{newBaseCRDT(store, key, schemaVersionKey, fieldName), allowDecrement}
kind client.ScalarKind,
) Counter {
return Counter{newBaseCRDT(store, key, schemaVersionKey, fieldName), allowDecrement, kind}
}

// Value gets the current counter value
func (c Counter[T]) Value(ctx context.Context) ([]byte, error) {
func (c Counter) Value(ctx context.Context) ([]byte, error) {
valueK := c.key.WithValueFlag()
buf, err := c.store.Get(ctx, valueK.ToDS())
if err != nil {
Expand All @@ -120,7 +111,7 @@ func (c Counter[T]) Value(ctx context.Context) ([]byte, error) {
// WARNING: Incrementing an integer and causing it to overflow the int64 max value
// will cause the value to roll over to the int64 min value. Incremeting a float and
// causing it to overflow the float64 max value will act like a no-op.
func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], error) {
func (c Counter) Increment(ctx context.Context, value []byte) (*CounterDelta, error) {
// To ensure that the dag block is unique, we add a random number to the delta.
// This is done only on update (if the doc doesn't already exist) to ensure that the
// initial dag block of a document can be reproducible.
Expand All @@ -137,7 +128,7 @@ func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], e
nonce = r.Int64()
}

return &CounterDelta[T]{
return &CounterDelta{
DocID: []byte(c.key.DocID),
FieldName: c.fieldName,
Data: value,
Expand All @@ -148,19 +139,20 @@ func (c Counter[T]) Increment(ctx context.Context, value T) (*CounterDelta[T], e

// Merge implements ReplicatedData interface.
// It merges two CounterRegisty by adding the values together.
func (c Counter[T]) Merge(ctx context.Context, delta core.Delta) error {
d, ok := delta.(*CounterDelta[T])
func (c Counter) Merge(ctx context.Context, delta core.Delta) error {
d, ok := delta.(*CounterDelta)
if !ok {
return ErrMismatchedMergeType
}

return c.incrementValue(ctx, d.Data, d.GetPriority())
}

func (c Counter[T]) incrementValue(ctx context.Context, value T, priority uint64) error {
if !c.AllowDecrement && value < 0 {
return NewErrNegativeValue(value)
}
func (c Counter) incrementValue(
ctx context.Context,
valueAsBytes []byte,
priority uint64,
) error {
key := c.key.WithValueFlag()
marker, err := c.store.Get(ctx, c.key.ToPrimaryDataStoreKey().ToDS())
if err != nil && !errors.Is(err, ds.ErrNotFound) {
Expand All @@ -170,44 +162,79 @@ func (c Counter[T]) incrementValue(ctx context.Context, value T, priority uint64
key = key.WithDeletedFlag()
}

curValue, err := c.getCurrentValue(ctx, key)
if err != nil {
return err
}
var resultAsBytes []byte

newValue := curValue + value
b, err := cbor.Marshal(newValue)
if err != nil {
return err
switch c.Kind {
case client.FieldKind_NILLABLE_INT:
resultAsBytes, err = validateAndIncrement[int64](ctx, c.store, key, valueAsBytes, c.AllowDecrement)
if err != nil {
return err
}
case client.FieldKind_NILLABLE_FLOAT:
resultAsBytes, err = validateAndIncrement[float64](ctx, c.store, key, valueAsBytes, c.AllowDecrement)
if err != nil {
return err
}
default:
return NewErrUnsupportedCounterType(c.Kind)
}

err = c.store.Put(ctx, key.ToDS(), b)
err = c.store.Put(ctx, key.ToDS(), resultAsBytes)
if err != nil {
return NewErrFailedToStoreValue(err)
}

return c.setPriority(ctx, c.key, priority)
}

func (c Counter[T]) getCurrentValue(ctx context.Context, key core.DataStoreKey) (T, error) {
curValue, err := c.store.Get(ctx, key.ToDS())
func validateAndIncrement[T Incrementable](
ctx context.Context,
store datastore.DSReaderWriter,
key core.DataStoreKey,
valueAsBytes []byte,
allowDecrement bool,
) ([]byte, error) {
value, err := getNumericFromBytes[T](valueAsBytes)
if err != nil {
if errors.Is(err, ds.ErrNotFound) {
return 0, nil
}
return 0, err
return nil, err
}

return getNumericFromBytes[T](curValue)
if !allowDecrement && value < 0 {
return nil, NewErrNegativeValue(value)
}

curValue, err := getCurrentValue[T](ctx, store, key)
if err != nil {
return nil, err
}

newValue := curValue + value
return cbor.Marshal(newValue)
}

func (c Counter[T]) CType() client.CType {
func (c Counter) CType() client.CType {
if c.AllowDecrement {
return client.PN_COUNTER
}
return client.P_COUNTER
}

func getCurrentValue[T Incrementable](
ctx context.Context,
store datastore.DSReaderWriter,
key core.DataStoreKey,
) (T, error) {
curValue, err := store.Get(ctx, key.ToDS())
if err != nil {
if errors.Is(err, ds.ErrNotFound) {
return 0, nil
}
return 0, err
}

return getNumericFromBytes[T](curValue)
}

func getNumericFromBytes[T Incrementable](b []byte) (T, error) {
var val T
err := cbor.Unmarshal(b, &val)
Expand Down
15 changes: 11 additions & 4 deletions internal/core/crdt/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
package crdt

import (
"github.com/sourcenetwork/defradb/client"
"github.com/sourcenetwork/defradb/errors"
)

const (
errFailedToGetPriority string = "failed to get priority"
errFailedToStoreValue string = "failed to store value"
errNegativeValue string = "value cannot be negative"
errFailedToGetPriority string = "failed to get priority"
errFailedToStoreValue string = "failed to store value"
errNegativeValue string = "value cannot be negative"
errUnsupportedCounterType string = "unsupported counter typee. Valid types are int64 and float64"
)

// Errors returnable from this package.
Expand All @@ -31,7 +33,8 @@ var (
ErrEncodingPriority = errors.New("error encoding priority")
ErrDecodingPriority = errors.New("error decoding priority")
// ErrMismatchedMergeType - Tying to merge two ReplicatedData of different types
ErrMismatchedMergeType = errors.New("given type to merge does not match source")
ErrMismatchedMergeType = errors.New("given type to merge does not match source")
ErrUnsupportedCounterType = errors.New(errUnsupportedCounterType)
)

// NewErrFailedToGetPriority returns an error indicating that the priority could not be retrieved.
Expand All @@ -47,3 +50,7 @@ func NewErrFailedToStoreValue(inner error) error {
func NewErrNegativeValue[T Incrementable](value T) error {
return errors.New(errNegativeValue, errors.NewKV("Value", value))
}

func NewErrUnsupportedCounterType(valueType client.ScalarKind) error {
return errors.New(errUnsupportedCounterType, errors.NewKV("Type", valueType))
}
36 changes: 12 additions & 24 deletions internal/core/crdt/ipld_union.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import "github.com/sourcenetwork/defradb/internal/core"
type CRDT struct {
LWWRegDelta *LWWRegDelta
CompositeDAGDelta *CompositeDAGDelta
CounterDeltaInt *CounterDelta[int64]
CounterDeltaFloat *CounterDelta[float64]
CounterDelta *CounterDelta
}

// IPLDSchemaBytes returns the IPLD schema representation for the CRDT.
Expand All @@ -28,8 +27,7 @@ func (c CRDT) IPLDSchemaBytes() []byte {
type CRDT union {
| LWWRegDelta "lww"
| CompositeDAGDelta "composite"
| CounterDeltaInt "counterInt"
| CounterDeltaFloat "counterFloat"
| CounterDelta "counter"
} representation keyed`)
}

Expand All @@ -40,10 +38,8 @@ func (c CRDT) GetDelta() core.Delta {
return c.LWWRegDelta
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt
case c.CounterDelta != nil:
return c.CounterDelta
}
return nil
}
Expand All @@ -55,10 +51,8 @@ func (c CRDT) GetPriority() uint64 {
return c.LWWRegDelta.GetPriority()
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.GetPriority()
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.GetPriority()
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.GetPriority()
case c.CounterDelta != nil:
return c.CounterDelta.GetPriority()
}
return 0
}
Expand All @@ -70,10 +64,8 @@ func (c CRDT) GetFieldName() string {
return c.LWWRegDelta.FieldName
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.FieldName
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.FieldName
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.FieldName
case c.CounterDelta != nil:
return c.CounterDelta.FieldName
}
return ""
}
Expand All @@ -85,10 +77,8 @@ func (c CRDT) GetDocID() []byte {
return c.LWWRegDelta.DocID
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.DocID
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.DocID
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.DocID
case c.CounterDelta != nil:
return c.CounterDelta.DocID
}
return nil
}
Expand All @@ -100,10 +90,8 @@ func (c CRDT) GetSchemaVersionID() string {
return c.LWWRegDelta.SchemaVersionID
case c.CompositeDAGDelta != nil:
return c.CompositeDAGDelta.SchemaVersionID
case c.CounterDeltaFloat != nil:
return c.CounterDeltaFloat.SchemaVersionID
case c.CounterDeltaInt != nil:
return c.CounterDeltaInt.SchemaVersionID
case c.CounterDelta != nil:
return c.CounterDelta.SchemaVersionID
}
return ""
}
Expand Down
Loading

0 comments on commit 1e35cd2

Please sign in to comment.