From 55fb02300202e33f0197be846a223131dbb85a53 Mon Sep 17 00:00:00 2001 From: Patrice Ferlet Date: Thu, 22 Jun 2023 16:30:30 +0200 Subject: [PATCH] Make better checks and fix ordering Structs - See #6 - in progress to resolve - Zee #4 - in propose to resolve --- ordering/main.go | 60 ++++++++++++-------- ordering/main_test.go | 126 +++++++++++++++++++++++++++++++++++++++++ ordering/parser.go | 8 ++- ordering/sortString.go | 35 ++++++++++++ 4 files changed, 203 insertions(+), 26 deletions(-) create mode 100644 ordering/sortString.go diff --git a/ordering/main.go b/ordering/main.go index 921306b..34e4033 100644 --- a/ordering/main.go +++ b/ordering/main.go @@ -4,8 +4,8 @@ import ( "crypto/sha256" "errors" "fmt" + "go/format" "io/ioutil" - "log" "os" "os/exec" "sort" @@ -62,13 +62,8 @@ func ReorderSource(opt ReorderConfig) (string, error) { }) } - structNames := make([]string, 0, len(info.Methods)) - for _, s := range info.Structs { - log.Println("s.Name", s.Name) - structNames = append(structNames, s.Name) - } if opt.ReorderStructs { - sort.Strings(structNames) + info.StructNames.Sort() } // Get the source code signature - we will use this to mark the lines to remove later @@ -82,7 +77,7 @@ func ReorderSource(opt ReorderConfig) (string, error) { lineNumberWhereInject := 0 removedLines := 0 - for _, typename := range structNames { + for _, typename := range *info.StructNames { if removedLines == 0 { lineNumberWhereInject = info.Structs[typename].OpeningLine } @@ -129,33 +124,50 @@ func ReorderSource(opt ReorderConfig) (string, error) { output := strings.Join(originalContent, "\n") // write in a temporary file and use "gofmt" to format it + newcontent := []byte(output) + switch opt.FormatCommand { + case "gofmt": + // format the temporary file + newcontent, err = format.Source([]byte(output)) + if err != nil { + return string(content), errors.New("Failed to format source: " + err.Error()) + } + default: + if newcontent, err = formatWithCommand(content, output, opt); err != nil { + return string(content), errors.New("Failed to format source: " + err.Error()) + } + } + + if opt.Diff { + return doDiff(content, newcontent, opt.Filename) + } + return string(newcontent), nil +} + +func formatWithCommand(content []byte, output string, opt ReorderConfig) (newcontent []byte, err error) { + // we use the format command given by the user + // on a temporary file we need to create and remove tmpfile, err := ioutil.TempFile("", "") if err != nil { - return string(content), errors.New("Failed to create temp file: " + err.Error()) + return content, errors.New("Failed to create temp file: " + err.Error()) } - defer func() { - // close and remove the temporary file - tmpfile.Close() - os.Remove(tmpfile.Name()) - }() + defer os.Remove(tmpfile.Name()) + // write the temporary file if _, err := tmpfile.Write([]byte(output)); err != nil { - return string(content), errors.New("Failed to write to temporary file: " + err.Error()) + return content, errors.New("Failed to write temp file: " + err.Error()) } + tmpfile.Close() + // format the temporary file cmd := exec.Command(opt.FormatCommand, "-w", tmpfile.Name()) if err := cmd.Run(); err != nil { - return string(content), err + return content, err } - // read the temporary file - newcontent, err := ioutil.ReadFile(tmpfile.Name()) + newcontent, err = ioutil.ReadFile(tmpfile.Name()) if err != nil { - return string(content), errors.New("Read Temporary File error: " + err.Error()) + return content, errors.New("Read Temporary File error: " + err.Error()) } - - if opt.Diff { - return doDiff(content, newcontent, opt.Filename) - } - return string(newcontent), nil + return newcontent, nil } diff --git a/ordering/main_test.go b/ordering/main_test.go index 4a5fa5b..1a44c61 100644 --- a/ordering/main_test.go +++ b/ordering/main_test.go @@ -364,3 +364,129 @@ func (f *Foo) FooMethod1() { t.Error(err) } } + +func TestNoOrderStructs(t *testing.T) { + const source = `package main +type grault struct {} +type xyzzy struct {} +type bar struct {} +type qux struct {} +type quux struct {} +type corge struct {} +type garply struct {} +type baz struct {} +type waldo struct {} +type fred struct {} +type plugh struct {} +type foo struct {} +` + const expected = `package main + +type grault struct{} + +type xyzzy struct{} + +type bar struct{} + +type qux struct{} + +type quux struct{} + +type corge struct{} + +type garply struct{} + +type baz struct{} + +type waldo struct{} + +type fred struct{} + +type plugh struct{} + +type foo struct{} +` + + const orderedSource = `package main + +type bar struct{} + +type baz struct{} + +type corge struct{} + +type foo struct{} + +type fred struct{} + +type garply struct{} + +type grault struct{} + +type plugh struct{} + +type quux struct{} + +type qux struct{} + +type waldo struct{} + +type xyzzy struct{} +` + + content, err := ReorderSource(ReorderConfig{ + Filename: "foo.go", + FormatCommand: "gofmt", + ReorderStructs: false, + Src: []byte(source), + Diff: false, + }) + if err != nil { + t.Error(err) + } + if content != expected { + t.Errorf("Expected UNORDERED:\n%s\nGot:\n%s\n", expected, content) + } + + content, err = ReorderSource(ReorderConfig{ + Filename: "foo.go", + FormatCommand: "gofmt", + ReorderStructs: true, + Src: []byte(source), + Diff: false, + }) + if err != nil { + t.Error(err) + } + if content != orderedSource { + t.Errorf("Expected ORDERED:\n%s\nGot:\n%s\n", orderedSource, content) + } + +} + +func TestBadFormatCommand(t *testing.T) { + const source = `package main + +import ( + "os" + "fmt" +) +type grault struct {} +type xyzzy struct {} +type bar struct {} +` + content, err := ReorderSource(ReorderConfig{ + Filename: "foo.go", + FormatCommand: "wthcommand", + ReorderStructs: false, + Src: []byte(source), + Diff: false, + }) + + if err == nil { + t.Error("Expected error, got nil") + } + if content != source { + t.Errorf("Expected:\n%s\nGot:\n%s\n", source, content) + } +} diff --git a/ordering/parser.go b/ordering/parser.go index 2df2945..84586d5 100644 --- a/ordering/parser.go +++ b/ordering/parser.go @@ -27,6 +27,7 @@ type ParsedInfo struct { Structs map[string]*GoType Constants map[string]*GoType Variables map[string]*GoType + StructNames *StingList } // GetMethodComments returns the comments for the given method. @@ -66,6 +67,7 @@ func Parse(filename string, src interface{}) (*ParsedInfo, error) { methods = make(map[string][]*GoType) constructors = make(map[string][]*GoType) structTypes = make(map[string]*GoType) + structNames = &StingList{} varTypes = make(map[string]*GoType) constTypes = make(map[string]*GoType) sourceCode []byte @@ -87,7 +89,7 @@ func Parse(filename string, src interface{}) (*ParsedInfo, error) { findMethods(d, fset, sourceLines, methods) // find struct declarations case *ast.GenDecl: - findStructs(d, fset, sourceLines, structTypes) + findStructs(d, fset, sourceLines, structNames, structTypes) findGlobalVarsAndConsts(d, fset, sourceLines, varTypes, constTypes) } } @@ -102,6 +104,7 @@ func Parse(filename string, src interface{}) (*ParsedInfo, error) { return &ParsedInfo{ Structs: structTypes, + StructNames: structNames, Methods: methods, Constructors: constructors, Variables: varTypes, @@ -109,7 +112,7 @@ func Parse(filename string, src interface{}) (*ParsedInfo, error) { }, nil } -func findStructs(d *ast.GenDecl, fset *token.FileSet, sourceLines []string, structTypes map[string]*GoType) { +func findStructs(d *ast.GenDecl, fset *token.FileSet, sourceLines []string, stuctNames *StingList, structTypes map[string]*GoType) { if d.Tok != token.TYPE { return } @@ -132,6 +135,7 @@ func findStructs(d *ast.GenDecl, fset *token.FileSet, sourceLines []string, stru typeDef.OpeningLine -= len(comments) structTypes[s.Name.Name] = typeDef + stuctNames.Add(s.Name.Name) } } } diff --git a/ordering/sortString.go b/ordering/sortString.go new file mode 100644 index 0000000..3b020a5 --- /dev/null +++ b/ordering/sortString.go @@ -0,0 +1,35 @@ +package ordering + +import "sort" + +var _ sort.Interface = (*StingList)(nil) + +// StingList is a list of strings that *can* be sorted. +// +// Implement sort.Interface +type StingList []string + +// Len returns the length of the list. +func (s *StingList) Len() int { + return len(*s) +} + +// Swap swaps the elements with indexes i and j. +func (s StingList) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +// Less reports whether the element with index i should sort before the element with index j. +func (s StingList) Less(i, j int) bool { + return s[i] < s[j] +} + +// Sort sorts the list. +func (s *StingList) Sort() { + sort.Sort(s) +} + +// Add adds a string to the list. +func (s *StingList) Add(str string) { + *s = append(*s, str) +}