Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented resumable downloads #26

Merged
merged 9 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 159 additions & 0 deletions services/slack/download.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package slack

import (
"bytes"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
)

const defaultOverlap int64 = 512

var ErrOverlapNotEqual = errors.New("download: the downloaded file doesn't match the one on disk")

// downloadInto downloads the contents of a URL into a file. If the file already exists it
// will resume the download. To prevent corrupting the files it downloads a tiny bit of
// overlapping data (512 byte) and compares it to the existing file:
//
// [-----existing local file-----]
// [-------resumed download-------]
// [overlap]
//
// When the check fails, the function returns an error and doesn't silently re-download
// the whole file. If the server doesn't support resumable downloads, the existing file will
// be truncated and re-downloaded.
func downloadInto(filename, url string, size int64) error {
file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0660)
if err != nil {
return fmt.Errorf("download: error opening the destination file: %w", err)
}
defer file.Close()

return resumeDownload(file, size, url)
}

func resumeDownload(existing *os.File, size int64, downloadURL string) error {
existingSize, overlap, err := calculateSize(existing, size)
if err != nil {
return err
}
if existingSize == size {
// the file has already been downloaded
return nil
}

start := existingSize - overlap // calculateSize makes sure this can't be negative
req, err := createRequest(downloadURL, start)
if err != nil {
return err
}

if start != 0 {
log.Printf("Resuming download from %s\n", humanSize(start))
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("download: error during HTTP request: %w", err)
}
defer resp.Body.Close()

switch resp.StatusCode {
case http.StatusPartialContent:
// do nothing, everything is fine
case http.StatusOK:
// server doesn't support Range
overlap = 0
if err = existing.Truncate(0); err != nil {
return fmt.Errorf("download: error emptying file for re-download: %w", err)
}
default:
return fmt.Errorf("download: HTTP request failed with status %q", resp.Status)
}

if overlap != 0 {
err = checkOverlap(existing, resp.Body, overlap)
if err != nil {
return err
}
}

_, err = existing.Seek(0, io.SeekEnd)
agarciamontoro marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("download: error seeking to the end of the existing file: %w", err)
}

_, err = io.Copy(existing, resp.Body)
if err != nil {
return fmt.Errorf("download: error during download: %w", err)
}

return nil
}

func checkOverlap(existing io.ReadSeeker, download io.Reader, overlap int64) error {
bufW := make([]byte, overlap)
bufL := make([]byte, overlap)

_, err := io.ReadFull(download, bufW)
if err != nil {
return fmt.Errorf("download: error downloading the overlapping data: %w", err)
}

_, err = existing.Seek(-overlap, io.SeekEnd)
if err != nil {
return fmt.Errorf("download: error seeking to the start of the existing overlap: %w", err)
}

_, err = io.ReadFull(existing, bufL)
if err != nil {
return fmt.Errorf("download: error reading the local overlapping data: %w", err)
}

if !bytes.Equal(bufW, bufL) {
return ErrOverlapNotEqual
}

return nil
}

func calculateSize(existing *os.File, size int64) (existingSize, overlap int64, err error) {
info, err := existing.Stat()
if err != nil {
return 0, 0, fmt.Errorf("download: error reading file info: %w", err)
}

existingSize = info.Size()
if existingSize == size {
return existingSize, 0, nil
}
if existingSize > size {
err = existing.Truncate(0)
if err != nil {
return 0, 0, fmt.Errorf("download: error emptying file: %w", err)
}
existingSize = 0
}

overlap = defaultOverlap
if overlap > existingSize {
overlap = existingSize
}

return existingSize, overlap, nil
}

func createRequest(url string, start int64) (*http.Request, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, fmt.Errorf("download: error creating HTTP request: %w", err)
}

req.Header.Set("User-Agent", "mmetl/1.0")
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", start))

return req, nil
}
198 changes: 198 additions & 0 deletions services/slack/download_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package slack

import (
"math/rand"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strconv"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

var mockData []byte

func TestDownload(t *testing.T) {
// set up the test
initializeMockData()
srv, old := mockDefaultHTTPClient()
defer func() {
srv.Close()
http.DefaultClient = old
}()

// run the idividual tests
t.Run("successful download", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)

require.NoError(t, downloadInto(fileName, srv.URL+"/no_resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful resume, empty file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, []byte{}, 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful resume, tiny file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData[:8], 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful resume, half file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData[:1024*512], 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful resume, full file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData, 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful re-download, empty file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, []byte{}, 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/no_resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful re-download, tiny file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData[:8], 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/no_resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful re-download, half file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData[:1024*512], 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/no_resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("successful re-download, full file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData, 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/no_resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})

t.Run("unsuccessful resume, tiny file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData[:8], 0660))

require.Error(t, downloadInto(fileName, srv.URL+"/wrong_resume", int64(len(mockData))))
})

t.Run("unsuccessful resume, half file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData[:1024*512], 0660))

require.Error(t, downloadInto(fileName, srv.URL+"/wrong_resume", int64(len(mockData))))
})

t.Run("successful resume from wrong file with an already downloaded file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData, 0660))

require.NoError(t, downloadInto(fileName, srv.URL+"/wrong_resume", int64(len(mockData))))
tempFile, _ := os.ReadFile(fileName)
require.Equal(t, mockData, tempFile)
})
noxer marked this conversation as resolved.
Show resolved Hide resolved

t.Run("unknown file", func(t *testing.T) {
fileName := filepath.Join(os.TempDir(), "download-test")
defer os.Remove(fileName)
require.NoError(t, os.WriteFile(fileName, mockData[:1024*512], 0660))

require.Error(t, downloadInto(fileName, srv.URL+"/wrong_path", int64(len(mockData))))
})
}

func mockDefaultHTTPClient() (newServer *httptest.Server, oldClient *http.Client) {
noxer marked this conversation as resolved.
Show resolved Hide resolved
mux := http.NewServeMux()

mux.HandleFunc("/no_resume", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write(mockData)
})

mux.HandleFunc("/resume", func(w http.ResponseWriter, r *http.Request) {
rangeHeader := r.Header.Get("Range")
if rangeHeader == "" {
_, _ = w.Write(mockData)
return
}

from, _ := strconv.ParseInt(strings.TrimPrefix(strings.TrimRight(rangeHeader, "-"), "bytes="), 10, 64)

w.WriteHeader(http.StatusPartialContent)
_, _ = w.Write(mockData[from:])
})

mux.HandleFunc("/wrong_resume", func(w http.ResponseWriter, r *http.Request) {
wrongData := make([]byte, 1024*1024)
rand.Read(wrongData) // read different "random" data

rangeHeader := r.Header.Get("Range")
if rangeHeader == "" {
_, _ = w.Write(wrongData)
return
}

from, _ := strconv.ParseInt(strings.TrimPrefix(strings.TrimRight(rangeHeader, "-"), "bytes="), 10, 64)

w.WriteHeader(http.StatusPartialContent)
_, _ = w.Write(wrongData[from:])
})

newServer = httptest.NewServer(mux)
oldClient = http.DefaultClient
http.DefaultClient = newServer.Client()

return
}

func initializeMockData() {
mockData = make([]byte, 1024*1024) // 1 MiB of "random" data
rand.Read(mockData)
}
Loading