Skip to content

Commit

Permalink
loop check
Browse files Browse the repository at this point in the history
  • Loading branch information
mazrean committed Aug 17, 2024
1 parent 002807f commit bfd822e
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 22 deletions.
7 changes: 6 additions & 1 deletion dbdoc/dbdoc.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ func Run(conf Config) error {
return fmt.Errorf("failed to build ssa: %w", err)
}

funcs, err := BuildFuncs(ctx, pkgs, ssaProgram)
loopRangeMap, err := BuildLoopRangeMap(ctx)
if err != nil {
return fmt.Errorf("failed to build loop range map: %w", err)
}

funcs, err := BuildFuncs(ctx, pkgs, ssaProgram, loopRangeMap)
if err != nil {
return fmt.Errorf("failed to build funcs: %w", err)
}
Expand Down
47 changes: 39 additions & 8 deletions dbdoc/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"golang.org/x/tools/go/ssa"
)

func BuildFuncs(ctx *Context, pkgs []*packages.Package, ssaProgram *ssa.Program) ([]function, error) {
func BuildFuncs(ctx *Context, pkgs []*packages.Package, ssaProgram *ssa.Program, loopRangeMap LoopRangeMap) ([]function, error) {
var funcs []function
for _, pkg := range pkgs {
for _, def := range pkg.TypesInfo.Defs {
Expand Down Expand Up @@ -52,11 +52,28 @@ func BuildFuncs(ctx *Context, pkgs []*packages.Package, ssaProgram *ssa.Program)
queries = append(queries, newQueries...)
}

loopRanges := loopRangeMap[ssaFunc.Name()]
queriesInLoop := make([]inLoop[query], 0, len(queries))
for _, q := range queries {
queriesInLoop = append(queriesInLoop, inLoop[query]{
value: q,
inLoop: loopRanges.Search(ctx.FileSet, q.pos),
})
}

callsInLoop := make([]inLoop[string], 0, len(calls))
for _, call := range calls {
callsInLoop = append(callsInLoop, inLoop[string]{
value: call.id,
inLoop: loopRanges.Search(ctx.FileSet, call.pos),
})
}

funcs = append(funcs, function{
id: def.Id(),
name: def.Name(),
queries: queries,
calls: calls,
queries: queriesInLoop,
calls: callsInLoop,
})
}
}
Expand All @@ -65,13 +82,18 @@ func BuildFuncs(ctx *Context, pkgs []*packages.Package, ssaProgram *ssa.Program)
return funcs, nil
}

