diff --git a/pkg/astutil/astutil_test.go b/pkg/astutil/astutil_test.go index b8ebb10..06ed8dc 100644 --- a/pkg/astutil/astutil_test.go +++ b/pkg/astutil/astutil_test.go @@ -4,11 +4,9 @@ import ( "fmt" "go/parser" "go/token" - "reflect" - "sort" - "strings" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -154,13 +152,11 @@ func main(){ t.Run(tt.name, func(t *testing.T) { fset := token.NewFileSet() f, err := parser.ParseFile(fset, "", []byte(fileData), parser.ParseComments) - if err != nil { - require.Nil(t, err) - } + require.NoError(t, err) - if got := UsesImport(f, tt.args.packageImports, tt.args.path); got != tt.want { - t.Errorf("UsesImport() = %v, want %v", got, tt.want) - } + got := UsesImport(f, tt.args.packageImports, tt.args.path) + + assert.Equal(t, tt.want, got) }) } } @@ -211,32 +207,16 @@ func TestLoadPackageDeps(t *testing.T) { nil, parser.ParseComments, ) - if err != nil { - t.Errorf("parser.ParseFile() error = %v", err) - return - } + require.NoError(t, err) got, err := LoadPackageDependencies(tt.args.dir, ParseBuildTag(f)) - if (err != nil) != tt.wantErr { - t.Errorf("LoadPackageDependencies() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } - gotList := make([]string, 0, len(got)) - for pkg, name := range got { - gotList = append(gotList, strings.Join([]string{pkg, name}, " - ")) - } - sort.Strings(gotList) - - wantList := make([]string, 0, len(got)) - for pkg, name := range tt.want { - wantList = append(wantList, strings.Join([]string{pkg, name}, " - ")) - } - sort.Strings(wantList) - - if !reflect.DeepEqual(gotList, wantList) { - t.Errorf("LoadPackageDependencies() got = %v, want %v", got, tt.want) - } + assert.NoError(t, err) + assert.EqualValues(t, tt.want, got) }) } } diff --git a/reviser/file_test.go b/reviser/file_test.go index aab1cdb..d39337b 100644 --- a/reviser/file_test.go +++ b/reviser/file_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestSourceFile_Fix(t *testing.T) { @@ -569,18 +570,17 @@ import "C" } for _, tt := range tests { if tt.args.filePath != StandardInput && !strings.Contains(tt.args.filePath, "does-not-exist") { - if err := os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644); err != nil { - t.Errorf("write test file failed: %s", err) - } + require.NoError(t, os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644)) } t.Run(tt.name, func(t *testing.T) { got, hasChange, err := NewSourceFile(tt.args.projectName, tt.args.filePath).Fix() - if (err != nil) != tt.wantErr { - t.Errorf("Fix() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } + assert.NoError(t, err) assert.Equal(t, tt.wantChange, hasChange) assert.Equal(t, tt.want, string(got)) }) @@ -716,9 +716,7 @@ import ( } for _, tt := range tests { if tt.args.filePath != StandardInput && !strings.Contains(tt.args.filePath, "does-not-exist") { - if err := os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644); err != nil { - t.Errorf("write test file failed: %s", err) - } + require.NoError(t, os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644)) } t.Run(tt.name, func(t *testing.T) { @@ -726,11 +724,12 @@ import ( assert.Nil(t, err) got, hasChange, err := NewSourceFile(tt.args.projectName, tt.args.filePath). Fix(WithImportsOrder(order)) - if (err != nil) != tt.wantErr { - t.Errorf("Fix() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } + assert.NoError(t, err) assert.Equal(t, tt.wantChange, hasChange) assert.Equal(t, tt.want, string(got)) }) @@ -1010,18 +1009,17 @@ func main() { } for _, tt := range tests { - if err := os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644); err != nil { - t.Errorf("write test file failed: %s", err) - } + require.NoError(t, os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644)) t.Run(tt.name, func(t *testing.T) { got, hasChange, err := NewSourceFile(tt.args.projectName, tt.args.filePath). Fix(WithRemovingUnusedImports) - if (err != nil) != tt.wantErr { - t.Errorf("Fix() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } + assert.NoError(t, err) assert.Equal(t, tt.wantChange, hasChange) assert.Equal(t, tt.want, string(got)) }) @@ -1160,18 +1158,17 @@ func main() { } for _, tt := range tests { - if err := os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644); err != nil { - t.Errorf("write test file failed: %s", err) - } + require.NoError(t, os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644)) t.Run(tt.name, func(t *testing.T) { got, hasChange, err := NewSourceFile(tt.args.projectName, tt.args.filePath). Fix(WithUsingAliasForVersionSuffix) - if (err != nil) != tt.wantErr { - t.Errorf("Fix() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } + assert.NoError(t, err) assert.Equal(t, tt.wantChange, hasChange) assert.Equal(t, tt.want, string(got)) }) @@ -1376,18 +1373,17 @@ func main() { } for _, tt := range tests { - if err := os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644); err != nil { - t.Errorf("write test file failed: %s", err) - } + require.NoError(t, os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644)) t.Run(tt.name, func(t *testing.T) { got, hasChange, err := NewSourceFile(tt.args.projectName, tt.args.filePath). Fix(WithCompanyPackagePrefixes(tt.args.localPkgPrefixes)) - if (err != nil) != tt.wantErr { - t.Errorf("Fix() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } + assert.NoError(t, err) assert.Equal(t, tt.wantChange, hasChange) assert.Equal(t, tt.want, string(got)) }) @@ -1466,17 +1462,16 @@ func test1() {} }, } for _, tt := range tests { - if err := os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644); err != nil { - t.Errorf("write test file failed: %s", err) - } + require.NoError(t, os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644)) t.Run(tt.name, func(t *testing.T) { got, hasChange, err := NewSourceFile(tt.args.projectName, tt.args.filePath).Fix(WithCodeFormatting) - if (err != nil) != tt.wantErr { - t.Errorf("Fix() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } + assert.NoError(t, err) assert.Equal(t, tt.wantChange, hasChange) assert.Equal(t, tt.want, string(got)) }) @@ -1842,18 +1837,17 @@ import ( } for _, tt := range tests { if tt.args.filePath != StandardInput && !strings.Contains(tt.args.filePath, "does-not-exist") { - if err := os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644); err != nil { - t.Errorf("write test file failed: %s", err) - } + require.NoError(t, os.WriteFile(tt.args.filePath, []byte(tt.args.fileContent), 0644)) } t.Run(tt.name, func(t *testing.T) { got, hasChange, err := NewSourceFile(tt.args.projectName, tt.args.filePath).Fix(WithSkipGeneratedFile) - if (err != nil) != tt.wantErr { - t.Errorf("Fix() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) return } + assert.NoError(t, err) assert.Equal(t, tt.wantChange, hasChange) assert.Equal(t, tt.want, string(got)) })