Skip to content

Commit

Permalink
tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-richards committed Jul 22, 2024
1 parent 4679220 commit 87c0256
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 134 deletions.
18 changes: 13 additions & 5 deletions authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,24 @@ func (ra *ReaderAuth) UnmarshalCBOR(data []byte) error {
return cbor.Unmarshal(data, (*cose.UntaggedSign1Message)(ra))
}

type ReaderAuthenticationBytes TaggedEncodedCBOR
type ReaderAuthentication struct {
_ struct{} `cbor:",toarray"`
ReaderAuthentication string
SessionTranscript SessionTranscript
ItemsRequestBytes TaggedEncodedCBOR
}

func NewReaderAuthentication(
sessionTranscript SessionTranscript,
itemsRequestBytes TaggedEncodedCBOR,
) *ReaderAuthentication {
return &ReaderAuthentication{
ReaderAuthentication: "ReaderAuthentication",
SessionTranscript: sessionTranscript,
ItemsRequestBytes: itemsRequestBytes,
}
}

type IssuerAuth cose.UntaggedSign1Message

func (ia *IssuerAuth) MarshalCBOR() ([]byte, error) {
Expand All @@ -40,8 +50,7 @@ type DeviceAuth struct {

func (ia *IssuerAuth) MobileSecurityObjectBytes() (*TaggedEncodedCBOR, error) {
mobileSecurityObjectBytes := new(TaggedEncodedCBOR)
err := cbor.Unmarshal(ia.Payload, mobileSecurityObjectBytes)
if err != nil {
if err := cbor.Unmarshal(ia.Payload, mobileSecurityObjectBytes); err != nil {
return nil, err
}

Expand All @@ -60,8 +69,7 @@ func (ia *IssuerAuth) MobileSecurityObject() (*MobileSecurityObject, error) {
}

mobileSecurityObject := new(MobileSecurityObject)
err = cbor.Unmarshal(mobileSecurityObjectBytesUntagged, mobileSecurityObject)
if err != nil {
if err = cbor.Unmarshal(mobileSecurityObjectBytesUntagged, mobileSecurityObject); err != nil {
return nil, err
}

Expand Down
11 changes: 8 additions & 3 deletions cbor.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,16 @@ var (
decodeModeTaggedEncodedCBOR cbor.DecMode
)

var (
ErrorEmptyTaggedValue = errors.New("empty tagged value")
ErrorEmptyUntaggedValue = errors.New("empty untagged value")
)

func init() {
ts := cbor.NewTagSet()
ts.Add(
cbor.TagOptions{DecTag: cbor.DecTagRequired, EncTag: cbor.EncTagRequired},
reflect.TypeOf(bstr{}),
reflect.TypeOf(bstr(nil)),
TagEncodedCBOR,
)

Expand All @@ -50,7 +55,7 @@ func (tec *TaggedEncodedCBOR) TaggedValue() ([]byte, error) {
return encodeModeTaggedEncodedCBOR.Marshal(tec.untaggedValue)
}

return nil, errors.New("TODO - TaggedValue - empty")
return nil, ErrorEmptyTaggedValue
}

func (tec *TaggedEncodedCBOR) UntaggedValue() ([]byte, error) {
Expand All @@ -67,7 +72,7 @@ func (tec *TaggedEncodedCBOR) UntaggedValue() ([]byte, error) {
return untaggedValue, nil
}

return nil, errors.New("TODO - UntaggedValue - empty")
return nil, ErrorEmptyUntaggedValue
}

func (tec *TaggedEncodedCBOR) MarshalCBOR() ([]byte, error) {
Expand Down
5 changes: 2 additions & 3 deletions cbor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ func TestEncodedCBORTagged(t *testing.T) {
t.Fatal(err)
}

errUntagged := cbor.Unmarshal(testStructBytes, &TaggedEncodedCBOR{})
if errUntagged == nil {
t.Fatal()
if errUntagged := cbor.Unmarshal(testStructBytes, &TaggedEncodedCBOR{}); errUntagged == nil {
t.Fatal("expected error")
}

taggedEncodedCBOR, err := NewTaggedEncodedCBOR(testStructBytes)
Expand Down
2 changes: 1 addition & 1 deletion curves.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/veraison/go-cose"
)

var ErrorUnsupportedCurve = errors.New("Unsupported Curve")
var ErrorUnsupportedCurve = errors.New("unsupported curve")

func NewCOSEKeyFromECDHPublicKey(key ecdh.PublicKey) (*cose.Key, error) {
var coseAlg cose.Algorithm
Expand Down
2 changes: 1 addition & 1 deletion curves_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ func TestKeyConversions(t *testing.T) {
}

if !bytes.Equal(privateKey.PublicKey().Bytes(), publicKey.Bytes()) {
t.Fail()
t.Fatal()
}
}
97 changes: 41 additions & 56 deletions device_engagement.go
Original file line number Diff line number Diff line change
@@ -1,30 +1,57 @@
package mdoc

import (
"fmt"
"errors"

"github.com/fxamacker/cbor/v2"
"github.com/google/uuid"
"github.com/veraison/go-cose"
)

const (
DeviceRetrievalMethodTypeNFC = 1
DeviceRetrievalMethodTypeBLE = 2
DeviceRetrievalMethodTypeWiFiAware = 3
)

var ErrorUnreccognisedReterevalMethod = errors.New("unreccognised retreival method")

type DeviceEngagement struct {
Version string `cbor:"0,keyasint"`
Security Security `cbor:"1,keyasint"`
DeviceRetrievalMethods []DeviceRetrievalMethod `cbor:"2,keyasint,omitempty"`
}

func NewDeviceEngagement(eDeviceKey *cose.Key) (*DeviceEngagement, error) {
security, err := newSecurity(eDeviceKey)
eDeviceKeyBytesUntagged, err := cbor.Marshal(eDeviceKey)
if err != nil {
return nil, err
}

eDeviceKeyBytes, err := NewTaggedEncodedCBOR(eDeviceKeyBytesUntagged)
if err != nil {
return nil, err
}

peripheralServerUUID := uuid.New()
centralClientUUID := uuid.New()
return &DeviceEngagement{
"1.0",
*security,
Security{
CipherSuiteIdentifier: 1,
EDeviceKeyBytes: *eDeviceKeyBytes,
},
[]DeviceRetrievalMethod{
newBleDeviceRetrievalMethod(),
{
Type: DeviceRetrievalMethodTypeBLE,
Version: 1,
RetrievalOptions: BleOptions{
SupportsPeripheralServer: true,
SupportsCentralClient: true,
PeripheralServerUUID: &peripheralServerUUID,
CentralClientUUID: &centralClientUUID,
},
},
},
}, nil
}
Expand All @@ -49,23 +76,6 @@ type Security struct {
EDeviceKeyBytes TaggedEncodedCBOR
}

func newSecurity(eDeviceKey *cose.Key) (*Security, error) {
eDeviceKeyBytesUntagged, err := cbor.Marshal(eDeviceKey)
if err != nil {
return nil, err
}

eDeviceKeyBytes, err := NewTaggedEncodedCBOR(eDeviceKeyBytesUntagged)
if err != nil {
return nil, err
}

return &Security{
CipherSuiteIdentifier: 1,
EDeviceKeyBytes: *eDeviceKeyBytes,
}, nil
}

type DeviceRetrievalMethod struct {
_ struct{} `cbor:",toarray"`
Type uint
Expand All @@ -80,55 +90,30 @@ type intermediateDeviceRetreievalMethod struct {
RetrievalOptions cbor.RawMessage
}

func newBleDeviceRetrievalMethod() DeviceRetrievalMethod {
peripheralServerUUID := uuid.New()
centralClientUUID := uuid.New()
return DeviceRetrievalMethod{
Type: 2,
Version: 1,
RetrievalOptions: BleOptions{
SupportsPeripheralServer: true,
SupportsCentralClient: true,
PeripheralServerUUID: &peripheralServerUUID,
CentralClientUUID: &centralClientUUID,
},
}
}

func (deviceRetrievalMethod *DeviceRetrievalMethod) UnmarshalCBOR(data []byte) error {
intermediateDeviceRetreievalMethod := new(intermediateDeviceRetreievalMethod)
err := cbor.Unmarshal(data, intermediateDeviceRetreievalMethod)
if err != nil {
func (drm *DeviceRetrievalMethod) UnmarshalCBOR(data []byte) error {
var intermediateDeviceRetreievalMethod intermediateDeviceRetreievalMethod
if err := cbor.Unmarshal(data, &intermediateDeviceRetreievalMethod); err != nil {
return err
}

switch intermediateDeviceRetreievalMethod.Type {
case 2:
bleOptions := BleOptions{}
err = cbor.Unmarshal(intermediateDeviceRetreievalMethod.RetrievalOptions, &bleOptions)
if err != nil {
case DeviceRetrievalMethodTypeBLE:
var bleOptions BleOptions
if err := cbor.Unmarshal(intermediateDeviceRetreievalMethod.RetrievalOptions, &bleOptions); err != nil {
return err
}
deviceRetrievalMethod.RetrievalOptions = bleOptions
drm.RetrievalOptions = bleOptions

default:
return &ErrorUnreccognisedReterevalMethod{Type: intermediateDeviceRetreievalMethod.Type}
return ErrorUnreccognisedReterevalMethod
}

deviceRetrievalMethod.Type = intermediateDeviceRetreievalMethod.Type
deviceRetrievalMethod.Version = intermediateDeviceRetreievalMethod.Version
drm.Type = intermediateDeviceRetreievalMethod.Type
drm.Version = intermediateDeviceRetreievalMethod.Version

return nil
}

type ErrorUnreccognisedReterevalMethod struct {
Type uint
}

func (err *ErrorUnreccognisedReterevalMethod) Error() string {
return fmt.Sprintf("DeviceRetrievalMethod - no unmashaller for type %d", err.Type)
}

type RetrievalOptions interface{}

type BleOptions struct {
Expand Down
26 changes: 8 additions & 18 deletions device_engagement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@ import (
)

func TestNewDeviceEngagement(t *testing.T) {
deviceEngagement, err := NewDeviceEngagement(EDeviceKeyPublic)
_, err := NewDeviceEngagement(EDeviceKeyPublic)
if err != nil {
t.Fatal(err)
}

if deviceEngagement == nil {
t.Fatal()
}
}

func TestDeviceEngagementCBORRoundTrip(t *testing.T) {
Expand All @@ -28,15 +24,15 @@ func TestDeviceEngagementCBORRoundTrip(t *testing.T) {

peripheralServerUUID := uuid.New()

deviceEngagement := &DeviceEngagement{
deviceEngagement := DeviceEngagement{
Version: "1.0",
Security: Security{
CipherSuiteIdentifier: 1,
EDeviceKeyBytes: *eDeviceKeyBytes,
},
DeviceRetrievalMethods: []DeviceRetrievalMethod{
{
Type: 2,
Type: DeviceRetrievalMethodTypeBLE,
Version: 1,
RetrievalOptions: BleOptions{
SupportsPeripheralServer: true,
Expand All @@ -49,19 +45,19 @@ func TestDeviceEngagementCBORRoundTrip(t *testing.T) {
},
}

deviceEngagementBytes, err := cbor.Marshal(deviceEngagement)
deviceEngagementBytes, err := cbor.Marshal(&deviceEngagement)
if err != nil {
t.Fatal(err)
}

deviceEngagementUnmarshalled := new(DeviceEngagement)
if err = cbor.Unmarshal(deviceEngagementBytes, deviceEngagementUnmarshalled); err != nil {
var deviceEngagementUnmarshalled DeviceEngagement
if err = cbor.Unmarshal(deviceEngagementBytes, &deviceEngagementUnmarshalled); err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(
deviceEngagement,
deviceEngagementUnmarshalled,
&deviceEngagement,
&deviceEngagementUnmarshalled,
cmp.FilterPath(func(p cmp.Path) bool {
return p.Last().Type() == reflect.TypeOf(TaggedEncodedCBOR{})
}, cmp.Ignore()),
Expand Down Expand Up @@ -99,10 +95,4 @@ func TestDeviceEngagementUnknownMethod(t *testing.T) {
if err = cbor.Unmarshal(deviceEngagementBytes, &deviceEngagementUnmarshalled); err == nil {
t.Fatal("expected error")
}

errUnreccognisedReterevalMethod := err.(*ErrorUnreccognisedReterevalMethod)

if errUnreccognisedReterevalMethod.Type != 123 {
t.Fatal(errUnreccognisedReterevalMethod)
}
}
12 changes: 3 additions & 9 deletions device_response.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package mdoc

import (
"errors"

"github.com/fxamacker/cbor/v2"
)

Expand Down Expand Up @@ -54,10 +52,6 @@ type IssuerSignedItems map[NameSpace][]IssuerSignedItem
func (ins IssuerNameSpaces) IssuerSignedItems() (IssuerSignedItems, error) {
issuerSignedItemss := make(IssuerSignedItems)
for nameSpace, issuerSignedItemBytess := range ins {
if issuerSignedItemBytess == nil {
return nil, errors.New("TODO")
}

issuerSignedItems := make([]IssuerSignedItem, len(issuerSignedItemBytess))
for i, issuerSignedItemBytes := range issuerSignedItemBytess {
issuerSignedItemBytesUntagged, err := issuerSignedItemBytes.UntaggedValue()
Expand Down Expand Up @@ -89,14 +83,14 @@ type DeviceSigned struct {
DeviceAuth DeviceAuth `cbor:"deviceAuth"`
}

func (ds *DeviceSigned) NameSpaces() (DeviceNameSpaces, error) {
func (ds *DeviceSigned) NameSpaces() (*DeviceNameSpaces, error) {
nameSpacesBytesUntagged, err := ds.NameSpacesBytes.UntaggedValue()
if err != nil {
return nil, err
}

deviceNameSpaces := make(DeviceNameSpaces)
if err = cbor.Unmarshal(nameSpacesBytesUntagged, &deviceNameSpaces); err != nil {
deviceNameSpaces := new(DeviceNameSpaces)
if err = cbor.Unmarshal(nameSpacesBytesUntagged, deviceNameSpaces); err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit 87c0256

Please sign in to comment.