From 26f8e1321cf66395a76e4671e6e67f4cc37f84cf Mon Sep 17 00:00:00 2001 From: Greg Date: Fri, 15 Dec 2017 00:31:55 +0900 Subject: [PATCH] float: return error when marshaling NaN or Inf --- float.go | 7 +++++++ float_test.go | 15 +++++++++++++++ zero/float.go | 7 +++++++ zero/float_test.go | 15 +++++++++++++++ 4 files changed, 44 insertions(+) diff --git a/float.go b/float.go index 5d55229..8c39e3b 100644 --- a/float.go +++ b/float.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "math" "reflect" "strconv" ) @@ -91,6 +92,12 @@ func (f Float) MarshalJSON() ([]byte, error) { if !f.Valid { return []byte("null"), nil } + if math.IsInf(f.Float64, 0) || math.IsNaN(f.Float64) { + return nil, &json.UnsupportedValueError{ + Value: reflect.ValueOf(f.Float64), + Str: strconv.FormatFloat(f.Float64, 'g', -1, 64), + } + } return []byte(strconv.FormatFloat(f.Float64, 'f', -1, 64)), nil } diff --git a/float_test.go b/float_test.go index 6be025a..bf9c833 100644 --- a/float_test.go +++ b/float_test.go @@ -2,6 +2,7 @@ package null import ( "encoding/json" + "math" "testing" ) @@ -170,6 +171,20 @@ func TestFloatScan(t *testing.T) { assertNullFloat(t, null, "scanned null") } +func TestFloatInfNaN(t *testing.T) { + nan := NewFloat(math.NaN(), true) + _, err := nan.MarshalJSON() + if err == nil { + t.Error("expected error for NaN, got nil") + } + + inf := NewFloat(math.Inf(1), true) + _, err = inf.MarshalJSON() + if err == nil { + t.Error("expected error for Inf, got nil") + } +} + func assertFloat(t *testing.T, f Float, from string) { if f.Float64 != 1.2345 { t.Errorf("bad %s float: %f ≠ %f\n", from, f.Float64, 1.2345) diff --git a/zero/float.go b/zero/float.go index ccf6ef6..e998543 100644 --- a/zero/float.go +++ b/zero/float.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "fmt" + "math" "reflect" "strconv" ) @@ -92,6 +93,12 @@ func (f Float) MarshalJSON() ([]byte, error) { if !f.Valid { n = 0 } + if math.IsInf(f.Float64, 0) || math.IsNaN(f.Float64) { + return nil, &json.UnsupportedValueError{ + Value: reflect.ValueOf(f.Float64), + Str: strconv.FormatFloat(f.Float64, 'g', -1, 64), + } + } return []byte(strconv.FormatFloat(n, 'f', -1, 64)), nil } diff --git a/zero/float_test.go b/zero/float_test.go index 9b94cff..6c30ee1 100644 --- a/zero/float_test.go +++ b/zero/float_test.go @@ -2,6 +2,7 @@ package zero import ( "encoding/json" + "math" "testing" ) @@ -176,6 +177,20 @@ func TestFloatScan(t *testing.T) { assertNullFloat(t, null, "scanned null") } +func TestFloatInfNaN(t *testing.T) { + nan := NewFloat(math.NaN(), true) + _, err := nan.MarshalJSON() + if err == nil { + t.Error("expected error for NaN, got nil") + } + + inf := NewFloat(math.Inf(1), true) + _, err = inf.MarshalJSON() + if err == nil { + t.Error("expected error for Inf, got nil") + } +} + func assertFloat(t *testing.T, f Float, from string) { if f.Float64 != 1.2345 { t.Errorf("bad %s float: %f ≠ %f\n", from, f.Float64, 1.2345)