Skip to content

Commit

Permalink
Call Unalias() in IsNilable() to support gotypesalias=1 (#3304)
Browse files Browse the repository at this point in the history
* Call `Unalias()` in `IsNilable()` to support `gotypesalias=1`

Co-authored-by: Gilad Maymon <[email protected]>

* Use `types.Unalias`

* add nested alias tests

---------

Co-authored-by: Gilad Maymon <[email protected]>
  • Loading branch information
noamcohen97 and gmwiz authored Sep 27, 2024
1 parent 8fcf704 commit 4cdeaa2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
4 changes: 4 additions & 0 deletions codegen/config/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ func (b *Binder) CopyModifiersFromAst(t *ast.Type, base types.Type) types.Type {
}

func IsNilable(t types.Type) bool {
// Note that we use types.Unalias rather than code.Unalias here
// because we want to always check the underlying type.
// code.Unalias only unwraps aliases in Go 1.23
t = types.Unalias(t)
if namedType, isNamed := t.(*types.Named); isNamed {
return IsNilable(namedType.Underlying())
}
Expand Down
39 changes: 39 additions & 0 deletions codegen/config/binder_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package config

import (
"fmt"
"go/token"
"go/types"
"testing"

Expand Down Expand Up @@ -264,3 +266,40 @@ func TestEnumBinding(t *testing.T) {
require.Equal(t, bazTwo, baz.EnumValues[1].Object)
require.Equal(t, cf.Schema.Types["Baz"].EnumValues[1], baz.EnumValues[1].Definition)
}

func createTypeAlias(name string, t types.Type) *types.Alias {
var nopos token.Pos
return types.NewAlias(types.NewTypeName(nopos, nil, name, nil), t)
}

func TestIsNilable(t *testing.T) {
type aTest struct {
input types.Type
expected bool
}

theTests := []aTest{
{types.Universe.Lookup("any").Type(), true},
{types.Universe.Lookup("rune").Type(), false},
{types.Universe.Lookup("byte").Type(), false},
{types.Universe.Lookup("error").Type(), true},
{types.Typ[types.Int], false},
{types.Typ[types.String], false},
{types.NewChan(types.SendOnly, types.Typ[types.Int]), true},
{types.NewPointer(types.Typ[types.Int]), true},
{types.NewPointer(types.Typ[types.String]), true},
{types.NewMap(types.Typ[types.Int], types.Typ[types.Int]), true},
{types.NewSlice(types.Typ[types.Int]), true},
{types.NewInterfaceType(nil, nil), true},
{createTypeAlias("interfaceAlias", types.NewInterfaceType(nil, nil)), true},
{createTypeAlias("interfaceNestedAlias", createTypeAlias("interfaceAlias", types.NewInterfaceType(nil, nil))), true},
{createTypeAlias("intAlias", types.Typ[types.Int]), false},
{createTypeAlias("intNestedAlias", createTypeAlias("intAlias", types.Typ[types.Int])), false},
}

for _, at := range theTests {
t.Run(fmt.Sprintf("nilable-%s", at.input.String()), func(t *testing.T) {
require.Equal(t, at.expected, IsNilable(at.input))
})
}
}

0 comments on commit 4cdeaa2

Please sign in to comment.