From 3b77b696050613d2d30eca33224da7fe743bc4c2 Mon Sep 17 00:00:00 2001 From: mazrean Date: Thu, 23 Nov 2023 10:27:19 +0900 Subject: [PATCH] =?UTF-8?q?isutools=E3=82=88=E3=82=8A=E5=88=86=E9=9B=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/dependabot.yaml | 16 ++ .github/workflows/ci.yaml | 58 ++++++ .github/workflows/release.yaml | 34 ++++ .gitignore | 1 + .goreleaser.yaml | 41 +++++ flagType.go | 14 ++ go.mod | 10 + go.sum | 8 + internal/dbdoc/dbdoc.go | 66 +++++++ internal/dbdoc/funcs.go | 216 ++++++++++++++++++++++ internal/dbdoc/graph.go | 187 +++++++++++++++++++ internal/dbdoc/mermaid.go | 130 +++++++++++++ internal/dbdoc/sql.go | 233 ++++++++++++++++++++++++ internal/dbdoc/types.go | 84 +++++++++ internal/pkg/analyze/initialize_func.go | 35 ++++ internal/pkg/list/list.go | 98 ++++++++++ internal/pkg/list/queue.go | 44 +++++ internal/pkg/list/stack.go | 44 +++++ main.go | 52 ++++++ 19 files changed, 1371 insertions(+) create mode 100644 .github/dependabot.yaml create mode 100644 .github/workflows/ci.yaml create mode 100644 .github/workflows/release.yaml create mode 100644 .gitignore create mode 100644 .goreleaser.yaml create mode 100644 flagType.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/dbdoc/dbdoc.go create mode 100644 internal/dbdoc/funcs.go create mode 100644 internal/dbdoc/graph.go create mode 100644 internal/dbdoc/mermaid.go create mode 100644 internal/dbdoc/sql.go create mode 100644 internal/dbdoc/types.go create mode 100644 internal/pkg/analyze/initialize_func.go create mode 100644 internal/pkg/list/list.go create mode 100644 internal/pkg/list/queue.go create mode 100644 internal/pkg/list/stack.go create mode 100644 main.go diff --git a/.github/dependabot.yaml b/.github/dependabot.yaml new file mode 100644 index 0000000..9a56402 --- /dev/null +++ b/.github/dependabot.yaml @@ -0,0 +1,16 @@ +version: 2 +updates: +- package-ecosystem: gomod + directory: "/" + schedule: + interval: weekly + day: saturday + time: "00:00" + timezone: Asia/Tokyo +- package-ecosystem: github-actions + directory: "/" + schedule: + interval: weekly + day: saturday + time: "00:00" + timezone: Asia/Tokyo \ No newline at end of file diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..fdbf920 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,58 @@ +name: CI + +on: + push: + branches: + - "main" + pull_request: + +jobs: + build: + name: Build + runs-on: ubuntu-latest + env: + GOCACHE: "/tmp/go/cache" + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: go.mod + cache: true + - uses: actions/cache@v3 + with: + path: /tmp/go/cache + key: ${{ runner.os }}-go-build-${{ github.ref }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-go-build-${{ github.ref }}- + ${{ runner.os }}-go-build- + - run: go build -o isucrud . + test: + name: Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v4 + with: + go-version-file: go.mod + cache: true + - run: go test ./... -v -coverprofile=./coverage.txt -race -vet=off + - name: Upload coverage data + uses: codecov/codecov-action@v3.1.4 + with: + file: ./coverage.txt + - uses: actions/upload-artifact@v3 + with: + name: coverage.txt + path: coverage.txt + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: golangci-lint + uses: reviewdog/action-golangci-lint@v2.4 + with: + go_version_file: go.mod + reporter: github-pr-check + github_token: ${{ secrets.GITHUB_TOKEN }} + fail_on_error: true diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml new file mode 100644 index 0000000..ea0e040 --- /dev/null +++ b/.github/workflows/release.yaml @@ -0,0 +1,34 @@ +name: Release + +on: + push: + tags: + - "v*" + +env: + APP_NAME: isucrud + +jobs: + build: + name: Release + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-go@v4 + with: + go-version-file: go.mod + - name: Run GoReleaser + uses: goreleaser/goreleaser-action@v5 + with: + args: release --rm-dist + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Upload assets + uses: actions/upload-artifact@v3 + with: + name: assets + path: ./dist/* diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..53c37a1 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +dist \ No newline at end of file diff --git a/.goreleaser.yaml b/.goreleaser.yaml new file mode 100644 index 0000000..d634e2a --- /dev/null +++ b/.goreleaser.yaml @@ -0,0 +1,41 @@ +before: + hooks: + - go mod download +builds: + - env: + - CGO_ENABLED=0 + ldflags: + - -s + - -w + - -X main.version={{.Version}} + - -X main.revision={{.ShortCommit}} + goos: + - linux + - windows + - darwin + main: ./ + +archives: + - format: tar.gz + # this name template makes the OS and Arch compatible with the results of uname. + name_template: >- + {{ .ProjectName }}_ + {{- title .Os }}_ + {{- if eq .Arch "amd64" }}x86_64 + {{- else if eq .Arch "386" }}i386 + {{- else }}{{ .Arch }}{{ end }} + {{- if .Arm }}v{{ .Arm }}{{ end }} + # use zip for windows archives + format_overrides: + - goos: windows + format: zip +checksum: + name_template: 'checksums.txt' +snapshot: + name_template: "{{ incpatch .Version }}-next" +changelog: + sort: asc + filters: + exclude: + - '^docs:' + - '^test:' diff --git a/flagType.go b/flagType.go new file mode 100644 index 0000000..aa69200 --- /dev/null +++ b/flagType.go @@ -0,0 +1,14 @@ +package main + +import "fmt" + +type sliceString []string + +func (ss *sliceString) String() string { + return fmt.Sprintf("%s", *ss) +} + +func (ss *sliceString) Set(value string) error { + *ss = append(*ss, value) + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..008e441 --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module github.com/mazrean/isucrud + +go 1.21.3 + +require golang.org/x/tools v0.15.0 + +require ( + golang.org/x/mod v0.14.0 // indirect + golang.org/x/sys v0.14.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..4778135 --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= +golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= +golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.15.0 h1:zdAyfUGbYmuVokhzVmghFl2ZJh5QhcfebBgmVPFYA+8= +golang.org/x/tools v0.15.0/go.mod h1:hpksKq4dtpQWS1uQ61JkdqWM3LscIS6Slf+VVkm+wQk= diff --git a/internal/dbdoc/dbdoc.go b/internal/dbdoc/dbdoc.go new file mode 100644 index 0000000..bc0686f --- /dev/null +++ b/internal/dbdoc/dbdoc.go @@ -0,0 +1,66 @@ +package dbdoc + +import ( + "fmt" + "go/token" + "os" + + "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/ssa" + "golang.org/x/tools/go/ssa/ssautil" +) + +type Config struct { + WorkDir string + BuildArgs []string + IgnoreFuncs []string + IgnoreFuncPrefixes []string + DestinationFilePath string +} + +func Run(conf Config) error { + ctx := &context{ + fileSet: token.NewFileSet(), + workDir: conf.WorkDir, + } + + ssaProgram, pkgs, err := buildSSA(ctx, conf.BuildArgs) + if err != nil { + return fmt.Errorf("failed to build ssa: %w", err) + } + + funcs, err := buildFuncs(ctx, pkgs, ssaProgram) + if err != nil { + return fmt.Errorf("failed to build funcs: %w", err) + } + + nodes := buildGraph(funcs, conf.IgnoreFuncs, conf.IgnoreFuncPrefixes) + + f, err := os.Create(conf.DestinationFilePath) + if err != nil { + return fmt.Errorf("failed to make directory: %w", err) + } + defer f.Close() + + err = writeMermaid(f, nodes) + if err != nil { + return fmt.Errorf("failed to write mermaid: %w", err) + } + + return nil +} + +func buildSSA(ctx *context, args []string) (*ssa.Program, []*packages.Package, error) { + pkgs, err := packages.Load(&packages.Config{ + Fset: ctx.fileSet, + Mode: packages.NeedFiles | packages.NeedSyntax | packages.NeedTypes | packages.NeedImports | packages.NeedTypesInfo | packages.NeedName | packages.NeedModule, + }, args...) + if err != nil { + return nil, nil, fmt.Errorf("failed to load packages: %w", err) + } + + ssaProgram, _ := ssautil.AllPackages(pkgs, ssa.BareInits) + ssaProgram.Build() + + return ssaProgram, pkgs, nil +} diff --git a/internal/dbdoc/funcs.go b/internal/dbdoc/funcs.go new file mode 100644 index 0000000..945c5cc --- /dev/null +++ b/internal/dbdoc/funcs.go @@ -0,0 +1,216 @@ +package dbdoc + +import ( + "go/constant" + "go/token" + "go/types" + + "github.com/mazrean/isucrud/internal/pkg/list" + "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/ssa" +) + +func buildFuncs(ctx *context, pkgs []*packages.Package, ssaProgram *ssa.Program) ([]function, error) { + var funcs []function + for _, pkg := range pkgs { + for _, def := range pkg.TypesInfo.Defs { + if def == nil { + continue + } + + switch def := def.(type) { + case *types.Func: + ssaFunc := ssaProgram.FuncValue(def) + if ssaFunc == nil { + continue + } + + stringLiterals, calls := analyzeFuncBody(ctx, ssaFunc.Blocks, getPos(ssaFunc.Pos(), def.Pos())) + + anonFuncQueue := list.NewQueue[*ssa.Function]() + for _, anonFunc := range ssaFunc.AnonFuncs { + anonFuncQueue.Push(anonFunc) + } + + for anonFunc, ok := anonFuncQueue.Pop(); ok; anonFunc, ok = anonFuncQueue.Pop() { + anonQueries, anonCalls := analyzeFuncBody(ctx, anonFunc.Blocks, getPos(anonFunc.Pos(), ssaFunc.Pos(), def.Pos())) + stringLiterals = append(stringLiterals, anonQueries...) + calls = append(calls, anonCalls...) + + for _, anonFunc := range anonFunc.AnonFuncs { + anonFuncQueue.Push(anonFunc) + } + } + + if len(stringLiterals) == 0 && len(calls) == 0 { + continue + } + + queries := make([]query, 0, len(stringLiterals)) + for _, strLiteral := range stringLiterals { + newQueries := AnalyzeSQL(ctx, strLiteral) + queries = append(queries, newQueries...) + } + + funcs = append(funcs, function{ + id: def.Id(), + name: def.Name(), + queries: queries, + calls: calls, + }) + } + } + } + + return funcs, nil +} + +func analyzeFuncBody(ctx *context, blocks []*ssa.BasicBlock, pos token.Pos) ([]stringLiteral, []string) { + type ssaValue struct { + value ssa.Value + pos token.Pos + } + var ssaValues []ssaValue + var calls []string + for _, block := range blocks { + for _, instr := range block.Instrs { + switch instr := instr.(type) { + case *ssa.BinOp: + if instr.X != nil { + ssaValues = append(ssaValues, ssaValue{ + value: instr.X, + pos: getPos(instr.X.Pos(), instr.Pos(), pos), + }) + } + + if instr.Y != nil { + ssaValues = append(ssaValues, ssaValue{ + value: instr.Y, + pos: getPos(instr.Y.Pos(), instr.Pos(), pos), + }) + } + case *ssa.ChangeType: + if instr.X != nil { + ssaValues = append(ssaValues, ssaValue{ + value: instr.X, + pos: getPos(instr.X.Pos(), instr.Pos(), pos), + }) + } + case *ssa.Convert: + if instr.X != nil { + ssaValues = append(ssaValues, ssaValue{ + value: instr.X, + pos: getPos(instr.X.Pos(), instr.Pos(), pos), + }) + } + case *ssa.MakeClosure: + for _, bind := range instr.Bindings { + if bind == nil { + ssaValues = append(ssaValues, ssaValue{ + value: bind, + pos: getPos(bind.Pos(), instr.Pos(), pos), + }) + } + } + case *ssa.MultiConvert: + if instr.X != nil { + ssaValues = append(ssaValues, ssaValue{ + value: instr.X, + pos: getPos(instr.X.Pos(), instr.Pos(), pos), + }) + } + case *ssa.Store: + if instr.Val != nil { + ssaValues = append(ssaValues, ssaValue{ + value: instr.Val, + pos: getPos(instr.Val.Pos(), instr.Pos(), pos), + }) + } + case *ssa.Call: + if f, ok := instr.Call.Value.(*ssa.Function); ok { + if f.Object() == nil { + continue + } + calls = append(calls, f.Object().Id()) + } + + for _, arg := range instr.Call.Args { + if arg != nil { + ssaValues = append(ssaValues, ssaValue{ + value: arg, + pos: getPos(arg.Pos(), instr.Pos(), pos), + }) + } + } + case *ssa.Defer: + if f, ok := instr.Call.Value.(*ssa.Function); ok { + if f.Object() == nil { + continue + } + calls = append(calls, f.Object().Id()) + } + + for _, arg := range instr.Call.Args { + if arg != nil { + ssaValues = append(ssaValues, ssaValue{ + value: arg, + pos: getPos(arg.Pos(), instr.Pos(), pos), + }) + } + } + case *ssa.Go: + if f, ok := instr.Call.Value.(*ssa.Function); ok { + if f.Object() == nil { + continue + } + calls = append(calls, f.Object().Id()) + } + + for _, arg := range instr.Call.Args { + if arg != nil { + ssaValues = append(ssaValues, ssaValue{ + value: arg, + pos: getPos(arg.Pos(), instr.Pos(), pos), + }) + } + } + } + } + } + + queries := make([]stringLiteral, 0, len(ssaValues)) + for _, ssaValue := range ssaValues { + strValue, ok := checkValue(ctx, ssaValue.value) + if ok { + queries = append(queries, stringLiteral{ + value: strValue, + pos: ssaValue.pos, + }) + } + } + + return queries, calls +} + +func getPos(posList ...token.Pos) token.Pos { + for _, pos := range posList { + if pos.IsValid() { + return pos + } + } + + return token.NoPos +} + +func checkValue(ctx *context, v ssa.Value) (string, bool) { + constValue, ok := v.(*ssa.Const) + if !ok || constValue == nil || constValue.Value == nil { + return "", false + } + + if constValue.Value.Kind() != constant.String { + return "", false + } + + return constant.StringVal(constValue.Value), true +} diff --git a/internal/dbdoc/graph.go b/internal/dbdoc/graph.go new file mode 100644 index 0000000..351105d --- /dev/null +++ b/internal/dbdoc/graph.go @@ -0,0 +1,187 @@ +package dbdoc + +import ( + "container/list" + "fmt" + "log" + "slices" + "strings" + + "github.com/mazrean/isucrud/internal/pkg/analyze" +) + +func buildGraph(funcs []function, ignoreFuncs, ignoreFuncPrefixes []string) []*node { + type tmpEdge struct { + label string + edgeType edgeType + childID string + } + type tmpNode struct { + *node + edges []tmpEdge + } + tmpNodeMap := make(map[string]tmpNode, len(funcs)) +FUNC_LOOP: + for _, f := range funcs { + if f.name == "main" || analyze.IsInitializeFuncName(f.name) { + continue + } + + for _, ignore := range ignoreFuncs { + if f.name == ignore { + continue FUNC_LOOP + } + } + + for _, ignorePrefix := range ignoreFuncPrefixes { + if strings.HasPrefix(f.name, ignorePrefix) { + continue FUNC_LOOP + } + } + + var edges []tmpEdge + for _, q := range f.queries { + id := tableID(q.table) + tmpNodeMap[id] = tmpNode{ + node: &node{ + id: id, + label: q.table, + nodeType: nodeTypeTable, + }, + } + + var edgeType edgeType + switch q.queryType { + case queryTypeSelect: + edgeType = edgeTypeSelect + case queryTypeInsert: + edgeType = edgeTypeInsert + case queryTypeUpdate: + edgeType = edgeTypeUpdate + case queryTypeDelete: + edgeType = edgeTypeDelete + default: + log.Printf("unknown query type: %v\n", q.queryType) + continue + } + + edges = append(edges, tmpEdge{ + label: "", + edgeType: edgeType, + childID: tableID(q.table), + }) + } + + for _, c := range f.calls { + id := funcID(c) + edges = append(edges, tmpEdge{ + label: "", + edgeType: edgeTypeCall, + childID: id, + }) + } + + slices.SortFunc(edges, func(a, b tmpEdge) int { + switch { + case a.childID < b.childID: + return -1 + case a.childID > b.childID: + return 1 + default: + return 0 + } + }) + edges = slices.Compact(edges) + + id := funcID(f.id) + tmpNodeMap[id] = tmpNode{ + node: &node{ + id: id, + label: f.name, + nodeType: nodeTypeFunction, + }, + edges: edges, + } + } + + type revEdge struct { + label string + edgeType edgeType + parentID string + } + revEdgeMap := make(map[string][]revEdge) + for _, tmpNode := range tmpNodeMap { + for _, tmpEdge := range tmpNode.edges { + revEdgeMap[tmpEdge.childID] = append(revEdgeMap[tmpEdge.childID], revEdge{ + label: tmpEdge.label, + edgeType: tmpEdge.edgeType, + parentID: tmpNode.id, + }) + } + } + + newNodeMap := make(map[string]tmpNode, len(tmpNodeMap)) + nodeQueue := list.New() + for id, node := range tmpNodeMap { + if node.nodeType == nodeTypeTable { + newNodeMap[id] = node + nodeQueue.PushBack(node) + delete(tmpNodeMap, id) + continue + } + } + + for { + element := nodeQueue.Front() + if element == nil { + break + } + nodeQueue.Remove(element) + + node := element.Value.(tmpNode) + for _, edge := range revEdgeMap[node.id] { + parent := tmpNodeMap[edge.parentID] + newNodeMap[edge.parentID] = parent + nodeQueue.PushBack(parent) + } + delete(revEdgeMap, node.id) + } + + var nodes []*node + for _, tmpNode := range newNodeMap { + node := tmpNode.node + for _, tmpEdge := range tmpNode.edges { + child, ok := newNodeMap[tmpEdge.childID] + if !ok { + continue + } + + node.edges = append(node.edges, edge{ + label: tmpEdge.label, + node: child.node, + edgeType: tmpEdge.edgeType, + }) + } + nodes = append(nodes, node) + } + + return nodes +} + +func funcID(functionID string) string { + functionID = strings.Replace(functionID, "(", "", -1) + functionID = strings.Replace(functionID, ")", "", -1) + functionID = strings.Replace(functionID, "[", "", -1) + functionID = strings.Replace(functionID, "]", "", -1) + + return fmt.Sprintf("func:%s", functionID) +} + +func tableID(table string) string { + table = strings.Replace(table, "(", "", -1) + table = strings.Replace(table, ")", "", -1) + table = strings.Replace(table, "[", "", -1) + table = strings.Replace(table, "]", "", -1) + + return fmt.Sprintf("table:%s", table) +} diff --git a/internal/dbdoc/mermaid.go b/internal/dbdoc/mermaid.go new file mode 100644 index 0000000..4d0752f --- /dev/null +++ b/internal/dbdoc/mermaid.go @@ -0,0 +1,130 @@ +package dbdoc + +import ( + "fmt" + "io" + "log" + "strconv" + "strings" +) + +const ( + mermaidHeader = "# DB Graph\n" + + "```mermaid\n" + + "graph LR\n" + + " classDef func fill:" + funcNodeColor + ",fill-opacity:0.5\n" + + " classDef table fill:" + tableNodeColor + ",fill-opacity:0.5\n" + mermaidFooter = "```" + + funcNodeColor = "#1976D2" + tableNodeColor = "#795548" + insertLinkColor = "#CDDC39" + deleteLinkColor = "#F44336" + selectLinkColor = "#78909C" + updateLinkColor = "#FF9800" + callLinkColor = "#BBDEFB" +) + +func writeMermaid(w io.StringWriter, nodes []*node) error { + _, err := w.WriteString(mermaidHeader) + if err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + edgeID := 0 + var insertLinks, deleteLinks, selectLinks, updateLinks, callLinks []string + for _, node := range nodes { + var src string + switch node.nodeType { + case nodeTypeTable: + src = fmt.Sprintf("%s[%s]:::table", node.id, node.label) + case nodeTypeFunction: + src = fmt.Sprintf("%s[%s]:::func", node.id, node.label) + default: + log.Printf("unknown node type: %v\n", node.nodeType) + src = fmt.Sprintf("%s[%s]", node.id, node.label) + } + + for _, edge := range node.edges { + var dst, line string + switch edge.node.nodeType { + case nodeTypeTable: + dst = fmt.Sprintf("%s[%s]:::table", edge.node.id, edge.node.label) + case nodeTypeFunction: + dst = fmt.Sprintf("%s[%s]:::func", edge.node.id, edge.node.label) + default: + log.Printf("unknown node type: %v\n", edge.node.nodeType) + dst = fmt.Sprintf("%s[%s]", edge.node.id, edge.node.label) + } + + line = "--" + + if edge.label == "" { + _, err = w.WriteString(fmt.Sprintf(" %s %s> %s\n", src, line, dst)) + if err != nil { + return fmt.Errorf("failed to write edge: %w\n", err) + } + } else { + _, err = w.WriteString(fmt.Sprintf(" %s %s %s %s> %s\n", src, line, edge.label, line, dst)) + if err != nil { + return fmt.Errorf("failed to write edge: %w\n", err) + } + } + + switch edge.edgeType { + case edgeTypeInsert: + insertLinks = append(insertLinks, strconv.Itoa(edgeID)) + case edgeTypeDelete: + deleteLinks = append(deleteLinks, strconv.Itoa(edgeID)) + case edgeTypeSelect: + selectLinks = append(selectLinks, strconv.Itoa(edgeID)) + case edgeTypeUpdate: + updateLinks = append(updateLinks, strconv.Itoa(edgeID)) + case edgeTypeCall: + callLinks = append(callLinks, strconv.Itoa(edgeID)) + default: + log.Printf("unknown edge type: %v\n", edge.edgeType) + } + + edgeID++ + } + } + + if len(insertLinks) > 0 { + _, err = w.WriteString(fmt.Sprintf(" linkStyle %s stroke:%s,stroke-width:2px\n", strings.Join(insertLinks, ","), insertLinkColor)) + if err != nil { + return fmt.Errorf("failed to write link style: %w\n", err) + } + } + if len(deleteLinks) > 0 { + _, err = w.WriteString(fmt.Sprintf(" linkStyle %s stroke:%s,stroke-width:2px\n", strings.Join(deleteLinks, ","), deleteLinkColor)) + if err != nil { + return fmt.Errorf("failed to write link style: %w\n", err) + } + } + if len(selectLinks) > 0 { + _, err = w.WriteString(fmt.Sprintf(" linkStyle %s stroke:%s,stroke-width:2px\n", strings.Join(selectLinks, ","), selectLinkColor)) + if err != nil { + return fmt.Errorf("failed to write link style: %w\n", err) + } + } + if len(updateLinks) > 0 { + _, err = w.WriteString(fmt.Sprintf(" linkStyle %s stroke:%s,stroke-width:2px\n", strings.Join(updateLinks, ","), updateLinkColor)) + if err != nil { + return fmt.Errorf("failed to write link style: %w\n", err) + } + } + if len(callLinks) > 0 { + _, err = w.WriteString(fmt.Sprintf(" linkStyle %s stroke:%s,stroke-width:2px\n", strings.Join(callLinks, ","), callLinkColor)) + if err != nil { + return fmt.Errorf("failed to write link style: %w\n", err) + } + } + + _, err = w.WriteString(mermaidFooter) + if err != nil { + return fmt.Errorf("failed to write footer: %w", err) + } + + return nil +} diff --git a/internal/dbdoc/sql.go b/internal/dbdoc/sql.go new file mode 100644 index 0000000..27d27ad --- /dev/null +++ b/internal/dbdoc/sql.go @@ -0,0 +1,233 @@ +package dbdoc + +import ( + "fmt" + "go/token" + "log" + "path/filepath" + "regexp" + "strings" + + "github.com/mazrean/isucrud/internal/pkg/list" +) + +var ( + tableRe = regexp.MustCompile("^\\s*[\\[\"'`]?(?P\\w+)[\\]\"'`]?\\s*") + insertRe = regexp.MustCompile("^insert\\s+(ignore\\s+)?(into\\s+)?[\\[\"'`]?(?P
\\w+)[\\]\"'`]?\\s*") + deleteRe = regexp.MustCompile("^delete\\s+from\\s+[\\[\"'`]?(?P
\\w+)[\\]\"'`]?\\s*") + selectKeywords = []string{" where ", " group by ", " having ", " window ", " order by ", "limit ", " for "} +) + +func AnalyzeSQL(ctx *context, sql stringLiteral) []query { + sqlValue := strings.ToLower(sql.value) + + strQueries := extractSubQueries(ctx, sqlValue) + + var queries []query + for _, sqlValue := range strQueries { + newQueries := analyzeSQLWithoutSubQuery(ctx, sqlValue, sql.pos) + for _, query := range newQueries { + fmt.Printf("%s(%s): %s\n", query.queryType, query.table, sqlValue) + } + queries = append(queries, newQueries...) + } + + return queries +} + +var ( + subQueryPrefixRe = regexp.MustCompile(`^\s*\(\s*select\s+`) +) + +func extractSubQueries(ctx *context, sql string) []string { + var subQueries []string + + type subQuery struct { + query string + bracketCount uint + } + + rootQuery := "" + subQueryStack := list.NewStack[*subQuery]() + for i := 0; i < len(sql); i++ { + r := sql[i] + switch r { + case '(': + if subQuery, ok := subQueryStack.Peek(); ok { + subQuery.bracketCount++ + subQuery.query += string(r) + } else { + rootQuery += string(r) + } + + match := subQueryPrefixRe.FindString(sql[i:]) + if len(match) != 0 { + subQueryStack.Push(&subQuery{ + query: match, + bracketCount: 0, + }) + i += len(match) + continue + } + case ')': + if subQuery, ok := subQueryStack.Peek(); ok && subQuery.bracketCount == 0 { + subQueries = append(subQueries, subQuery.query) + subQueryStack.Pop() + } + + if subQuery, ok := subQueryStack.Peek(); ok { + subQuery.bracketCount-- + subQuery.query += string(r) + } else { + rootQuery += string(r) + } + default: + if subQuery, ok := subQueryStack.Peek(); ok { + subQuery.query += string(r) + } else { + rootQuery += string(r) + } + } + } + + for subQuery, ok := subQueryStack.Pop(); ok; subQuery, ok = subQueryStack.Pop() { + subQueries = append(subQueries, subQuery.query) + } + + if rootQuery != "" { + subQueries = append(subQueries, rootQuery) + } + + return subQueries +} + +func analyzeSQLWithoutSubQuery(ctx *context, sqlValue string, pos token.Pos) []query { + var queries []query + switch { + case strings.HasPrefix(sqlValue, "select"): + _, after, found := strings.Cut(sqlValue, " from ") + if !found { + tableNames := tableForm(ctx, sqlValue, pos) + + for _, tableName := range tableNames { + queries = append(queries, query{ + queryType: queryTypeSelect, + table: tableName, + pos: pos, + }) + } + break + } + + tmpTableNames := strings.Split(after, ",") + var tableNames []string + TABLE_LOOP: + for _, tableName := range tmpTableNames { + tableNames = append(tableNames, strings.Split(tableName, " join ")...) + + for _, keyword := range selectKeywords { + if strings.Contains(tableName, keyword) { + break TABLE_LOOP + } + } + } + + for _, tableName := range tableNames { + matches := tableRe.FindStringSubmatch(tableName) + if len(matches) == 0 { + continue + } + + for i, name := range tableRe.SubexpNames() { + if name == "Table" { + queries = append(queries, query{ + queryType: queryTypeSelect, + table: matches[i], + pos: pos, + }) + } + } + } + case strings.HasPrefix(sqlValue, "insert"): + matches := insertRe.FindStringSubmatch(sqlValue) + + for i, name := range insertRe.SubexpNames() { + if name == "Table" { + queries = append(queries, query{ + queryType: queryTypeInsert, + table: matches[i], + pos: pos, + }) + } + } + case strings.HasPrefix(sqlValue, "update"): + afterUpdate := strings.TrimPrefix(sqlValue, "update ") + before, _, found := strings.Cut(afterUpdate, " set ") + if !found { + before = afterUpdate + } + + tmpTableNames := strings.Split(before, ",") + var tableNames []string + for _, tableName := range tmpTableNames { + tableNames = append(tableNames, strings.Split(tableName, " join ")...) + } + + for _, tableName := range tableNames { + matches := tableRe.FindStringSubmatch(tableName) + if len(matches) == 0 { + continue + } + + for i, name := range tableRe.SubexpNames() { + if name == "Table" { + queries = append(queries, query{ + queryType: queryTypeUpdate, + table: matches[i], + pos: pos, + }) + } + } + } + case strings.HasPrefix(sqlValue, "delete"): + matches := deleteRe.FindStringSubmatch(sqlValue) + + for i, name := range deleteRe.SubexpNames() { + if name == "Table" { + queries = append(queries, query{ + queryType: queryTypeDelete, + table: matches[i], + pos: pos, + }) + } + } + } + + return queries +} + +func tableForm(ctx *context, sqlValue string, pos token.Pos) []string { + position := ctx.fileSet.Position(pos) + filename, err := filepath.Rel(ctx.workDir, position.Filename) + if err != nil { + log.Printf("failed to get relative path: %v", err) + return nil + } + + fmt.Printf("query:%s\n", sqlValue) + fmt.Printf("position: %s:%d:%d\n", filename, position.Line, position.Column) + fmt.Print("table name?: ") + var input string + _, err = fmt.Scanln(&input) + if err != nil { + return nil + } + + if input == "" { + return nil + } + + tableNames := strings.Split(input, ",") + + return tableNames +} diff --git a/internal/dbdoc/types.go b/internal/dbdoc/types.go new file mode 100644 index 0000000..b319b87 --- /dev/null +++ b/internal/dbdoc/types.go @@ -0,0 +1,84 @@ +package dbdoc + +import ( + "go/token" +) + +type context struct { + fileSet *token.FileSet + workDir string +} + +type function struct { + id string + name string + queries []query + calls []string +} + +type stringLiteral struct { + value string + pos token.Pos +} + +type query struct { + queryType queryType + table string + pos token.Pos +} + +type queryType uint8 + +const ( + queryTypeSelect queryType = iota + 1 + queryTypeInsert + queryTypeUpdate + queryTypeDelete +) + +func (qt queryType) String() string { + switch qt { + case queryTypeSelect: + return "select" + case queryTypeInsert: + return "insert" + case queryTypeUpdate: + return "update" + case queryTypeDelete: + return "delete" + } + + return "" +} + +type node struct { + id string + label string + nodeType nodeType + edges []edge +} + +type nodeType uint8 + +const ( + nodeTypeUnknown nodeType = iota + nodeTypeTable + nodeTypeFunction +) + +type edge struct { + label string + node *node + edgeType edgeType +} + +type edgeType uint8 + +const ( + edgeTypeUnknown edgeType = iota + edgeTypeInsert + edgeTypeUpdate + edgeTypeDelete + edgeTypeSelect + edgeTypeCall +) diff --git a/internal/pkg/analyze/initialize_func.go b/internal/pkg/analyze/initialize_func.go new file mode 100644 index 0000000..a033e97 --- /dev/null +++ b/internal/pkg/analyze/initialize_func.go @@ -0,0 +1,35 @@ +package analyze + +import ( + "strings" + "unicode" +) + +const ( + initializeKeyword = "initialize" +) + +func IsInitializeFuncName(name string) bool { + words := camelCaseSplit(name) + for _, word := range words { + if strings.ToLower(word) == initializeKeyword { + return true + } + } + + return false +} + +func camelCaseSplit(s string) []string { + var result []string + start := 0 + for i, r := range s { + if unicode.IsUpper(r) { + result = append(result, s[start:i]) + start = i + } + } + result = append(result, s[start:]) + + return result +} diff --git a/internal/pkg/list/list.go b/internal/pkg/list/list.go new file mode 100644 index 0000000..f0793b1 --- /dev/null +++ b/internal/pkg/list/list.go @@ -0,0 +1,98 @@ +package list + +import "container/list" + +type List[T any] struct { + l *list.List +} + +type Element[T any] struct { + e *list.Element +} + +func (e Element[T]) Value() (T, bool) { + if e.e == nil { + var v T + return v, false + } + + return e.e.Value.(T), true +} + +func New[T any]() List[T] { + return List[T]{ + l: list.New(), + } +} + +func (l List[T]) Back() Element[T] { + return Element[T]{ + e: l.l.Back(), + } +} + +func (l List[T]) Front() Element[T] { + return Element[T]{ + e: l.l.Front(), + } +} + +func (l List[T]) Init() List[T] { + l.l.Init() + return l +} + +func (l List[T]) Len() int { + return l.l.Len() +} + +func (l List[T]) PushBack(v T) Element[T] { + return Element[T]{ + e: l.l.PushBack(v), + } +} + +func (l List[T]) PushBackList(other List[T]) { + l.l.PushBackList(other.l) +} +func (l List[T]) PushFront(v T) Element[T] { + return Element[T]{ + e: l.l.PushFront(v), + } +} + +func (l List[T]) PushFrontList(other List[T]) { + l.l.PushFrontList(other.l) +} + +func (l List[T]) Remove(e Element[T]) T { + return l.l.Remove(e.e).(T) +} + +func (l List[T]) InsertAfter(v T, mark Element[T]) Element[T] { + return Element[T]{ + e: l.l.InsertAfter(v, mark.e), + } +} + +func (l List[T]) InsertBefore(v T, mark Element[T]) Element[T] { + return Element[T]{ + e: l.l.InsertBefore(v, mark.e), + } +} + +func (l List[T]) MoveAfter(e, mark Element[T]) { + l.l.MoveAfter(e.e, mark.e) +} + +func (l List[T]) MoveBefore(e Element[T], mark Element[T]) { + l.l.MoveBefore(e.e, mark.e) +} + +func (l List[T]) MoveToBack(e Element[T]) { + l.l.MoveToBack(e.e) +} + +func (l List[T]) MoveToFront(e Element[T]) { + l.l.MoveToFront(e.e) +} diff --git a/internal/pkg/list/queue.go b/internal/pkg/list/queue.go new file mode 100644 index 0000000..0c13853 --- /dev/null +++ b/internal/pkg/list/queue.go @@ -0,0 +1,44 @@ +package list + +type Queue[T any] struct { + l List[T] +} + +func NewQueue[T any]() Queue[T] { + return Queue[T]{ + l: New[T](), + } +} + +func (q Queue[T]) Len() int { + return q.l.Len() +} + +func (q Queue[T]) Push(v T) { + q.l.PushBack(v) +} + +func (q Queue[T]) Pop() (T, bool) { + e := q.l.Front() + if e.e == nil { + var v T + return v, false + } + + q.l.Remove(e) + return e.Value() +} + +func (q Queue[T]) Peek() (T, bool) { + e := q.l.Front() + if e.e == nil { + var v T + return v, false + } + + return e.Value() +} + +func (q Queue[T]) Clear() { + q.l.Init() +} diff --git a/internal/pkg/list/stack.go b/internal/pkg/list/stack.go new file mode 100644 index 0000000..6291aa5 --- /dev/null +++ b/internal/pkg/list/stack.go @@ -0,0 +1,44 @@ +package list + +type Stack[T any] struct { + l List[T] +} + +func NewStack[T any]() Stack[T] { + return Stack[T]{ + l: New[T](), + } +} + +func (s Stack[T]) Len() int { + return s.l.Len() +} + +func (s Stack[T]) Push(v T) { + s.l.PushBack(v) +} + +func (s Stack[T]) Pop() (T, bool) { + e := s.l.Back() + if e.e == nil { + var v T + return v, false + } + + s.l.Remove(e) + return e.Value() +} + +func (s Stack[T]) Peek() (T, bool) { + e := s.l.Back() + if e.e == nil { + var v T + return v, false + } + + return e.Value() +} + +func (s Stack[T]) Clear() { + s.l.Init() +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..aa52eca --- /dev/null +++ b/main.go @@ -0,0 +1,52 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/mazrean/isucrud/internal/dbdoc" +) + +var ( + version = "Unknown" + revision = "Unknown" + + versionFlag bool + dst string + ignores sliceString + ignorePrefixes sliceString +) + +func init() { + flag.BoolVar(&versionFlag, "version", false, "show version") + + flag.StringVar(&dst, "dst", "./dbdoc.md", "destination file") + flag.Var(&ignores, "ignore", "ignore function") + flag.Var(&ignorePrefixes, "ignorePrefix", "ignore function") +} + +func main() { + flag.Parse() + + if versionFlag { + fmt.Printf("iwrapper %s (revision: %s)\n", version, revision) + return + } + + wd, err := os.Getwd() + if err != nil { + panic(fmt.Errorf("failed to get working directory: %w", err)) + } + + err = dbdoc.Run(dbdoc.Config{ + WorkDir: wd, + BuildArgs: flag.Args(), + IgnoreFuncs: ignores, + IgnoreFuncPrefixes: ignorePrefixes, + DestinationFilePath: dst, + }) + if err != nil { + panic(fmt.Errorf("failed to run dbdoc: %w", err)) + } +}