Skip to content

Commit

Permalink
sql/postgres: add support for sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
a8m committed Dec 7, 2023
1 parent 386c859 commit 3e804f0
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 69 deletions.
6 changes: 3 additions & 3 deletions cmd/atlas/internal/cmdlog/cmdlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,15 +704,15 @@ func sqlInspect(report *SchemaInspect, indent ...string) (string, error) {
if report.Client.URL.Schema == "" {
changes = append(changes, &schema.AddSchema{S: s})
}
for _, o := range s.Objects {
changes = append(changes, &schema.AddObject{O: o})
}
for _, t := range s.Tables {
changes = append(changes, &schema.AddTable{T: t})
}
for _, v := range s.Views {
changes = append(changes, &schema.AddView{V: v})
}
for _, o := range s.Objects {
changes = append(changes, &schema.AddObject{O: o})
}
for _, f := range s.Funcs {
changes = append(changes, &schema.AddFunc{F: f})
}
Expand Down
28 changes: 13 additions & 15 deletions schemahcl/schemahcl.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,17 +480,9 @@ func (s *State) mayScopeContext(ctx *hcl.EvalContext, scope []string) *hcl.EvalC
return ctx
}
nctx := &hcl.EvalContext{
Variables: make(map[string]cty.Value),
Functions: make(map[string]function.Function),
Variables: make(map[string]cty.Value, len(vars)),
Functions: make(map[string]function.Function, len(funcs)),
}
for p := ctx; p != nil; p = p.Parent() {
for k, v := range p.Variables {
if isRef(v) {
nctx.Variables[k] = v
}
}
}
// Override the parent context with the scoped variables and functions.
for n, v := range vars {
nctx.Variables[n] = v
}
Expand All @@ -500,11 +492,18 @@ func (s *State) mayScopeContext(ctx *hcl.EvalContext, scope []string) *hcl.EvalC
// A patch from the past. Should be moved
// to specific scopes in the future.
nctx.Functions["sql"] = rawExprFunc
for p := ctx; p != nil; p = p.Parent() {
for k, v := range p.Variables {
if isRef(v) {
nctx.Variables[k] = v
}
}
}
return nctx
}

