Skip to content

Commit

Permalink
Rename go script; Add tests for updateGoModulePath.
Browse files Browse the repository at this point in the history
  • Loading branch information
Joerger committed Oct 20, 2021
1 parent bf44ca3 commit 30d5c26
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ $(VERSRC): Makefile
# Note: any build flags needed to compile go files (such as build tags) should be provided below.
.PHONY: update-api-module-path
update-api-module-path:
go run build.assets/update-api-module-path/main.go -tags "bpf fips pam roletester desktop_access_beta"
go run build.assets/update_api_module_path/main.go -tags "bpf fips pam roletester desktop_access_beta"
$(MAKE) update-vendor
$(MAKE) grpc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ func main() {
}

// the api module import path should only be updated on releases
if isPreRelease() {
exitWithMessage("the current API version (%v) is not a release, continue without updating", api.Version)
newVersion := api.Version
if isPreRelease(newVersion) {
exitWithMessage("the current API version (%v) is not a release, continue without updating", newVersion)
}

// get the current and new api module import paths
currentPath, newPath, err := getAPIModuleImportPaths()
currentPath, newPath, err := getAPIModuleImportPaths(newVersion)
if err != nil {
exitWithError(trace.Wrap(err, "failed to get mod paths"))
} else if currentPath == newPath {
Expand All @@ -59,11 +60,11 @@ func main() {

// update go files within the teleport/api and teleport modules to use the new import path
log.Info("Updating teleport/api module...")
if err := updateGoModule("./api", currentPath, newPath, buildFlags); err != nil {
if err := updateGoModule("./api", currentPath, newPath, newVersion, buildFlags); err != nil {
exitWithError(trace.Wrap(err, "failed to update teleport/api module"))
}
log.Info("Updating teleport module...")
if err := updateGoModule("./", currentPath, newPath, buildFlags); err != nil {
if err := updateGoModule("./", currentPath, newPath, newVersion, buildFlags); err != nil {
exitWithError(trace.Wrap(err, "failed to update teleport module"))
}

Expand All @@ -75,38 +76,42 @@ func main() {
}

// updateGoModule updates instances of the currentPath with the newPath in the given go module.
func updateGoModule(modulePath, currentPath, newPath string, buildFlags []string) error {
func updateGoModule(modulePath, currentPath, newPath, newVersion string, buildFlags []string) error {
log.Info(" Updating go files...")
if err := updateGoFiles(modulePath, currentPath, newPath, newVersion, buildFlags); err != nil {
return trace.Wrap(err, "failed to update mod file for module")
}

log.Info(" Updating go.mod...")
if err := updateGoModFile(modulePath, currentPath, newPath, newVersion); err != nil {
return trace.Wrap(err, "failed to update mod file for module")
}

return nil
}

// updateGoFiles updates instances of the currentPath with the newPath in go files in the given module.
func updateGoFiles(modulePath, currentPath, newPath, newVersion string, buildFlags []string) error {
mode := packages.NeedTypes | packages.NeedSyntax
cfg := &packages.Config{Mode: mode, Tests: true, Dir: modulePath, BuildFlags: buildFlags}
pkgs, err := packages.Load(cfg, "./...")
if err != nil {
return trace.Wrap(err)
}

log.Info(" Updating go files...")
var errs []error
packages.Visit(pkgs, func(pkg *packages.Package) bool {
if err = updateGoImports(pkg, currentPath, newPath); err != nil {
if err = updateGoImports(pkg, currentPath, newPath, newVersion); err != nil {
errs = append(errs, err)
return false
}
return true
}, nil)

if len(errs) != 0 {
return trace.NewAggregate(errs...)
}

log.Info(" Updating go.mod...")
if err := updateModFile(modulePath, currentPath, newPath); err != nil {
return trace.Wrap(err, "failed to update mod file for module")
}

return nil
return trace.NewAggregate(errs...)
}

// updateGoImports updates instances of the currentPath with the newPath in the given package.
func updateGoImports(p *packages.Package, currentPath, newPath string) error {
func updateGoImports(p *packages.Package, currentPath, newPath, newVersion string) error {
for _, syn := range p.Syntax {
var rewritten bool
for _, i := range syn.Imports {
Expand Down Expand Up @@ -137,9 +142,9 @@ func updateGoImports(p *packages.Package, currentPath, newPath string) error {
return nil
}

// updateModFile updates instances of oldPath to newPath in a go.mod file.
// updateGoModFile updates instances of oldPath to newPath in a go.mod file.
// The modFile is updated in place by updating the syntax fields directly.
func updateModFile(dir, oldPath, newPath string) error {
func updateGoModFile(dir, oldPath, newPath, newVersion string) error {
modFile, err := getModFile(dir)
if err != nil {
return trace.Wrap(err)
Expand All @@ -153,10 +158,10 @@ func updateModFile(dir, oldPath, newPath string) error {
if r.Mod.Path == oldPath {
// Update path and version of require statement.
if r.Syntax.InBlock {
r.Syntax.Token[0], r.Syntax.Token[1] = newPath, "v"+api.Version
r.Syntax.Token[0], r.Syntax.Token[1] = newPath, "v"+newVersion
} else {
// First token in the line is "require", skip to second and third indices
r.Syntax.Token[1], r.Syntax.Token[2] = newPath, "v"+api.Version
r.Syntax.Token[1], r.Syntax.Token[2] = newPath, "v"+newVersion
}
}
}
Expand All @@ -166,9 +171,15 @@ func updateModFile(dir, oldPath, newPath string) error {
// Update path of replace statement.
if r.Syntax.InBlock {
r.Syntax.Token[0] = newPath
if r.Old.Version != "" {
r.Syntax.Token[1] = "v" + newVersion
}
} else {
// First token in the line is "replace", skip to second index
r.Syntax.Token[1] = newPath
if r.Old.Version != "" {
r.Syntax.Token[2] = "v" + newVersion
}
}
}
}
Expand Down Expand Up @@ -236,7 +247,7 @@ func getModFile(dir string) (*modfile.File, error) {
}

// getAPIModuleImportPaths gets the current and new import paths for the api module
func getAPIModuleImportPaths() (current string, new string, err error) {
func getAPIModuleImportPaths(version string) (current string, new string, err error) {
// get the current mod path from `api/go.mod`
currentPath, err := getModImportPath("./api")
if err != nil {
Expand All @@ -245,7 +256,7 @@ func getAPIModuleImportPaths() (current string, new string, err error) {

// get the new major version suffix - e.g "" for v0/v1 or "/vX" for vX where X >= 2
var majVerSuffix string
if ver := semver.New(api.Version); ver.Major >= 2 {
if ver := semver.New(version); ver.Major >= 2 {
majVerSuffix = fmt.Sprintf("/v%d", ver.Major)
}

Expand All @@ -259,8 +270,8 @@ func getAPIModuleImportPaths() (current string, new string, err error) {
}

// returns whether the current api version is a pre-release, e.g "v7.0.0-beta"
func isPreRelease() bool {
return semver.New(api.Version).PreRelease != ""
func isPreRelease(version string) bool {
return semver.New(version).PreRelease != ""
}

func exitWithError(err error) {
Expand Down
101 changes: 101 additions & 0 deletions build.assets/update_api_module_path/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"fmt"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

var (
goMainFile = `
package main
import (
"github.com/gravitational/teleport/api"
)
func main() {
api.func()
}`
)

func TestUpdateGoModulePath(t *testing.T) {
modDir := t.TempDir()
modFilePath := filepath.Join(modDir, "go.mod")

modPath := "github.com/gravitational/teleport/api"
goModFile := `module github.com/gravitational/teleport/api
go 1.15
require github.com/gravitational/teleport/api v0.0.0
require (
github.com/gravitational/teleport/api v0.0.0
github.com/gravitational/teleport/api v0.0.0 // indirect
)
replace github.com/gravitational/teleport/api => ./api
replace github.com/gravitational/teleport/api v0.0.0 => ./api
replace (
github.com/gravitational/teleport/api v0.0.0 => ./api
)
`

newVersion := "2.1.3"
newModPath := modPath + "/v2"
newGoModFile := `module github.com/gravitational/teleport/api/v2
go 1.15
require github.com/gravitational/teleport/api/v2 v2.1.3
require (
github.com/gravitational/teleport/api/v2 v2.1.3
github.com/gravitational/teleport/api/v2 v2.1.3 // indirect
)
replace github.com/gravitational/teleport/api/v2 => ./api
replace github.com/gravitational/teleport/api/v2 v2.1.3 => ./api
replace (
github.com/gravitational/teleport/api/v2 v2.1.3 => ./api
)
`

err := os.WriteFile(modFilePath, []byte(goModFile), 0660)
require.NoError(t, err)

err = updateGoModFile(modDir, modPath, newModPath, newVersion)
require.NoError(t, err)

bytes, err := os.ReadFile(modFilePath)
require.NoError(t, err)

fmt.Println("\n\nBREAK\n\ns")
fmt.Println(newGoModFile)

fmt.Println("\n\nBREAK\n\ns")
fmt.Println(string(bytes))

require.Equal(t, newGoModFile, string(bytes))
}

0 comments on commit 30d5c26

Please sign in to comment.