func analyzeFuncBody(ctx *Context, blocks []*ssa.BasicBlock, pos token.Pos) ([]stringLiteral, []string) {
type funcCall struct {
id string
pos token.Pos
}

func analyzeFuncBody(ctx *Context, blocks []*ssa.BasicBlock, pos token.Pos) ([]stringLiteral, []funcCall) {
type ssaValue struct {
value ssa.Value
pos token.Pos
}
var ssaValues []ssaValue
var calls []string
var calls []funcCall
for _, block := range blocks {
for _, instr := range block.Instrs {
switch instr := instr.(type) {
Expand Down Expand Up @@ -131,7 +153,10 @@ func analyzeFuncBody(ctx *Context, blocks []*ssa.BasicBlock, pos token.Pos) ([]s
if f.Object() == nil {
continue
}
calls = append(calls, f.Object().Id())
calls = append(calls, funcCall{
id: f.Object().Id(),
pos: getPos(f.Pos(), instr.Pos(), pos),
})
}

for _, arg := range instr.Call.Args {
Expand All @@ -147,7 +172,10 @@ func analyzeFuncBody(ctx *Context, blocks []*ssa.BasicBlock, pos token.Pos) ([]s
if f.Object() == nil {
continue
}
calls = append(calls, f.Object().Id())
calls = append(calls, funcCall{
id: f.Object().Id(),
pos: getPos(instr.Call.Pos(), instr.Pos(), pos),
})
}

for _, arg := range instr.Call.Args {
Expand All @@ -163,7 +191,10 @@ func analyzeFuncBody(ctx *Context, blocks []*ssa.BasicBlock, pos token.Pos) ([]s
if f.Object() == nil {
continue
}
calls = append(calls, f.Object().Id())
calls = append(calls, funcCall{
id: f.Object().Id(),
pos: getPos(instr.Call.Pos(), instr.Pos(), pos),
})
}

for _, arg := range instr.Call.Args {
Expand Down
24 changes: 13 additions & 11 deletions dbdoc/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ func BuildGraph(funcs []function, ignoreFuncs, ignoreFuncPrefixes []string, igno
label string
edgeType edgeType
childID string
inLoop bool
}
type tmpNode struct {
*node
Expand Down Expand Up @@ -42,17 +43,17 @@ FUNC_LOOP:

var edges []tmpEdge
for _, q := range f.queries {
id := tableID(q.table)
id := tableID(q.value.table)
tmpNodeMap[id] = tmpNode{
node: &node{
id: id,
label: q.table,
label: q.value.table,
nodeType: nodeTypeTable,
},
}

var edgeType edgeType
switch q.queryType {
switch q.value.queryType {
case queryTypeSelect:
edgeType = edgeTypeSelect
case queryTypeInsert:
Expand All @@ -62,23 +63,25 @@ FUNC_LOOP:
case queryTypeDelete:
edgeType = edgeTypeDelete
default:
log.Printf("unknown query type: %v\n", q.queryType)
log.Printf("unknown query type: %v\n", q.value.queryType)
continue
}

edges = append(edges, tmpEdge{
label: "",
edgeType: edgeType,
childID: tableID(q.table),
childID: tableID(q.value.table),
inLoop: q.inLoop,
})
}

for _, c := range f.calls {
id := funcID(c)
id := funcID(c.value)
edges = append(edges, tmpEdge{
label: "",
edgeType: edgeTypeCall,
childID: id,
inLoop: c.inLoop,
})
}

Expand Down Expand Up @@ -109,6 +112,7 @@ FUNC_LOOP:
label string
edgeType edgeType
parentID string
inLoop bool
}
revEdgeMap := make(map[string][]revEdge)
for _, tmpNode := range tmpNodeMap {
Expand All @@ -117,6 +121,7 @@ FUNC_LOOP:
label: tmpEdge.label,
edgeType: tmpEdge.edgeType,
parentID: tmpNode.id,
inLoop: tmpEdge.inLoop,
})
}
}
Expand All @@ -132,11 +137,7 @@ FUNC_LOOP:
}
}

for {
element := nodeQueue.Front()
if element == nil {
break
}
for element := nodeQueue.Front(); element != nil; element = nodeQueue.Front() {
nodeQueue.Remove(element)

node := element.Value.(tmpNode)
Expand All @@ -161,6 +162,7 @@ FUNC_LOOP:
label: tmpEdge.label,
node: child.node,
edgeType: tmpEdge.edgeType,
inLoop: tmpEdge.inLoop,
})
}
nodes = append(nodes, node)
Expand Down
54 changes: 54 additions & 0 deletions dbdoc/loopmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dbdoc

import (
"fmt"
"go/ast"
"go/parser"
)

func BuildLoopRangeMap(ctx *Context) (LoopRangeMap, error) {
astPkgs, err := parser.ParseDir(ctx.FileSet, ctx.WorkDir, nil, 0)
if err != nil {
return nil, fmt.Errorf("failed to parse dir: %w", err)
}

lrm := make(LoopRangeMap)
for _, astPkg := range astPkgs {
for _, astFile := range astPkg.Files {
for _, decl := range astFile.Decls {
if f, ok := decl.(*ast.FuncDecl); ok {
v := &loopRangeVisitor{
lr: nil,
}
ast.Walk(v, f)
lrm[f.Name.Name] = v.lr
}
}
}
}

return lrm, nil
}

type loopRangeVisitor struct {
lr LoopRanges
}

func (v *loopRangeVisitor) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) {
case *ast.ForStmt:
v.lr = append(v.lr, LoopRange{
start: n.Body.Lbrace,
end: n.Body.Rbrace,
})
return nil
case *ast.RangeStmt:
v.lr = append(v.lr, LoopRange{
start: n.Body.Lbrace,
end: n.Body.Rbrace,
})
return nil
}

return v
}
3 changes: 3 additions & 0 deletions dbdoc/mermaid.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ func writeMermaid(w io.StringWriter, nodes []*node) error {
}

line := "--"
if edge.inLoop {
line = "=="
}

var edgeExpr string
if edge.label == "" {
Expand Down
35 changes: 33 additions & 2 deletions dbdoc/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,41 @@ type Context struct {
WorkDir string
}

type LoopRange struct {
start token.Pos
end token.Pos
}
type LoopRanges []LoopRange
type LoopRangeMap map[string]LoopRanges

func (lr LoopRanges) Search(fset *token.FileSet, pos token.Pos) bool {
position := fset.Position(pos)
for _, r := range lr {
start := fset.Position(r.start)
end := fset.Position(r.end)
if position.Filename != start.Filename || position.Filename != end.Filename ||
position.Line < start.Line || position.Line > end.Line ||
(position.Line == start.Line && position.Column < start.Column) ||
(position.Line == end.Line && position.Column > end.Column) {
continue
}

return true
}

return false
}

type inLoop[T any] struct {
value T
inLoop bool
}

type function struct {
id string
name string
queries []query
calls []string
queries []inLoop[query]
calls []inLoop[string]
}

type stringLiteral struct {
Expand Down Expand Up @@ -70,6 +100,7 @@ type edge struct {
label string
node *node
edgeType edgeType
inLoop bool
}

type edgeType uint8
Expand Down

0 comments on commit bfd822e

Please sign in to comment.