func (s *State) toAttrs(ctx *hcl.EvalContext, vr SchemaValidator, hclAttrs hclsyntax.Attributes, scope []string) ([]*Attr, error) {
var attrs []*Attr
attrs := make([]*Attr, 0, len(hclAttrs))
for _, hclAttr := range hclAttrs {
var (
scope = append(scope, hclAttr.Name)
Expand All @@ -520,6 +519,9 @@ func (s *State) toAttrs(ctx *hcl.EvalContext, vr SchemaValidator, hclAttrs hclsy
at := &Attr{K: hclAttr.Name}
switch t := value.Type(); {
case isRef(value):
if !value.Type().HasAttribute("__ref") {
return nil, fmt.Errorf("%s: invaid reference used in %s", hclAttr.SrcRange, hclAttr.Name)
}
at.V = cty.CapsuleVal(ctyRefType, &Ref{V: value.GetAttr("__ref").AsString()})
case (t.IsTupleType() || t.IsListType() || t.IsSetType()) && value.LengthInt() > 0:
var (
Expand Down Expand Up @@ -641,10 +643,6 @@ func (s *State) toResource(ctx *hcl.EvalContext, vr SchemaValidator, block *hcls
}
spec.Attrs = attrs
for _, blk := range block.Body.Blocks {
ctx, err := setBlockVars(ctx.NewChild(), blk.Body)
if err != nil {
return nil, err
}
r, err := s.toResource(ctx, vr, blk, append(scope, blk.Type))
if err != nil {
return nil, err
Expand Down
10 changes: 9 additions & 1 deletion schemahcl/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,15 @@ func (r *Resource) Resource(t string) (*Resource, bool) {

// Attr returns the Attr by the provided name and reports whether it was found.
func (r *Resource) Attr(name string) (*Attr, bool) {
return attrVal(r.Attrs, name)
if at, ok := attrVal(r.Attrs, name); ok {
return at, true
}
for _, r := range r.Children {
if at, ok := attrVal(r.Attrs, name); ok && r.Type == "" {
return at, true // Match on embedded resource.
}
}
return nil, false
}

// SetAttr sets the Attr on the Resource. If r is nil, a zero value Resource
Expand Down
13 changes: 8 additions & 5 deletions sql/internal/specutil/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ func FromForeignKey(s *schema.ForeignKey) (*sqlspec.ForeignKey, error) {
for _, v := range s.RefColumns {
ref := ColumnRef(v.Name)
if s.Table != s.RefTable {
ref = externalColRef(v.Name, s.RefTable.Name)
ref = ExternalColumnRef(v.Name, s.RefTable.Name)
}
r = append(r, ref)
}
Expand Down Expand Up @@ -992,7 +992,7 @@ func ColumnByRef(t *schema.Table, ref *schemahcl.Ref) (*schema.Column, error) {
}

func externalRef(ref *schemahcl.Ref, sch *schema.Schema) (*schema.Table, *schema.Column, error) {
qualifier, name, err := tableName(ref)
qualifier, name, err := TableName(ref)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -1045,7 +1045,8 @@ func findT[T schema.Object](sch *schema.Schema, qualifier, name string, findT fu
return
}

func tableName(ref *schemahcl.Ref) (string, string, error) {
// TableName returns the qualifier and name from a reference to a table.
func TableName(ref *schemahcl.Ref) (string, string, error) {
return RefName(ref, typeTable)
}

Expand Down Expand Up @@ -1077,14 +1078,16 @@ func ColumnRef(cName string) *schemahcl.Ref {
})
}

func externalColRef(cName string, tName string) *schemahcl.Ref {
// ExternalColumnRef returns the reference of a column by its name and table name.
func ExternalColumnRef(cName string, tName string) *schemahcl.Ref {
return schemahcl.BuildRef([]schemahcl.PathIndex{
{T: typeTable, V: []string{tName}},
{T: typeColumn, V: []string{cName}},
})
}

func qualifiedExternalColRef(cName, tName, sName string) *schemahcl.Ref {
// QualifiedExternalColRef returns the reference of a column by its name and qualified table name.
func QualifiedExternalColRef(cName, tName, sName string) *schemahcl.Ref {
return schemahcl.BuildRef([]schemahcl.PathIndex{
{T: typeTable, V: []string{sName, tName}},
{T: typeColumn, V: []string{cName}},
Expand Down
30 changes: 28 additions & 2 deletions sql/internal/specutil/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,32 @@ func QualifyObjects[T SchemaObject](specs []T) error {
return nil
}

// ColumnRefFinder returns a function that finds a column reference by its table and its schema.
func ColumnRefFinder(specs []*sqlspec.Table) func(s, t, c string) (*schemahcl.Ref, error) {
var refs map[struct{ s, t string }]*sqlspec.Table
return func(s, t, c string) (*schemahcl.Ref, error) {
// Lazily initialize refs.
if refs == nil {
refs = make(map[struct{ s, t string }]*sqlspec.Table, len(specs))
for _, st := range specs {
ns, err := SchemaName(st.Schema)
if err != nil {
return nil, err
}
refs[struct{ s, t string }{s: ns, t: st.Name}] = st
}
}
r, ok := refs[struct{ s, t string }{s: s, t: t}]
if !ok {
return nil, fmt.Errorf("table %q.%q was not found", s, t)
}
if r.Qualifier != "" {
return QualifiedExternalColRef(c, r.Name, r.Qualifier), nil
}
return ExternalColumnRef(c, r.Name), nil
}
}

// QualifyReferences qualifies any reference with qualifier.
func QualifyReferences(tableSpecs []*sqlspec.Table, realm *schema.Realm) error {
type cref struct{ s, t string }
Expand Down Expand Up @@ -183,9 +209,9 @@ func QualifyReferences(tableSpecs []*sqlspec.Table, realm *schema.Realm) error {
}
for i, c := range fk.RefColumns {
if r, ok := byRef[cref{s: fk1.RefTable.Schema.Name, t: fk1.RefTable.Name}]; ok && r.Qualifier != "" {
fk.RefColumns[i] = qualifiedExternalColRef(fk1.RefColumns[i].Name, r.Name, r.Qualifier)
fk.RefColumns[i] = QualifiedExternalColRef(fk1.RefColumns[i].Name, r.Name, r.Qualifier)
} else if r, ok := byRef[cref{t: fk1.RefTable.Name}]; ok && r.Qualifier == "" {
fk.RefColumns[i] = externalColRef(fk1.RefColumns[i].Name, r.Name)
fk.RefColumns[i] = ExternalColumnRef(fk1.RefColumns[i].Name, r.Name)
} else {
return fmt.Errorf("missing reference for column %q in %q.%q.%q", c.V, sname, t.Name, fk.Symbol)
}
Expand Down
67 changes: 67 additions & 0 deletions sql/internal/sqlx/plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,20 @@ func SortChanges(changes []schema.Change) []schema.Change {
return planned
}

// Depender can be implemented by an object to determine if a change to it
// depends on other change, or if an other change depends on it. For example:
// A table creation depends on type creation, and a type deletion depends on
// table deletion.
type Depender interface {
DependsOn(change, other schema.Change) bool
DependencyOf(change, other schema.Change) bool
}

// dependsOn reports if the given change depends on the other change.
func dependsOn(c1, c2 schema.Change) bool {
if dependOnOf(c1, c2) {
return true
}
switch c1 := c1.(type) {
case *schema.AddTable:
switch c2 := c2.(type) {
Expand All @@ -449,6 +461,13 @@ func dependsOn(c1, c2 schema.Change) bool {
if refTo(c1.T.ForeignKeys, c2.T) {
return true
}
case *schema.AddObject:
t, ok := c2.O.(schema.Type)
if ok && slices.ContainsFunc(c1.T.Columns, func(c *schema.Column) bool {
return dependsOnT(c.Type.Type, t)
}) {
return true
}
}
return depOfAdd(c1.T.Deps, c2)
case *schema.DropTable:
Expand Down Expand Up @@ -491,6 +510,20 @@ func dependsOn(c1, c2 schema.Change) bool {
return ok && fk.F.RefTable == c2.T && slices.ContainsFunc(fk.F.RefColumns, func(c *schema.Column) bool { return addC[c] })
})
}
case *schema.AddObject:
t, ok := c2.O.(schema.Type)
if ok && slices.ContainsFunc(c1.Changes, func(c schema.Change) bool {
switch c := c.(type) {
case *schema.AddColumn:
return dependsOnT(c.C.Type.Type, t)
case *schema.ModifyColumn:
return dependsOnT(c.To.Type.Type, t)
default:
return false
}
}) {
return true
}
}
return depOfAdd(c1.T.Deps, c2)
case *schema.AddView:
Expand Down Expand Up @@ -561,6 +594,40 @@ func dependsOn(c1, c2 schema.Change) bool {
return false
}

// dependOnOf checks if the given change depends on the other change or
// vice versa based on their underlying object implementation.
func dependOnOf(change, other schema.Change) bool {
switch change := change.(type) {
case *schema.AddObject:
if d, ok := change.O.(Depender); ok && d.DependsOn(change, other) {
return true
}
case *schema.ModifyObject:
if d, ok := change.To.(Depender); ok && d.DependsOn(change, other) {
return true
}
case *schema.DropObject:
if d, ok := change.O.(Depender); ok && d.DependsOn(change, other) {
return true
}
}
switch other := other.(type) {
case *schema.AddObject:
if d, ok := other.O.(Depender); ok && d.DependencyOf(other, change) {
return true
}
case *schema.ModifyObject:
if d, ok := other.To.(Depender); ok && d.DependencyOf(other, change) {
return true
}
case *schema.DropObject:
if d, ok := other.O.(Depender); ok && d.DependencyOf(other, change) {
return true
}
}
return false
}

// depOfDrops checks if the given object is a dependency of the given change.
func depOfDrop(o schema.Object, c schema.Change) bool {
var deps []schema.Object
Expand Down
17 changes: 14 additions & 3 deletions sql/internal/sqlx/sqlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,15 @@ func ValuesEqual(v1, v2 []string) bool {
// ModeInspectSchema returns the InspectMode or its default.
func ModeInspectSchema(o *schema.InspectOptions) schema.InspectMode {
if o == nil || o.Mode == 0 {
return schema.InspectSchemas | schema.InspectTables | schema.InspectViews | schema.InspectFuncs | schema.InspectTypes
return schema.InspectSchemas | schema.InspectTables | schema.InspectViews | schema.InspectFuncs | schema.InspectTypes | schema.InspectObjects
}
return o.Mode
}

// ModeInspectRealm returns the InspectMode or its default.
func ModeInspectRealm(o *schema.InspectRealmOption) schema.InspectMode {
if o == nil || o.Mode == 0 {
return schema.InspectSchemas | schema.InspectTables | schema.InspectViews | schema.InspectFuncs | schema.InspectTypes
return schema.InspectSchemas | schema.InspectTables | schema.InspectViews | schema.InspectFuncs | schema.InspectTypes | schema.InspectObjects
}
return o.Mode
}
Expand Down Expand Up @@ -265,6 +265,11 @@ func (b *Builder) P(phrases ...string) *Builder {
return b
}

// Int64 writes the given value to the builder in base 10.
func (b *Builder) Int64(v int64) *Builder {
return b.P(strconv.FormatInt(v, 10))
}

// Ident writes the given string quoted as an SQL identifier.
func (b *Builder) Ident(s string) *Builder {
if s != "" {
Expand All @@ -288,6 +293,12 @@ func (b *Builder) Table(t *schema.Table) *Builder {
return b.mayQualify(t.Schema, t.Name)
}

// TableColumn writes the table's resource identifier to the builder, prefixed
// with the schema name if exists.
func (b *Builder) TableColumn(t *schema.Table, c *schema.Column) *Builder {
return b.mayQualify(t.Schema, t.Name, c.Name)
}

// Func writes the function identifier to the builder, prefixed
// with the schema name if exists.
func (b *Builder) Func(f *schema.Func) *Builder {
Expand All @@ -305,7 +316,7 @@ func (b *Builder) Proc(p *schema.Proc) *Builder {
func (b *Builder) TableResource(t *schema.Table, r any) *Builder {
switch c := r.(type) {
case *schema.Column:
return b.mayQualify(t.Schema, t.Name, c.Name)
return b.TableColumn(t, c)
case *schema.Index:
return b.mayQualify(t.Schema, t.Name, c.Name)
default:
Expand Down
23 changes: 19 additions & 4 deletions sql/postgres/driver_oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ func (*inspect) inspectFuncs(context.Context, *schema.Realm, *schema.InspectOpti
return nil // unimplemented.
}

func (i *inspect) inspectTypes(context.Context, *schema.Realm, *schema.InspectOptions) error {
func (*inspect) inspectTypes(context.Context, *schema.Realm, *schema.InspectOptions) error {
return nil // unimplemented.
}

func (i *inspect) inspectDeps(context.Context, *schema.Realm, *schema.InspectOptions) error {
func (*inspect) inspectSequences(context.Context, *schema.Realm, *schema.InspectOptions) error {
return nil // unimplemented.
}

func (*inspect) inspectDeps(context.Context, *schema.Realm, *schema.InspectOptions) error {
return nil // unimplemented.
}

Expand Down Expand Up @@ -179,18 +183,29 @@ func verifyChanges(context.Context, []schema.Change) error {
return nil // unimplemented.
}

func convertDomains(_ []*sqlspec.Table, domains []*Domain, _ *schema.Realm) error {
func convertDomains(_ []*sqlspec.Table, domains []*domain, _ *schema.Realm) error {
if len(domains) > 0 {
return fmt.Errorf("postgres: domains are not supported by this version. Use: https://atlasgo.io/getting-started")
}
return nil
}

func convertSequences(_ []*sqlspec.Table, seqs []*sequence, _ *schema.Realm) error {
if len(seqs) > 0 {
return fmt.Errorf("postgres: sequences are not supported by this version. Use: https://atlasgo.io/getting-started")
}
return nil
}

func qualifySeqRefs([]*sequence, []*sqlspec.Table, *schema.Realm) error {
return nil // unimplemented.
}

// objectSpec converts from a concrete schema objects into specs.
func objectSpec(d *doc, spec *specutil.SchemaSpec, s *schema.Schema) error {
for _, o := range s.Objects {
if e, ok := o.(*schema.EnumType); ok {
d.Enums = append(d.Enums, &Enum{
d.Enums = append(d.Enums, &enum{
Name: e.T,
Values: e.Values,
Schema: specutil.SchemaRef(spec.Schema.Name),
Expand Down
Loading

0 comments on commit 3e804f0

Please sign in to comment.