Skip to content

Commit

Permalink
- bench tools for both paths
Browse files Browse the repository at this point in the history
- now reproed the issue, looks like it's around the autorotate part
- fixing this hopefully
  • Loading branch information
Benjamin Lefaudeux committed Jan 24, 2025
1 parent bc3a3c0 commit 012c695
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 80 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: pre-commit run --all-files

- name: Build
run: cd cmd && go build -v main.go
run: cd cmd/filesystem && go build -v main.go

- name: Test
env:
Expand Down
90 changes: 90 additions & 0 deletions cmd/db/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package main

import (
datago "datago/pkg"
"flag"
"fmt"
"os"
"runtime/pprof"
"runtime/trace"
"time"
)

func main() {

cropAndResize := flag.Bool("crop_and_resize", false, "Whether to crop and resize the images and masks")
itemFetchBuffer := flag.Int("item_fetch_buffer", 256, "The number of items to pre-load")
itemReadyBuffer := flag.Int("item_ready_buffer", 128, "The number of items ready to be served")
limit := flag.Int("limit", 2000, "The number of items to fetch")
profile := flag.Bool("profile", false, "Whether to profile the code")
source := flag.String("source", os.Getenv("DATAGO_TEST_DB"), "The data source to select on")

// Parse the flags before setting the configuration values
flag.Parse()

// Initialize the configuration
config := datago.GetDatagoConfig()

sourceConfig := datago.GetDefaultSourceDBConfig()
sourceConfig.Sources = *source

config.ImageConfig = datago.GetDefaultImageTransformConfig()
config.ImageConfig.CropAndResize = *cropAndResize

config.SourceConfig = sourceConfig
config.PrefetchBufferSize = int32(*itemFetchBuffer)
config.SamplesBufferSize = int32(*itemReadyBuffer)
config.Limit = *limit

dataroom_client := datago.GetClient(config)

// Go-routine which will feed the sample data to the workers
// and fetch the next page
startTime := time.Now() // Record the start time

if *profile {
fmt.Println("Profiling the code")
{
f, _ := os.Create("trace.out")
// read with go tool trace trace.out

err := trace.Start(f)
if err != nil {
panic(err)
}
defer trace.Stop()
}
{
f, _ := os.Create("cpu.prof")
// read with go tool pprof cpu.prof
err := pprof.StartCPUProfile(f)
if err != nil {
panic(err)
}
defer pprof.StopCPUProfile()
}
}

dataroom_client.Start()

// Fetch all of the binary payloads as they become available
// NOTE: This is useless, just making sure that we empty the payloads channel
n_samples := 0
for {
sample := dataroom_client.GetSample()
if sample.ID == "" {
fmt.Println("No more samples")
break
}
n_samples++
}

// Cancel the context to kill the goroutines
dataroom_client.Stop()

// Calculate the elapsed time
elapsedTime := time.Since(startTime)
fps := float64(config.Limit) / elapsedTime.Seconds()
fmt.Printf("Total execution time: %.2f seconds. Samples %d \n", elapsedTime.Seconds(), n_samples)
fmt.Printf("Average throughput: %.2f samples per second \n", fps)
}
8 changes: 3 additions & 5 deletions cmd/main.go → cmd/filesystem/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ func main() {
sourceConfig.Rank = 0
sourceConfig.WorldSize = 1

config.ImageConfig = datago.ImageTransformConfig{
DefaultImageSize: 1024,
DownsamplingRatio: 32,
CropAndResize: *cropAndResize,
}
config.ImageConfig = datago.GetDefaultImageTransformConfig()
config.ImageConfig.CropAndResize = *cropAndResize

config.SourceConfig = sourceConfig
config.PrefetchBufferSize = int32(*itemFetchBuffer)
config.SamplesBufferSize = int32(*itemReadyBuffer)
Expand Down
12 changes: 7 additions & 5 deletions pkg/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ func (c *ImageTransformConfig) setDefaults() {
c.PreEncodeImages = false
}

func GetDefaultImageTransformConfig() ImageTransformConfig {
config := ImageTransformConfig{}
config.setDefaults()
return config
}

// DatagoConfig is the main configuration structure for the datago client
type DatagoConfig struct {
SourceType DatagoSourceType `json:"source_type"`
Expand Down Expand Up @@ -161,12 +167,8 @@ type DatagoClient struct {

// GetClient is a constructor for the DatagoClient, given a JSON configuration string
func GetClient(config DatagoConfig) *DatagoClient {
// Make sure that the GC is run more often than usual
// VIPS will allocate a lot of memory and we want to make sure that it's released as soon as possible
os.Setenv("GOGC", "10") // Default is 100, we're running it when heap is 10% larger than the last GC

// Initialize the vips library
err := os.Setenv("VIPS_DISC_THRESHOLD", "5g")
err := os.Setenv("VIPS_DISC_THRESHOLD", "10g")
if err != nil {
log.Panicf("Error setting VIPS_DISC_THRESHOLD: %v", err)
}
Expand Down
12 changes: 6 additions & 6 deletions pkg/generator_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ func (c *SourceDBConfig) setDefaults() {
c.DuplicateState = -1
}

func GetDefaultSourceDBConfig() SourceDBConfig {
config := SourceDBConfig{}
config.setDefaults()
return config
}

func (c *SourceDBConfig) getDbRequest() dbRequest {

fields := "attributes,image_direct_url,source"
Expand Down Expand Up @@ -210,12 +216,6 @@ func (c *SourceDBConfig) getDbRequest() dbRequest {
}
}

func GetSourceDBConfig() SourceDBConfig {
config := SourceDBConfig{}
config.setDefaults()
return config
}

type datagoGeneratorDB struct {
baseRequest http.Request
config SourceDBConfig
Expand Down
2 changes: 1 addition & 1 deletion pkg/generator_filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (c *SourceFileSystemConfig) setDefaults() {
c.RootPath = os.Getenv("DATAGO_TEST_FILESYSTEM")
}

func GetSourceFileSystemConfig() SourceFileSystemConfig {
func GetDefaultSourceFileSystemConfig() SourceFileSystemConfig {
config := SourceFileSystemConfig{}
config.setDefaults()
return config
Expand Down
111 changes: 71 additions & 40 deletions pkg/serdes.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,83 +37,112 @@ func readBodyBuffered(resp *http.Response) ([]byte, error) {
}

func imageFromBuffer(buffer []byte, transform *ARAwareTransform, aspectRatio float64, encodeImage bool, isMask bool) (*ImagePayload, float64, error) {
// Decode the image payload using vips
img, err := vips.NewImageFromBuffer(buffer)
// Decode the image payload using vips, using bulletproof settings
importParams := vips.NewImportParams()
importParams.AutoRotate.Set(true)
importParams.FailOnError.Set(true)
importParams.Page.Set(0)
importParams.NumPages.Set(1)
importParams.HeifThumbnail.Set(false)
importParams.SvgUnlimited.Set(false)

img, err := vips.LoadImageFromBuffer(buffer, importParams)
if err != nil {
return nil, -1., err
return nil, -1., fmt.Errorf("error loading image: %w", err)
}

err = img.AutoRotate()
if err != nil {
return nil, -1., err
}

// Optionally crop and resize the image on the fly. Save the aspect ratio in the process for future use
originalWidth, originalHeight := img.Width(), img.Height()

if transform != nil {
aspectRatio, err = transform.cropAndResizeToClosestAspectRatio(img, aspectRatio)
if err != nil {
return nil, -1., err
}
}

width, height := img.Width(), img.Height()

// If the image is 4 channels, we need to drop the alpha channel
if img.Bands() == 4 {
err = img.Flatten(&vips.Color{R: 255, G: 255, B: 255}) // Flatten with white background
if err != nil {
fmt.Println("Error flattening image:", err)
return nil, -1., err
return nil, -1., fmt.Errorf("error flattening image: %w", err)
}
fmt.Println("Image flattened")
}

// If the image is not a mask but is 1 channel, we want to convert it to 3 channels
if (img.Bands() == 1) && !isMask {
err = img.ToColorSpace(vips.InterpretationSRGB)
if img.Metadata().Format == vips.ImageTypeJPEG {
err = img.ToColorSpace(vips.InterpretationSRGB)
if err != nil {
return nil, -1., fmt.Errorf("error converting to sRGB: %w", err)
}
} else {
// // FIXME: maybe that we could recover these still. By default throws an error, sRGB and PNG not supported
return nil, -1., fmt.Errorf("1 channel PNG image not supported")
}
}

// If the image is 2 channels, that's gray+alpha and we flatten it
if img.Bands() == 2 {
err = img.ExtractBand(1, 1)
fmt.Println("Gray+alpha image, removing alpha")
if err != nil {
fmt.Println("Error converting to sRGB:", err)
return nil, -1., err
return nil, -1., fmt.Errorf("error extracting band: %w", err)
}
}

// Optionally crop and resize the image on the fly. Save the aspect ratio in the process for future use
originalWidth, originalHeight := img.Width(), img.Height()

if transform != nil {
// Catch possible SIGABRT in libvips and recover from it
defer func() {
if r := recover(); r != nil {
if strings.Contains(fmt.Sprint(r), "SIGABRT") || strings.Contains(fmt.Sprint(r), "SIGSEV") {
err = fmt.Errorf("caught SIGABRT or SIGSEV: %v", r)
} else {
panic(r) // re-throw the panic if it's not SIGABRT
}
}
}()

aspectRatio, err = transform.cropAndResizeToClosestAspectRatio(img, aspectRatio)
if err != nil {
return nil, -1., fmt.Errorf("error cropping and resizing image: %w", err)
}
}

width, height := img.Width(), img.Height()

// If requested, re-encode the image to a jpg or png
var imgBytes []byte
var channels int
var bitDepth int

if encodeImage {
if err != nil {
return nil, -1., err
}

if img.Bands() == 3 {
// Re-encode the image to a jpg
imgBytes, _, err = img.ExportJpeg(&vips.JpegExportParams{Quality: 95})
if err != nil {
return nil, -1., err
}
} else {
// Re-encode the image to a png
imgBytes, _, err = img.ExportPng(vips.NewPngExportParams())
if err != nil {
return nil, -1., err
}
imgBytes, _, err = img.ExportPng(&vips.PngExportParams{
Compression: 6,
Filter: vips.PngFilterNone,
Interlace: false,
Palette: false,
Bitdepth: 8, // force 8 bit depth
})
}

if err != nil {
return nil, -1., err
}

channels = -1 // Signal that we have encoded the image
} else {
channels = img.Bands()
imgBytes, err = img.ToBytes()

if err != nil {
return nil, -1., err
}
channels = img.Bands()

// Define bit depth de facto, not exposed in the vips interface
bitDepth = len(imgBytes) / (width * height * channels) * 8 // 8 bits per byte
}

defer img.Close() // release vips buffers when done

if bitDepth == 0 && !encodeImage {
panic("Bit depth not set")
}
Expand Down Expand Up @@ -195,10 +224,10 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware
continue
}

// Decode into a flat buffer using vips
// Decode into a flat buffer using vips. Note that this can fail on its own
imgPayload_ptr, aspectRatio, err := imageFromBuffer(body_bytes, transform, aspectRatio, encodeImage, isMask)
if err != nil {
break
fmt.Print(err)
}
return imgPayload_ptr, aspectRatio, nil
}
Expand Down Expand Up @@ -253,6 +282,8 @@ func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult d
}
masks[latent.LatentType] = *mask_ptr
} else {
fmt.Println("Loading latents ", latent.URL)

// Vanilla latents, pure binary payloads
latentPayload, err := fetchURL(httpClient, latent.URL, retries)
if err != nil {
Expand Down
Loading

0 comments on commit 012c695

Please sign in to comment.