From 5ddf7f0489d5c9e03307b8cbe013ac8e427d5f37 Mon Sep 17 00:00:00 2001 From: Connor Kuehl Date: Fri, 6 Sep 2024 11:35:27 -0700 Subject: [PATCH] refactor error surface for the API The functions were too opaque and didn't add much. Now callers should be able to do something simple like var txerr *nbd.TransmissionError if errors.As(err, &txerr) { switch txerr.Code { case nbd.ErrShuttingDown: // ... default: fmt.Println("oh no!") } } --- error.go | 168 +++++++++++++++++++++++++++--------------------- negotiation.go | 29 ++------- transmission.go | 45 +++++-------- 3 files changed, 116 insertions(+), 126 deletions(-) diff --git a/error.go b/error.go index d6da6ab..110cf57 100644 --- a/error.go +++ b/error.go @@ -3,102 +3,124 @@ package nbd import ( - "errors" "fmt" + + "github.com/digitalocean/go-nbd/internal/nbdproto" ) -var ( - ErrUnsupported = errors.New("server does not support option") - ErrPolicy = errors.New("server forbids option") - ErrInvalid = errors.New("invalid option") - ErrPlatform = errors.New("server platform does not support option") - ErrTLSReqd = errors.New("server requires TLS for option") - ErrUnknown = errors.New("requested export is not available") - ErrShutdown = errors.New("server is shutting down") - ErrBlockSizeRequired = errors.New("server requires blocksize assurances from client") - ErrTooBig = errors.New("request or reply is too large to process") - ErrExtHeaderRequired = errors.New("extension header required") - ErrUndefined = errors.New("did not understand error type from server") +type OptionErrorCode uint32 - ErrPerm = errors.New("operation not permitted") - ErrIO = errors.New("input/output error") - ErrNoMem = errors.New("cannot allocate memory") - ErrInval = errors.New("invalid argument") - ErrNoSpc = errors.New("no space left on device") - ErrOverflow = errors.New("value too large for defined data type") - ErrNotSupported = errors.New("operation not supported") - ErrTransportShutdown = errors.New("cannot send after transport endpoint shutdown") +const ( + ErrUnsupported = OptionErrorCode(nbdproto.REP_ERR_UNSUPPORTED) + ErrPolicy = OptionErrorCode(nbdproto.REP_ERR_POLICY) + ErrInvalid = OptionErrorCode(nbdproto.REP_ERR_INVALID) + ErrPlatform = OptionErrorCode(nbdproto.REP_ERR_PLATFORM) + ErrTLSRequired = OptionErrorCode(nbdproto.REP_ERR_TLS_REQUIRED) + ErrUnknown = OptionErrorCode(nbdproto.REP_ERR_UNKNOWN) + ErrShutdown = OptionErrorCode(nbdproto.REP_ERR_SHUTDOWN) + ErrBlockSizeRequired = OptionErrorCode(nbdproto.REP_ERR_BLOCK_SIZE_REQUIRED) + ErrTooBig = OptionErrorCode(nbdproto.REP_ERR_TOO_BIG) + ErrExtHeaderRequired = OptionErrorCode(nbdproto.REP_ERR_EXT_HEADER_REQUIRED) ) +func (o OptionErrorCode) Symbol() string { + switch o { + case ErrUnsupported: + return "REP_ERR_UNSUPPORTED" + case ErrPolicy: + return "REP_ERR_POLICY" + case ErrInvalid: + return "REP_ERR_INVALID" + case ErrPlatform: + return "REP_ERR_PLATFORM" + case ErrTLSRequired: + return "REP_ERR_TLS_REQUIRED" + case ErrUnknown: + return "REP_ERR_UNKNOWN" + case ErrShutdown: + return "REP_ERR_SHUTDOWN" + case ErrBlockSizeRequired: + return "REP_ERR_BLOCK_SIZE_REQUIRED" + case ErrTooBig: + return "REP_ERR_TOO_BIG" + case ErrExtHeaderRequired: + return "REP_ERR_EXT_HEADER_REQUIRED" + default: + return fmt.Sprintf("BUG:%d", uint32(o)) + } +} + +// NegotiationError is a protocol-level error returned by the server +// during the option phase. type NegotiationError struct { - Cause error - Message string + Code OptionErrorCode + Message NullErrorMessage } func (e *NegotiationError) Error() string { - if e.Message != "" { - return fmt.Sprintf("%s: %s", e.Cause.Error(), e.Message) + if e.Message.Valid { + return fmt.Sprintf("%s: %s", e.Message.Value, e.Code.Symbol()) } - return e.Cause.Error() + return e.Code.Symbol() } -func IsUnsupportedErr(err error) bool { return isNegotiationErr(err, ErrUnsupported) } -func IsPolicyErr(err error) bool { return isNegotiationErr(err, ErrPolicy) } -func IsInvalidErr(err error) bool { return isNegotiationErr(err, ErrInvalid) } -func IsPlatformErr(err error) bool { return isNegotiationErr(err, ErrPlatform) } -func IsTLSReqdErr(err error) bool { return isNegotiationErr(err, ErrTLSReqd) } -func IsUnknownErr(err error) bool { return isNegotiationErr(err, ErrUnknown) } -func IsShutdownErr(err error) bool { return isNegotiationErr(err, ErrShutdown) } -func IsBlockSizeRequiredErr(err error) bool { return isNegotiationErr(err, ErrBlockSizeRequired) } -func IsErrTooBig(err error) bool { return isNegotiationErr(err, ErrTooBig) } -func IsExtHeaderRequiredErr(err error) bool { return isNegotiationErr(err, ErrExtHeaderRequired) } -func IsUndefinedErr(err error) bool { return isNegotiationErr(err, ErrUndefined) } +type TransmissionErrorCode uint32 -func isNegotiationErr(err, target error) bool { - if errors.Is(err, target) { - return true - } - var e *NegotiationError - if errors.As(err, &e) { - return errors.Is(e.Cause, target) +const ( + ErrNotPermitted = TransmissionErrorCode(nbdproto.EPERM) + ErrIO = TransmissionErrorCode(nbdproto.EIO) + ErrNoMemory = TransmissionErrorCode(nbdproto.ENOMEM) + ErrInvalidArgument = TransmissionErrorCode(nbdproto.EINVAL) + ErrNoSpaceLeft = TransmissionErrorCode(nbdproto.ENOSPC) + ErrOverflow = TransmissionErrorCode(nbdproto.EOVERFLOW) + ErrNotSupported = TransmissionErrorCode(nbdproto.ENOTSUP) + ErrShuttingDown = TransmissionErrorCode(nbdproto.ESHUTDOWN) +) + +func (t TransmissionErrorCode) Symbol() string { + switch t { + case ErrNotPermitted: + return "EPERM" + case ErrIO: + return "EIO" + case ErrNoMemory: + return "ENOMEM" + case ErrInvalidArgument: + return "EINVAL" + case ErrNoSpaceLeft: + return "ENOSPC" + case ErrOverflow: + return "EOVERFLOW" + case ErrNotSupported: + return "ENOTSUP" + case ErrShuttingDown: + return "ESHUTDOWN" + default: + return fmt.Sprintf("BUG:%d", uint32(t)) } - return false } +// TransmissionError is a protocol-level error returned by the server +// during the transmission phase. type TransmissionError struct { - Cause error - Message string - Offset uint64 - HasOffset bool + Code TransmissionErrorCode + Message NullErrorMessage + Offset NullOffset } func (e *TransmissionError) Error() string { - s := "transmission error" - if e.HasOffset { - s = fmt.Sprintf("%s at offset %d: %s", s, e.Offset, e.Cause.Error()) - } - if e.Message != "" { - s = s + fmt.Sprintf(": %s", e.Message) + if e.Message.Valid { + return fmt.Sprintf("%s: %s", e.Message.Value, e.Code.Symbol()) } - return s + return e.Code.Symbol() } -func IsPermErr(err error) bool { return isTransmissionErr(err, ErrPerm) } -func IsIOErr(err error) bool { return isTransmissionErr(err, ErrIO) } -func IsNoMemErr(err error) bool { return isTransmissionErr(err, ErrNoMem) } -func IsInvalErr(err error) bool { return isTransmissionErr(err, ErrInval) } -func IsNoSpcErr(err error) bool { return isTransmissionErr(err, ErrNoSpc) } -func IsOverflowErr(err error) bool { return isTransmissionErr(err, ErrOverflow) } -func IsNotSupportedErr(err error) bool { return isTransmissionErr(err, ErrNotSupported) } -func IsTransportShutdownErr(err error) bool { return isTransmissionErr(err, ErrTransportShutdown) } +type NullErrorMessage struct { + Value string + Valid bool +} -func isTransmissionErr(err, target error) bool { - if errors.Is(err, target) { - return true - } - var e *TransmissionError - if errors.As(err, &e) { - return errors.Is(e.Cause, target) - } - return false +type NullOffset struct { + Value uint64 + Valid bool } diff --git a/negotiation.go b/negotiation.go index 86e7cbb..dc14c7a 100644 --- a/negotiation.go +++ b/negotiation.go @@ -366,28 +366,11 @@ func isOptError(id uint32) bool { func toOptError(id uint32, payload []byte) error { message := string(payload) - switch id { - case nbdproto.REP_ERR_UNSUPPORTED: - return &NegotiationError{ErrUnsupported, message} - case nbdproto.REP_ERR_POLICY: - return &NegotiationError{ErrPolicy, message} - case nbdproto.REP_ERR_INVALID: - return &NegotiationError{ErrInvalid, message} - case nbdproto.REP_ERR_PLATFORM: - return &NegotiationError{ErrPlatform, message} - case nbdproto.REP_ERR_TLS_REQUIRED: - return &NegotiationError{ErrTLSReqd, message} - case nbdproto.REP_ERR_UNKNOWN: - return &NegotiationError{ErrUnknown, message} - case nbdproto.REP_ERR_SHUTDOWN: - return &NegotiationError{ErrShutdown, message} - case nbdproto.REP_ERR_BLOCK_SIZE_REQUIRED: - return &NegotiationError{ErrBlockSizeRequired, message} - case nbdproto.REP_ERR_TOO_BIG: - return &NegotiationError{ErrTooBig, message} - case nbdproto.REP_ERR_EXT_HEADER_REQUIRED: - return &NegotiationError{ErrExtHeaderRequired, message} - default: - return &NegotiationError{ErrUndefined, fmt.Sprintf("id %x", id)} + return &NegotiationError{ + Code: OptionErrorCode(id), + Message: NullErrorMessage{ + Value: message, + Valid: len(message) > 0, + }, } } diff --git a/transmission.go b/transmission.go index febd8b4..02bb119 100644 --- a/transmission.go +++ b/transmission.go @@ -220,10 +220,14 @@ func (c *Conn) demuxReplies() (err error) { if !ok { return } + var err error + if hdr.Error != 0 { + err = &TransmissionError{Code: TransmissionErrorCode(hdr.Error)} + } r := reply{ simple: &hdr, buf: buf, - err: codeToErr(hdr.Error), + err: err, } stream.replies <- r close(stream.replies) @@ -268,11 +272,17 @@ func (c *Conn) demuxReplies() (err error) { return fmt.Errorf("route replies: structured: read offset: %w", err) } } + m := b.String() replyError = &TransmissionError{ - Cause: codeToErr(code), - Message: b.String(), - HasOffset: hdr.Type == nbdproto.REPLY_TYPE_ERROR_OFFSET, - Offset: offset, + Code: TransmissionErrorCode(code), + Message: NullErrorMessage{ + Value: m, + Valid: len(m) > 0, + }, + Offset: NullOffset{ + Value: offset, + Valid: hdr.Type == nbdproto.REPLY_TYPE_ERROR_OFFSET, + }, } } func() { @@ -329,28 +339,3 @@ func requestTransmit(server io.Writer, cflags uint16, ty uint16, cookie uint64, func isTXError(type_ uint16) bool { return type_&(1<<15) != 0 } - -func codeToErr(id uint32) error { - var cause error - switch id { - case nbdproto.EPERM: - cause = ErrPerm - case nbdproto.EIO: - cause = ErrIO - case nbdproto.ENOMEM: - cause = ErrNoMem - case nbdproto.ENOSPC: - cause = ErrNoSpc - case nbdproto.EOVERFLOW: - cause = ErrOverflow - case nbdproto.ENOTSUP: - cause = ErrNotSupported - case nbdproto.ESHUTDOWN: - cause = ErrTransportShutdown - case 0: - return nil - default: - cause = fmt.Errorf("unrecognized error code %d", id) - } - return cause -}