diff --git a/json_raw_message.go b/json_raw_message.go index d213be2..82c428c 100644 --- a/json_raw_message.go +++ b/json_raw_message.go @@ -4,7 +4,6 @@ import ( "database/sql/driver" "encoding/json" "fmt" - "reflect" ) // JSONRawMessage holds a json.RawMessage. Keep in mind, that the JSON NULL @@ -53,9 +52,14 @@ func (rm *JSONRawMessage) Scan(src any) error { } rm.Valid = true // Copy bytes. - srcBytes, ok := src.([]byte) - if !ok { - return fmt.Errorf("cannot convert to byte slice: %s", reflect.TypeOf(src).String()) + var srcBytes []byte + switch src := src.(type) { + case []byte: + srcBytes = src + case string: + srcBytes = []byte(src) + default: + return fmt.Errorf("unsupported source value type: %T", src) } b := make([]byte, len(srcBytes)) copy(b, srcBytes) diff --git a/json_raw_message_test.go b/json_raw_message_test.go index e602db8..177618c 100644 --- a/json_raw_message_test.go +++ b/json_raw_message_test.go @@ -110,11 +110,11 @@ func (suite *JSONRawMessageScanSuite) TestJSONNull() { func (suite *JSONRawMessageScanSuite) TestUnexpectedValue() { var rm JSONRawMessage - err := rm.Scan("I'm not a byte slice.") + err := rm.Scan(1234) suite.Error(err, "should fail") } -func (suite *JSONRawMessageScanSuite) TestOK() { +func (suite *JSONRawMessageScanSuite) TestOKByteSlice() { v := json.RawMessage(`{"meow":"woof"}`) var rm JSONRawMessage err := rm.Scan([]byte(v)) @@ -123,6 +123,15 @@ func (suite *JSONRawMessageScanSuite) TestOK() { suite.Equal(v, rm.RawMessage, "should scan correct value") } +func (suite *JSONRawMessageScanSuite) TestOKString() { + v := json.RawMessage(`{"meow":"woof"}`) + var rm JSONRawMessage + err := rm.Scan(string(v)) + suite.Require().NoError(err, "should not fail") + suite.True(rm.Valid, "should be valid") + suite.Equal(v, rm.RawMessage, "should scan correct value") +} + func TestJSONRawMessage_Scan(t *testing.T) { suite.Run(t, new(JSONRawMessageScanSuite)) }