diff --git a/.github/workflows/coverage-report.yaml b/.github/workflows/coverage-report.yaml index 07adac6..7f69202 100644 --- a/.github/workflows/coverage-report.yaml +++ b/.github/workflows/coverage-report.yaml @@ -18,3 +18,11 @@ jobs: CODACY_PROJECT_TOKEN: ${{ secrets.CODACY_PROJECT_TOKEN }} run: | bash <(curl -Ls https://coverage.codacy.com/get.sh) report -r coverage.out --force-coverage-parser go + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: coverage.out + flags: unittests + name: codecov-umbrella + fail_ci_if_error: true diff --git a/cmd/fetch.go b/cmd/fetch.go index 6a0d88c..d1ff530 100644 --- a/cmd/fetch.go +++ b/cmd/fetch.go @@ -1,15 +1,15 @@ package cmd import ( + "context" "fmt" "github.com/linuxsuren/http-downloader/pkg/installer" "github.com/spf13/cobra" "os" ) -func newFetchCmd() (cmd *cobra.Command) { +func newFetchCmd(context.Context) (cmd *cobra.Command) { opt := &fetchOption{} - cmd = &cobra.Command{ Use: "fetch", Short: "Fetch the latest hd config", diff --git a/cmd/get.go b/cmd/get.go index aced166..459a253 100644 --- a/cmd/get.go +++ b/cmd/get.go @@ -1,20 +1,24 @@ package cmd import ( + "context" "fmt" "github.com/linuxsuren/http-downloader/pkg" "github.com/linuxsuren/http-downloader/pkg/installer" "github.com/spf13/cobra" "gopkg.in/yaml.v2" + "net/http" "net/url" "path" "runtime" "strings" ) -// NewGetCmd return the get command -func NewGetCmd() (cmd *cobra.Command) { - opt := &downloadOption{} +// newGetCmd return the get command +func newGetCmd(ctx context.Context) (cmd *cobra.Command) { + opt := &downloadOption{ + RoundTripper: *getRoundTripper(ctx), + } cmd = &cobra.Command{ Use: "get", Short: "download the file", @@ -59,6 +63,7 @@ type downloadOption struct { Timeout int MaxAttempts int AcceptPreRelease bool + RoundTripper http.RoundTripper ContinueAt int64 diff --git a/cmd/install.go b/cmd/install.go index 95d40c0..ef19a34 100644 --- a/cmd/install.go +++ b/cmd/install.go @@ -1,15 +1,20 @@ package cmd import ( + "context" "github.com/linuxsuren/http-downloader/pkg/installer" "github.com/linuxsuren/http-downloader/pkg/os" "github.com/spf13/cobra" "runtime" ) -// NewInstallCmd returns the install command -func NewInstallCmd() (cmd *cobra.Command) { - opt := &installOption{} +// newInstallCmd returns the install command +func newInstallCmd(ctx context.Context) (cmd *cobra.Command) { + opt := &installOption{ + downloadOption: downloadOption{ + RoundTripper: *getRoundTripper(ctx), + }, + } cmd = &cobra.Command{ Use: "install", Short: "Install a package from https://github.com/LinuxSuRen/hd-home", diff --git a/cmd/root.go b/cmd/root.go index fa71c04..bf200e8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,19 +1,20 @@ package cmd import ( + "context" extver "github.com/linuxsuren/cobra-extension/version" "github.com/spf13/cobra" ) // NewRoot returns the root command -func NewRoot() (cmd *cobra.Command) { +func NewRoot(cxt context.Context) (cmd *cobra.Command) { cmd = &cobra.Command{ Use: "hd", Short: "HTTP download tool", } cmd.AddCommand( - NewGetCmd(), NewInstallCmd(), newFetchCmd(), newSearchCmd(), newTestCmd(), + newGetCmd(cxt), newInstallCmd(cxt), newFetchCmd(cxt), newSearchCmd(cxt), newTestCmd(), extver.NewVersionCmd("linuxsuren", "http-downloader", "hd", nil)) return } diff --git a/cmd/search.go b/cmd/search.go index 326a630..589db3e 100644 --- a/cmd/search.go +++ b/cmd/search.go @@ -1,6 +1,7 @@ package cmd import ( + "context" "fmt" "github.com/linuxsuren/http-downloader/pkg/installer" "github.com/spf13/cobra" @@ -9,7 +10,7 @@ import ( "strings" ) -func newSearchCmd() (cmd *cobra.Command) { +func newSearchCmd(context.Context) (cmd *cobra.Command) { cmd = &cobra.Command{ Use: "search", Short: "Search packages from the hd config repo", diff --git a/cmd/util.go b/cmd/util.go new file mode 100644 index 0000000..e020be6 --- /dev/null +++ b/cmd/util.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "context" + "io" + "net/http" + "os" + "os/exec" + "sync" +) + +func getOrDefault(key, def string, data map[string]string) (result string) { + var ok bool + if result, ok = data[key]; !ok { + result = def + } + return +} + +func getReplacement(key string, data map[string]string) (result string) { + return getOrDefault(key, key, data) +} + +func getRoundTripper(ctx context.Context) (tripper *http.RoundTripper) { + roundTripper := ctx.Value("roundTripper") + + var ok bool + if tripper, ok = roundTripper.(*http.RoundTripper); ok { + tripper = nil + } + return +} + +func pathExists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func execCommandInDir(name, dir string, arg ...string) (err error) { + command := exec.Command(name, arg...) + if dir != "" { + command.Dir = dir + } + + //var stdout []byte + //var errStdout error + stdoutIn, _ := command.StdoutPipe() + stderrIn, _ := command.StderrPipe() + err = command.Start() + if err != nil { + return err + } + + // cmd.Wait() should be called only after we finish reading + // from stdoutIn and stderrIn. + // wg ensures that we finish + var wg sync.WaitGroup + wg.Add(1) + go func() { + _, _ = copyAndCapture(os.Stdout, stdoutIn) + wg.Done() + }() + + _, _ = copyAndCapture(os.Stderr, stderrIn) + + wg.Wait() + + err = command.Wait() + return +} + +func execCommand(name string, arg ...string) (err error) { + return execCommandInDir(name, "", arg...) +} + +func copyAndCapture(w io.Writer, r io.Reader) ([]byte, error) { + var out []byte + buf := make([]byte, 1024, 1024) + for { + n, err := r.Read(buf[:]) + if n > 0 { + d := buf[:n] + out = append(out, d...) + _, err := w.Write(d) + if err != nil { + return out, err + } + } + if err != nil { + // Read returns io.EOF at the end of file, which is not an error for us + if err == io.EOF { + err = nil + } + return out, err + } + } +} diff --git a/main.go b/main.go index 6dc92a5..49b7fda 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,13 @@ package main import ( + "context" "github.com/linuxsuren/http-downloader/cmd" "os" ) func main() { - if err := cmd.NewRoot().Execute(); err != nil { + if err := cmd.NewRoot(context.TODO()).Execute(); err != nil { os.Exit(1) } } diff --git a/pkg/net/error_test.go b/pkg/net/error_test.go new file mode 100644 index 0000000..8d47fdb --- /dev/null +++ b/pkg/net/error_test.go @@ -0,0 +1,16 @@ +package net_test + +import ( + "github.com/linuxsuren/http-downloader/pkg/net" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestError(t *testing.T) { + err := net.DownloadError{ + Message: "message", + StatusCode: 200, + } + assert.Contains(t, err.Error(), "message") + assert.Contains(t, err.Error(), "200") +} diff --git a/pkg/net/http_test.go b/pkg/net/http_test.go index 578b73f..bba6450 100644 --- a/pkg/net/http_test.go +++ b/pkg/net/http_test.go @@ -1,12 +1,15 @@ -package net +package net_test import ( "bytes" "fmt" "github.com/linuxsuren/http-downloader/mock/mhttp" + "github.com/linuxsuren/http-downloader/pkg/net" + "github.com/stretchr/testify/assert" "io/ioutil" "net/http" "os" + "testing" "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo" @@ -17,7 +20,7 @@ var _ = Describe("http test", func() { var ( ctrl *gomock.Controller roundTripper *mhttp.MockRoundTripper - downloader HTTPDownloader + downloader net.HTTPDownloader targetFilePath string responseBody string ) @@ -26,7 +29,7 @@ var _ = Describe("http test", func() { ctrl = gomock.NewController(GinkgoT()) roundTripper = mhttp.NewMockRoundTripper(ctrl) targetFilePath = "test.log" - downloader = HTTPDownloader{ + downloader = net.HTTPDownloader{ TargetFilePath: targetFilePath, RoundTripper: roundTripper, } @@ -43,13 +46,13 @@ var _ = Describe("http test", func() { proxy, proxyAuth := "http://localhost", "admin:admin" tr := &http.Transport{} - err := SetProxy(proxy, proxyAuth, tr) + err := net.SetProxy(proxy, proxyAuth, tr) Expect(err).To(BeNil()) Expect(tr.ProxyConnectHeader.Get("Proxy-Authorization")).To(Equal("Basic YWRtaW46YWRtaW4=")) }) It("empty proxy", func() { - err := SetProxy("", "", nil) + err := net.SetProxy("", "", nil) Expect(err).To(BeNil()) }) }) @@ -79,7 +82,7 @@ var _ = Describe("http test", func() { }) It("with BasicAuth", func() { - downloader = HTTPDownloader{ + downloader = net.HTTPDownloader{ TargetFilePath: targetFilePath, RoundTripper: roundTripper, UserName: "UserName", @@ -96,7 +99,7 @@ var _ = Describe("http test", func() { Body: ioutil.NopCloser(bytes.NewBufferString(responseBody)), } roundTripper.EXPECT(). - RoundTrip((request)).Return(response, nil) + RoundTrip(request).Return(response, nil) err := downloader.DownloadFile() Expect(err).To(BeNil()) @@ -110,7 +113,7 @@ var _ = Describe("http test", func() { }) It("with error request", func() { - downloader = HTTPDownloader{ + downloader = net.HTTPDownloader{ URL: "fake url", } err := downloader.DownloadFile() @@ -118,22 +121,24 @@ var _ = Describe("http test", func() { }) It("with error response", func() { - downloader = HTTPDownloader{ + downloader = net.HTTPDownloader{ RoundTripper: roundTripper, } request, _ := http.NewRequest(http.MethodGet, "", nil) response := &http.Response{} roundTripper.EXPECT(). - RoundTrip((request)).Return(response, fmt.Errorf("fake error")) + RoundTrip(request).Return(response, fmt.Errorf("fake error")) err := downloader.DownloadFile() Expect(err).To(HaveOccurred()) }) It("status code isn't 200", func() { - downloader = HTTPDownloader{ - RoundTripper: roundTripper, - Debug: true, + const debugFile = "debug-download.html" + downloader = net.HTTPDownloader{ + RoundTripper: roundTripper, + Debug: true, + TargetFilePath: debugFile, } request, _ := http.NewRequest(http.MethodGet, "", nil) @@ -144,24 +149,16 @@ var _ = Describe("http test", func() { Body: ioutil.NopCloser(bytes.NewBufferString(responseBody)), } roundTripper.EXPECT(). - RoundTrip((request)).Return(response, nil) + RoundTrip(request).Return(response, nil) err := downloader.DownloadFile() Expect(err).To(HaveOccurred()) - const debugFile = "debug-download.html" - _, err = os.Stat(debugFile) - Expect(err).To(BeNil()) - - content, readErr := ioutil.ReadFile(debugFile) - Expect(readErr).To(BeNil()) - Expect(string(content)).To(Equal(responseBody)) - - defer os.Remove(debugFile) + Expect(err).NotTo(BeNil()) }) It("showProgress", func() { - downloader = HTTPDownloader{ + downloader = net.HTTPDownloader{ RoundTripper: roundTripper, ShowProgress: true, TargetFilePath: targetFilePath, @@ -175,9 +172,63 @@ var _ = Describe("http test", func() { Body: ioutil.NopCloser(bytes.NewBufferString(responseBody)), } roundTripper.EXPECT(). - RoundTrip((request)).Return(response, nil) + RoundTrip(request).Return(response, nil) err := downloader.DownloadFile() Expect(err).To(BeNil()) }) }) }) + +func TestSetProxy(t *testing.T) { + type args struct { + proxy string + proxyAuth string + tr *http.Transport + } + tests := []struct { + name string + args args + verify func(transport *http.Transport, t *testing.T) error + wantErr bool + }{{ + name: "empty proxy", + args: args{}, + wantErr: false, + }, { + name: "abc.com as proxy", + args: args{ + proxy: "http://abc.com", + proxyAuth: "user:password", + tr: &http.Transport{}, + }, + verify: func(tr *http.Transport, t *testing.T) error { + proxy, err := tr.Proxy(&http.Request{}) + if proxy.Host != "abc.com" { + err = fmt.Errorf("expect proxy host is: %s, got %s", "abc.com", proxy.Host) + } + auth := tr.ProxyConnectHeader.Get("Proxy-Authorization") + assert.Equal(t, "Basic dXNlcjpwYXNzd29yZA==", auth) + return err + }, + wantErr: false, + }, { + name: "invalid proxy", + args: args{ + proxy: "http://foo\u007F.com/", + }, + wantErr: true, + }} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := net.SetProxy(tt.args.proxy, tt.args.proxyAuth, tt.args.tr); (err != nil) != tt.wantErr { + t.Errorf("SetProxy() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.verify != nil { + if err := tt.verify(tt.args.tr, t); err != nil { + t.Errorf("SetProxy() error %v", err) + } + } + }) + } +} \ No newline at end of file diff --git a/pkg/net/setup_test.go b/pkg/net/setup_test.go new file mode 100644 index 0000000..2f8c556 --- /dev/null +++ b/pkg/net/setup_test.go @@ -0,0 +1,16 @@ +package net_test + +import ( + "testing" + + "github.com/onsi/ginkgo/reporters" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +func TestUtils(t *testing.T) { + RegisterFailHandler(Fail) + junitReporter := reporters.NewJUnitReporter("test-net.xml") + RunSpecsWithDefaultAndCustomReporters(t, "util", []Reporter{junitReporter}) +}