summaryrefslogtreecommitdiffstats
path: root/build/codeformat/formatimports.go
diff options
context:
space:
mode:
Diffstat (limited to 'build/codeformat/formatimports.go')
-rw-r--r--build/codeformat/formatimports.go195
1 files changed, 195 insertions, 0 deletions
diff --git a/build/codeformat/formatimports.go b/build/codeformat/formatimports.go
new file mode 100644
index 0000000..c9fc2a2
--- /dev/null
+++ b/build/codeformat/formatimports.go
@@ -0,0 +1,195 @@
+// Copyright 2021 The Gitea Authors. All rights reserved.
+// SPDX-License-Identifier: MIT
+
+package codeformat
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "os"
+ "sort"
+ "strings"
+)
+
+var importPackageGroupOrders = map[string]int{
+ "": 1, // internal
+ "code.gitea.io/gitea/": 2,
+}
+
+var errInvalidCommentBetweenImports = errors.New("comments between imported packages are invalid, please move comments to the end of the package line")
+
+var (
+ importBlockBegin = []byte("\nimport (\n")
+ importBlockEnd = []byte("\n)")
+)
+
+type importLineParsed struct {
+ group string
+ pkg string
+ content string
+}
+
+func parseImportLine(line string) (*importLineParsed, error) {
+ il := &importLineParsed{content: line}
+ p1 := strings.IndexRune(line, '"')
+ if p1 == -1 {
+ return nil, errors.New("invalid import line: " + line)
+ }
+ p1++
+ p := strings.IndexRune(line[p1:], '"')
+ if p == -1 {
+ return nil, errors.New("invalid import line: " + line)
+ }
+ p2 := p1 + p
+ il.pkg = line[p1:p2]
+
+ pDot := strings.IndexRune(il.pkg, '.')
+ pSlash := strings.IndexRune(il.pkg, '/')
+ if pDot != -1 && pDot < pSlash {
+ il.group = "domain-package"
+ }
+ for groupName := range importPackageGroupOrders {
+ if groupName == "" {
+ continue // skip internal
+ }
+ if strings.HasPrefix(il.pkg, groupName) {
+ il.group = groupName
+ }
+ }
+ return il, nil
+}
+
+type (
+ importLineGroup []*importLineParsed
+ importLineGroupMap map[string]importLineGroup
+)
+
+func formatGoImports(contentBytes []byte) ([]byte, error) {
+ p1 := bytes.Index(contentBytes, importBlockBegin)
+ if p1 == -1 {
+ return nil, nil
+ }
+ p1 += len(importBlockBegin)
+ p := bytes.Index(contentBytes[p1:], importBlockEnd)
+ if p == -1 {
+ return nil, nil
+ }
+ p2 := p1 + p
+
+ importGroups := importLineGroupMap{}
+ r := bytes.NewBuffer(contentBytes[p1:p2])
+ eof := false
+ for !eof {
+ line, err := r.ReadString('\n')
+ eof = err == io.EOF
+ if err != nil && !eof {
+ return nil, err
+ }
+ line = strings.TrimSpace(line)
+ if line != "" {
+ if strings.HasPrefix(line, "//") || strings.HasPrefix(line, "/*") {
+ return nil, errInvalidCommentBetweenImports
+ }
+ importLine, err := parseImportLine(line)
+ if err != nil {
+ return nil, err
+ }
+ importGroups[importLine.group] = append(importGroups[importLine.group], importLine)
+ }
+ }
+
+ var groupNames []string
+ for groupName, importLines := range importGroups {
+ groupNames = append(groupNames, groupName)
+ sort.Slice(importLines, func(i, j int) bool {
+ return strings.Compare(importLines[i].pkg, importLines[j].pkg) < 0
+ })
+ }
+
+ sort.Slice(groupNames, func(i, j int) bool {
+ n1 := groupNames[i]
+ n2 := groupNames[j]
+ o1 := importPackageGroupOrders[n1]
+ o2 := importPackageGroupOrders[n2]
+ if o1 != 0 && o2 != 0 {
+ return o1 < o2
+ }
+ if o1 == 0 && o2 == 0 {
+ return strings.Compare(n1, n2) < 0
+ }
+ return o1 != 0
+ })
+
+ formattedBlock := bytes.Buffer{}
+ for _, groupName := range groupNames {
+ hasNormalImports := false
+ hasDummyImports := false
+ // non-dummy import comes first
+ for _, importLine := range importGroups[groupName] {
+ if strings.HasPrefix(importLine.content, "_") {
+ hasDummyImports = true
+ } else {
+ formattedBlock.WriteString("\t" + importLine.content + "\n")
+ hasNormalImports = true
+ }
+ }
+ // dummy (_ "pkg") comes later
+ if hasDummyImports {
+ if hasNormalImports {
+ formattedBlock.WriteString("\n")
+ }
+ for _, importLine := range importGroups[groupName] {
+ if strings.HasPrefix(importLine.content, "_") {
+ formattedBlock.WriteString("\t" + importLine.content + "\n")
+ }
+ }
+ }
+ formattedBlock.WriteString("\n")
+ }
+ formattedBlockBytes := bytes.TrimRight(formattedBlock.Bytes(), "\n")
+
+ var formattedBytes []byte
+ formattedBytes = append(formattedBytes, contentBytes[:p1]...)
+ formattedBytes = append(formattedBytes, formattedBlockBytes...)
+ formattedBytes = append(formattedBytes, contentBytes[p2:]...)
+ return formattedBytes, nil
+}
+
+// FormatGoImports format the imports by our rules (see unit tests)
+func FormatGoImports(file string, doWriteFile bool) error {
+ f, err := os.Open(file)
+ if err != nil {
+ return err
+ }
+ var contentBytes []byte
+ {
+ defer f.Close()
+ contentBytes, err = io.ReadAll(f)
+ if err != nil {
+ return err
+ }
+ }
+ formattedBytes, err := formatGoImports(contentBytes)
+ if err != nil {
+ return err
+ }
+ if formattedBytes == nil {
+ return nil
+ }
+ if bytes.Equal(contentBytes, formattedBytes) {
+ return nil
+ }
+
+ if doWriteFile {
+ f, err = os.OpenFile(file, os.O_TRUNC|os.O_WRONLY, 0o644)
+ if err != nil {
+ return err
+ }
+ defer f.Close()
+ _, err = f.Write(formattedBytes)
+ return err
+ }
+
+ return err
+}