diff --git a/common/types.go b/common/types.go index aeb9507d5..1fc96ed9b 100644 --- a/common/types.go +++ b/common/types.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "math/big" - "strings" "github.com/jackc/pgx/v5/pgtype" "github.com/oasisprotocol/oasis-core/go/common/quantity" @@ -21,21 +20,35 @@ func NewBigInt(v int64) BigInt { return BigInt{*big.NewInt(v)} } -func (b BigInt) MarshalText() ([]byte, error) { - return []byte(b.String()), nil -} - -func (b *BigInt) UnmarshalText(text []byte) error { - return b.Int.UnmarshalText(text) -} - func (b BigInt) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf(`"%s"`, b.String())), nil + t, err := b.MarshalText() + if err != nil { + return nil, err + } + return json.Marshal(string(t)) } func (b *BigInt) UnmarshalJSON(text []byte) error { - v := strings.Trim(string(text), "\"") - return b.Int.UnmarshalJSON([]byte(v)) + var s string + err := json.Unmarshal(text, &s) + if err != nil { + return err + } + return b.UnmarshalText([]byte(s)) +} + +func (b BigInt) String() string { + // *big.Int does have a String() method. But the way the Go language + // works, that method on a pointer receiver doesn't get included in + // non-pointer BigInt's method set. In some places this doesn't matter, + // because *big.Int's methods are included in pointer *BigInt's method + // set, and a completely different part of the language set says that + // writing b.String() is fine; it's shorthand for (&b).String(). But + // reflection-driven code like fmt.Printf only looks at method sets and + // not shorthand trickery, so we need this method to make + // fmt.Printf("%v\n", b) show a number instead of dumping the internal + // bytes. + return b.Int.String() } func BigIntFromQuantity(q quantity.Quantity) BigInt { diff --git a/common/types_test.go b/common/types_test.go new file mode 100644 index 000000000..5d94bdd97 --- /dev/null +++ b/common/types_test.go @@ -0,0 +1,33 @@ +package common + +import ( + "encoding/json" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBigInt(t *testing.T) { + var v BigInt + + textRef := []byte("11111111111111111111") + err := v.UnmarshalText(textRef) + require.NoError(t, err) + textRoundTrip, err := v.MarshalText() + require.NoError(t, err) + require.Equal(t, textRef, textRoundTrip) + + jsonRef := []byte("\"22222222222222222222\"") + err = json.Unmarshal(jsonRef, &v) + require.NoError(t, err) + jsonRoundTrip, err := json.Marshal(v) + require.NoError(t, err) + require.Equal(t, jsonRef, jsonRoundTrip) + + stringRef := "33333333333333333333" + err = v.Int.UnmarshalText([]byte(stringRef)) + require.NoError(t, err) + stringRoundTrip := fmt.Sprintf("%v", v) + require.Equal(t, stringRef, stringRoundTrip) +}