Skip to content

Commit

Permalink
Replace sparsefile.Copy with sparsefile.Overwrite
Browse files Browse the repository at this point in the history
Copy will not be correct in rare cases when we need to override parts of
already existing file with whole buffer of zeros.
  • Loading branch information
tomekjarosik committed Oct 9, 2024
1 parent ca7374a commit 847cc05
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 9 deletions.
26 changes: 17 additions & 9 deletions pkg/dirimage/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func writeToSegment(destinationDir string, segment *filesegment.Descriptor, src
}
}(f)

written, skipped, err = sparsefile.Copy(f, src)
written, skipped, err = sparsefile.Overwrite(f, src)
if written+skipped != segment.Length() {
return written, skipped, fmt.Errorf("invalid numer of bytes written+skipped: segment length: %d, written+skipped: %d", segment.Length(), written+skipped)
}
Expand Down Expand Up @@ -62,7 +62,13 @@ func truncateFiles(destinationDir string, segmentDescriptors []*filesegment.Desc
}

for filename, size := range fileSizesMap {
err := os.Truncate(filepath.Join(destinationDir, filename), size)
fpath := filepath.Join(destinationDir, filename)
f, err := os.OpenFile(fpath, os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return fmt.Errorf("error opening file '%s': %w", filename, err)
}
defer f.Close()
err = os.Truncate(fpath, size)
if err != nil {
return fmt.Errorf("error while truncating file '%v': %w", filename, err)
}
Expand Down Expand Up @@ -99,6 +105,13 @@ func (di *DirImage) Write(ctx context.Context, destinationDir string, opt ...Opt
}
bytesTotal := di.Length()
sendProgressUpdate(opts.progress, 0, bytesTotal)

// Create & truncate the files to correct sizes, so we only have to overwrite parts that are different
err := truncateFiles(destinationDir, di.segmentDescriptors)
if err != nil {
return err
}

jobs := make(chan Job, opts.workersCount)
g, ctx := errgroup.WithContext(ctx)
layerOpts := []filesegment.LayerOpt{filesegment.WithLogFunction(opts.printf)}
Expand All @@ -108,7 +121,7 @@ func (di *DirImage) Write(ctx context.Context, destinationDir string, opt ...Opt
atomic.AddInt64(&di.BytesReadCount, job.Descriptor.Length())
sendProgressUpdate(opts.progress, di.BytesReadCount, bytesTotal)
if filesegment.Matches(&job.Descriptor, destinationDir, layerOpts...) {
opts.printf("existing layer: %v\n", &job.Descriptor)
opts.printf("existing layer: %v matches %v\n", &job.Descriptor, job.Descriptor)
continue
}

Expand Down Expand Up @@ -147,12 +160,7 @@ func (di *DirImage) Write(ctx context.Context, destinationDir string, opt ...Opt
return nil
})

err := g.Wait()
if err != nil {
return err
}

err = truncateFiles(destinationDir, di.segmentDescriptors)
err = g.Wait()
if err != nil {
return err
}
Expand Down
58 changes: 58 additions & 0 deletions pkg/sparsefile/overwrite.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package sparsefile

import (
"bytes"
"fmt"
"io"
)

const maxBufSize = 64 * 1024

func Overwrite(dst io.ReadWriteSeeker, src io.Reader) (written int64, skipped int64, err error) {
srcBuf := make([]byte, maxBufSize)
dstBuf := make([]byte, maxBufSize)
return overwriteBuffer(dst, src, srcBuf, dstBuf)
}

func overwriteBuffer(dst io.ReadWriteSeeker, src io.Reader, srcBuf, dstBuf []byte) (written int64, skipped int64, err error) {
var shiftedSrc []byte
dstPos, err := dst.Seek(0, io.SeekCurrent)
if err != nil {
return 0, 0, fmt.Errorf("unable to seek current: %w", err)
}
for {
nrSrc, er1 := src.Read(srcBuf)
if nrSrc == 0 && er1 == io.EOF {
break
}
nrDst, er2 := dst.Read(dstBuf[:nrSrc])
nrMin := min(nrSrc, nrDst)
if bytes.Equal(dstBuf[:nrMin], srcBuf[:nrMin]) {
dstPos += int64(nrMin)
skipped += int64(nrMin)
shiftedSrc = srcBuf[nrMin:nrSrc]
} else {
shiftedSrc = srcBuf[0:nrSrc]
}
// rewind dstPost
dstPos, er2 = dst.Seek(dstPos, io.SeekStart)
if er2 != nil {
err = er2
break
}
nw, ew := dst.Write(shiftedSrc)
if ew != nil {
err = ew
break
}
dstPos += int64(nw)
written += int64(nw)
if er1 != nil {
if er1 != io.EOF {
err = er1
}
break
}
}
return
}
173 changes: 173 additions & 0 deletions pkg/sparsefile/overwrite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package sparsefile

import (
"github.com/stretchr/testify/require"
"io"
"os"
"path"
"testing"
)

func TestOverwrite(t *testing.T) {
tests := []struct {
name string
dstInitial string // Initial content of dst
src string // Content to overwrite dst with
srcBufSize int
dstBufSize int
wantWritten int64 // Expected number of bytes actually written
wantSkipped int64 // Expected number of bytes skipped because they are the same
wantFinal string // Expected final content of dst
wantErr bool // Whether an error is expected
}{
{
name: "overwrite first chunk",
dstInitial: "Hello, World!",
src: "Greetings!",
srcBufSize: 8,
dstBufSize: 8,
wantFinal: "Greetings!ld!",
wantWritten: 10,
wantSkipped: 0,

wantErr: false,
},
{
name: "Partial overwrite, some content same",
dstInitial: "Hello, World!",
src: "Hello, Go!",
wantFinal: "Hello, Go!ld!",
srcBufSize: 8,
dstBufSize: 8,
wantWritten: 3,
wantSkipped: 8, // "Hello, " is the same

wantErr: false,
},
{
name: "Complete match, all skipped",
dstInitial: "Hello, World!",
src: "Hello, World!",
srcBufSize: 8,
dstBufSize: 8,
wantWritten: 0,
wantSkipped: 13,
wantFinal: "Hello, World!",
wantErr: false,
},
{
name: "Complete match, dst buffer larger",
dstInitial: "Hello, World!",
src: "Hello, World!",
srcBufSize: 8,
dstBufSize: 10,
wantWritten: 0,
wantSkipped: 13,
wantFinal: "Hello, World!",
wantErr: false,
},
{
name: "Partial match in the middle",
dstInitial: "Hello, W12345678World!",
src: "123456781234567812345678",
srcBufSize: 8,
dstBufSize: 10,
wantWritten: 16,
wantSkipped: 8,
wantFinal: "123456781234567812345678",
wantErr: false,
},
{
name: "Partial match at the end",
dstInitial: "Hello, W123456",
src: "12345678123456",
srcBufSize: 8,
dstBufSize: 30,
wantWritten: 8,
wantSkipped: 6,
wantFinal: "12345678123456",
wantErr: false,
},
{
name: "dst initial is empty",
dstInitial: "",
src: "12345678123456",
wantFinal: "12345678123456",
srcBufSize: 8,
dstBufSize: 30,
wantWritten: 14,
wantSkipped: 0,

wantErr: false,
},
{
name: "dst initial is larger than src",
dstInitial: "000000000000000000000000000",
src: "12345678123456",
wantFinal: "123456781234560000000000000",
srcBufSize: 8,
dstBufSize: 30,
wantWritten: 14,
wantSkipped: 0,

wantErr: false,
},
{
name: "input is multiplier of src buffer",
dstInitial: "1234567812345678",
src: "1234567812345678",
wantFinal: "1234567812345678",
srcBufSize: 8,
dstBufSize: 8,
wantWritten: 0,
wantSkipped: 16,

wantErr: false,
},
}

createFileWithContent := func(t *testing.T, name string, content string) (f *os.File, closer func()) {
t.Helper()

f, err := os.Create(name)
require.NoError(t, err)
_, err = io.WriteString(f, content)
require.NoError(t, err)
_, err = f.Seek(0, io.SeekStart)
require.NoError(t, err)
return f, func() {
require.NoError(t, f.Close())
}
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmp := t.TempDir()
src, closerSrc := createFileWithContent(t, path.Join(tmp, "src.tmp"), tt.src)
dst, closerDst := createFileWithContent(t, path.Join(tmp, "dst.tmp"), tt.dstInitial)
defer closerDst()
defer closerSrc()

written, skipped, err := overwriteBuffer(dst, src, make([]byte, tt.srcBufSize), make([]byte, tt.dstBufSize))
if (err != nil) != tt.wantErr {
t.Errorf("Overwrite() error = %v, wantErr %v", err, tt.wantErr)
return
}

// Check the final content of dst
_, err = dst.Seek(0, io.SeekStart)
require.NoError(t, err)
finalDstContent, err := io.ReadAll(dst)
require.NoError(t, err)
if string(finalDstContent) != tt.wantFinal {
t.Errorf("Final dst content = %v, want %v", string(finalDstContent), tt.wantFinal)
}
if written != tt.wantWritten {
t.Errorf("Overwrite() written = %v, want %v", written, tt.wantWritten)
}
if skipped != tt.wantSkipped {
t.Errorf("Overwrite() skipped = %v, want %v", skipped, tt.wantSkipped)
}
})
}
}

0 comments on commit 847cc05

Please sign in to comment.