From 012c69540a675e92b8a950d615b8ada8df911dfe Mon Sep 17 00:00:00 2001 From: Benjamin Lefaudeux Date: Thu, 23 Jan 2025 07:50:45 +0000 Subject: [PATCH] - bench tools for both paths - now reproed the issue, looks like it's around the autorotate part - fixing this hopefully --- .github/workflows/go.yml | 2 +- cmd/db/main.go | 90 ++++++++++++++++++++++++++ cmd/{ => filesystem}/main.go | 8 +-- pkg/client.go | 12 ++-- pkg/generator_db.go | 12 ++-- pkg/generator_filesystem.go | 2 +- pkg/serdes.go | 111 ++++++++++++++++++++------------ python/benchmark_db.py | 44 +++++++------ python/go_types.py | 2 +- python/test_datago_db.py | 2 +- tests/client_db_test.go | 2 +- tests/client_filesystem_test.go | 2 +- 12 files changed, 209 insertions(+), 80 deletions(-) create mode 100644 cmd/db/main.go rename cmd/{ => filesystem}/main.go (94%) diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index b7919f9..307a8c8 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -38,7 +38,7 @@ jobs: run: pre-commit run --all-files - name: Build - run: cd cmd && go build -v main.go + run: cd cmd/filesystem && go build -v main.go - name: Test env: diff --git a/cmd/db/main.go b/cmd/db/main.go new file mode 100644 index 0000000..943cffa --- /dev/null +++ b/cmd/db/main.go @@ -0,0 +1,90 @@ +package main + +import ( + datago "datago/pkg" + "flag" + "fmt" + "os" + "runtime/pprof" + "runtime/trace" + "time" +) + +func main() { + + cropAndResize := flag.Bool("crop_and_resize", false, "Whether to crop and resize the images and masks") + itemFetchBuffer := flag.Int("item_fetch_buffer", 256, "The number of items to pre-load") + itemReadyBuffer := flag.Int("item_ready_buffer", 128, "The number of items ready to be served") + limit := flag.Int("limit", 2000, "The number of items to fetch") + profile := flag.Bool("profile", false, "Whether to profile the code") + source := flag.String("source", os.Getenv("DATAGO_TEST_DB"), "The data source to select on") + + // Parse the flags before setting the configuration values + flag.Parse() + + // Initialize the configuration + config := datago.GetDatagoConfig() + + sourceConfig := datago.GetDefaultSourceDBConfig() + sourceConfig.Sources = *source + + config.ImageConfig = datago.GetDefaultImageTransformConfig() + config.ImageConfig.CropAndResize = *cropAndResize + + config.SourceConfig = sourceConfig + config.PrefetchBufferSize = int32(*itemFetchBuffer) + config.SamplesBufferSize = int32(*itemReadyBuffer) + config.Limit = *limit + + dataroom_client := datago.GetClient(config) + + // Go-routine which will feed the sample data to the workers + // and fetch the next page + startTime := time.Now() // Record the start time + + if *profile { + fmt.Println("Profiling the code") + { + f, _ := os.Create("trace.out") + // read with go tool trace trace.out + + err := trace.Start(f) + if err != nil { + panic(err) + } + defer trace.Stop() + } + { + f, _ := os.Create("cpu.prof") + // read with go tool pprof cpu.prof + err := pprof.StartCPUProfile(f) + if err != nil { + panic(err) + } + defer pprof.StopCPUProfile() + } + } + + dataroom_client.Start() + + // Fetch all of the binary payloads as they become available + // NOTE: This is useless, just making sure that we empty the payloads channel + n_samples := 0 + for { + sample := dataroom_client.GetSample() + if sample.ID == "" { + fmt.Println("No more samples") + break + } + n_samples++ + } + + // Cancel the context to kill the goroutines + dataroom_client.Stop() + + // Calculate the elapsed time + elapsedTime := time.Since(startTime) + fps := float64(config.Limit) / elapsedTime.Seconds() + fmt.Printf("Total execution time: %.2f seconds. Samples %d \n", elapsedTime.Seconds(), n_samples) + fmt.Printf("Average throughput: %.2f samples per second \n", fps) +} diff --git a/cmd/main.go b/cmd/filesystem/main.go similarity index 94% rename from cmd/main.go rename to cmd/filesystem/main.go index 3989ace..16d367d 100644 --- a/cmd/main.go +++ b/cmd/filesystem/main.go @@ -29,11 +29,9 @@ func main() { sourceConfig.Rank = 0 sourceConfig.WorldSize = 1 - config.ImageConfig = datago.ImageTransformConfig{ - DefaultImageSize: 1024, - DownsamplingRatio: 32, - CropAndResize: *cropAndResize, - } + config.ImageConfig = datago.GetDefaultImageTransformConfig() + config.ImageConfig.CropAndResize = *cropAndResize + config.SourceConfig = sourceConfig config.PrefetchBufferSize = int32(*itemFetchBuffer) config.SamplesBufferSize = int32(*itemReadyBuffer) diff --git a/pkg/client.go b/pkg/client.go index 8c2e5c2..d953043 100644 --- a/pkg/client.go +++ b/pkg/client.go @@ -51,6 +51,12 @@ func (c *ImageTransformConfig) setDefaults() { c.PreEncodeImages = false } +func GetDefaultImageTransformConfig() ImageTransformConfig { + config := ImageTransformConfig{} + config.setDefaults() + return config +} + // DatagoConfig is the main configuration structure for the datago client type DatagoConfig struct { SourceType DatagoSourceType `json:"source_type"` @@ -161,12 +167,8 @@ type DatagoClient struct { // GetClient is a constructor for the DatagoClient, given a JSON configuration string func GetClient(config DatagoConfig) *DatagoClient { - // Make sure that the GC is run more often than usual - // VIPS will allocate a lot of memory and we want to make sure that it's released as soon as possible - os.Setenv("GOGC", "10") // Default is 100, we're running it when heap is 10% larger than the last GC - // Initialize the vips library - err := os.Setenv("VIPS_DISC_THRESHOLD", "5g") + err := os.Setenv("VIPS_DISC_THRESHOLD", "10g") if err != nil { log.Panicf("Error setting VIPS_DISC_THRESHOLD: %v", err) } diff --git a/pkg/generator_db.go b/pkg/generator_db.go index 7fee9f5..e845fa6 100644 --- a/pkg/generator_db.go +++ b/pkg/generator_db.go @@ -135,6 +135,12 @@ func (c *SourceDBConfig) setDefaults() { c.DuplicateState = -1 } +func GetDefaultSourceDBConfig() SourceDBConfig { + config := SourceDBConfig{} + config.setDefaults() + return config +} + func (c *SourceDBConfig) getDbRequest() dbRequest { fields := "attributes,image_direct_url,source" @@ -210,12 +216,6 @@ func (c *SourceDBConfig) getDbRequest() dbRequest { } } -func GetSourceDBConfig() SourceDBConfig { - config := SourceDBConfig{} - config.setDefaults() - return config -} - type datagoGeneratorDB struct { baseRequest http.Request config SourceDBConfig diff --git a/pkg/generator_filesystem.go b/pkg/generator_filesystem.go index 400ea4e..3add694 100644 --- a/pkg/generator_filesystem.go +++ b/pkg/generator_filesystem.go @@ -31,7 +31,7 @@ func (c *SourceFileSystemConfig) setDefaults() { c.RootPath = os.Getenv("DATAGO_TEST_FILESYSTEM") } -func GetSourceFileSystemConfig() SourceFileSystemConfig { +func GetDefaultSourceFileSystemConfig() SourceFileSystemConfig { config := SourceFileSystemConfig{} config.setDefaults() return config diff --git a/pkg/serdes.go b/pkg/serdes.go index 06c4fc3..5d6b189 100644 --- a/pkg/serdes.go +++ b/pkg/serdes.go @@ -37,83 +37,112 @@ func readBodyBuffered(resp *http.Response) ([]byte, error) { } func imageFromBuffer(buffer []byte, transform *ARAwareTransform, aspectRatio float64, encodeImage bool, isMask bool) (*ImagePayload, float64, error) { - // Decode the image payload using vips - img, err := vips.NewImageFromBuffer(buffer) + // Decode the image payload using vips, using bulletproof settings + importParams := vips.NewImportParams() + importParams.AutoRotate.Set(true) + importParams.FailOnError.Set(true) + importParams.Page.Set(0) + importParams.NumPages.Set(1) + importParams.HeifThumbnail.Set(false) + importParams.SvgUnlimited.Set(false) + + img, err := vips.LoadImageFromBuffer(buffer, importParams) if err != nil { - return nil, -1., err + return nil, -1., fmt.Errorf("error loading image: %w", err) } - err = img.AutoRotate() - if err != nil { - return nil, -1., err - } - - // Optionally crop and resize the image on the fly. Save the aspect ratio in the process for future use - originalWidth, originalHeight := img.Width(), img.Height() - - if transform != nil { - aspectRatio, err = transform.cropAndResizeToClosestAspectRatio(img, aspectRatio) - if err != nil { - return nil, -1., err - } - } - - width, height := img.Width(), img.Height() - // If the image is 4 channels, we need to drop the alpha channel if img.Bands() == 4 { err = img.Flatten(&vips.Color{R: 255, G: 255, B: 255}) // Flatten with white background if err != nil { - fmt.Println("Error flattening image:", err) - return nil, -1., err + return nil, -1., fmt.Errorf("error flattening image: %w", err) } fmt.Println("Image flattened") } // If the image is not a mask but is 1 channel, we want to convert it to 3 channels if (img.Bands() == 1) && !isMask { - err = img.ToColorSpace(vips.InterpretationSRGB) + if img.Metadata().Format == vips.ImageTypeJPEG { + err = img.ToColorSpace(vips.InterpretationSRGB) + if err != nil { + return nil, -1., fmt.Errorf("error converting to sRGB: %w", err) + } + } else { + // // FIXME: maybe that we could recover these still. By default throws an error, sRGB and PNG not supported + return nil, -1., fmt.Errorf("1 channel PNG image not supported") + } + } + + // If the image is 2 channels, that's gray+alpha and we flatten it + if img.Bands() == 2 { + err = img.ExtractBand(1, 1) + fmt.Println("Gray+alpha image, removing alpha") if err != nil { - fmt.Println("Error converting to sRGB:", err) - return nil, -1., err + return nil, -1., fmt.Errorf("error extracting band: %w", err) + } + } + + // Optionally crop and resize the image on the fly. Save the aspect ratio in the process for future use + originalWidth, originalHeight := img.Width(), img.Height() + + if transform != nil { + // Catch possible SIGABRT in libvips and recover from it + defer func() { + if r := recover(); r != nil { + if strings.Contains(fmt.Sprint(r), "SIGABRT") || strings.Contains(fmt.Sprint(r), "SIGSEV") { + err = fmt.Errorf("caught SIGABRT or SIGSEV: %v", r) + } else { + panic(r) // re-throw the panic if it's not SIGABRT + } + } + }() + + aspectRatio, err = transform.cropAndResizeToClosestAspectRatio(img, aspectRatio) + if err != nil { + return nil, -1., fmt.Errorf("error cropping and resizing image: %w", err) } } + width, height := img.Width(), img.Height() + // If requested, re-encode the image to a jpg or png var imgBytes []byte var channels int var bitDepth int if encodeImage { - if err != nil { - return nil, -1., err - } - if img.Bands() == 3 { // Re-encode the image to a jpg imgBytes, _, err = img.ExportJpeg(&vips.JpegExportParams{Quality: 95}) - if err != nil { - return nil, -1., err - } } else { // Re-encode the image to a png - imgBytes, _, err = img.ExportPng(vips.NewPngExportParams()) - if err != nil { - return nil, -1., err - } + imgBytes, _, err = img.ExportPng(&vips.PngExportParams{ + Compression: 6, + Filter: vips.PngFilterNone, + Interlace: false, + Palette: false, + Bitdepth: 8, // force 8 bit depth + }) } + + if err != nil { + return nil, -1., err + } + channels = -1 // Signal that we have encoded the image } else { + channels = img.Bands() imgBytes, err = img.ToBytes() + if err != nil { return nil, -1., err } - channels = img.Bands() - // Define bit depth de facto, not exposed in the vips interface bitDepth = len(imgBytes) / (width * height * channels) * 8 // 8 bits per byte } + defer img.Close() // release vips buffers when done + if bitDepth == 0 && !encodeImage { panic("Bit depth not set") } @@ -195,10 +224,10 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware continue } - // Decode into a flat buffer using vips + // Decode into a flat buffer using vips. Note that this can fail on its own imgPayload_ptr, aspectRatio, err := imageFromBuffer(body_bytes, transform, aspectRatio, encodeImage, isMask) if err != nil { - break + fmt.Print(err) } return imgPayload_ptr, aspectRatio, nil } @@ -253,6 +282,8 @@ func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult d } masks[latent.LatentType] = *mask_ptr } else { + fmt.Println("Loading latents ", latent.URL) + // Vanilla latents, pure binary payloads latentPayload, err := fetchURL(httpClient, latent.URL, retries) if err != nil { diff --git a/python/benchmark_db.py b/python/benchmark_db.py index 3b26d5c..8eaa323 100644 --- a/python/benchmark_db.py +++ b/python/benchmark_db.py @@ -4,6 +4,7 @@ import numpy as np from go_types import go_array_to_pil_image, go_array_to_numpy import typer +import json def benchmark( @@ -15,26 +16,33 @@ def benchmark( require_images: bool = typer.Option(True, help="Request the original images"), require_embeddings: bool = typer.Option(False, help="Request embeddings"), test_masks: bool = typer.Option(True, help="Test masks"), - test_latents: bool = typer.Option(True, help="Test latents"), ): print(f"Running benchmark for {source} - {limit} samples") - - # Get a generic client config - client_config = datago.GetDatagoConfig() - client_config.ImageConfig.CropAndResize = crop_and_resize - - # Specify the source parameters as you see fit - source_config = datago.GetSourceDBConfig() - source_config.Sources = source - source_config.RequireImages = require_images - source_config.RequireEmbeddings = require_embeddings - source_config.HasMasks = "segmentation_mask" if test_masks else "" - source_config.HasLatents = "caption_latent_t5xxl" if test_latents else "" - - # Get a new client instance, happy benchmarking - client_config.SourceConfig = source_config - client = datago.GetClient(client_config) - + client_config = { + "source_type": datago.SourceTypeDB, + "source_config": { + "page_size": 512, + "sources": source, + "require_images": require_images, + "require_embeddings": require_embeddings, + "has_masks": "segmentation_mask" if test_masks else "", + "rank": 0, + "world_size": 1, + }, + "image_config": { + "crop_and_resize": crop_and_resize, + "default_image_size": 512, + "downsampling_ratio": 16, + "min_aspect_ratio": 0.5, + "max_aspect_ratio": 2.0, + "pre_encode_images": False, + }, + "prefetch_buffer_size": 128, + "samples_buffer_size": 64, + "limit": limit, + } + + client = datago.GetClientFromJSON(json.dumps(client_config)) client.Start() # Optional, but good practice to start the client to reduce latency to first sample (while you're instantiating models for instance) start = time.time() diff --git a/python/go_types.py b/python/go_types.py index 41a2d4b..6892ad8 100644 --- a/python/go_types.py +++ b/python/go_types.py @@ -64,5 +64,5 @@ def go_array_to_pil_image(go_array) -> Optional[Image.Image]: if c == 4: return Image.frombuffer("RGBA", (w, h), np_array, "raw", "RGBA", 0, 1) - assert c == 3, "Expected 3 channels" + assert c == 3, f"Expected 3 channels, got {c}" return Image.fromarray(np_array) diff --git a/python/test_datago_db.py b/python/test_datago_db.py index 9f10686..b61683b 100644 --- a/python/test_datago_db.py +++ b/python/test_datago_db.py @@ -45,7 +45,7 @@ def test_get_sample_db(): client_config = datago.GetDatagoConfig() client_config.SamplesBufferSize = 10 - source_config = datago.GetSourceDBConfig() + source_config = datago.GetDefaultSourceDBConfig() source_config.Sources = get_test_source() client_config.SourceConfig = source_config diff --git a/tests/client_db_test.go b/tests/client_db_test.go index d172fb4..17060bb 100644 --- a/tests/client_db_test.go +++ b/tests/client_db_test.go @@ -17,7 +17,7 @@ func get_test_source() string { func get_default_test_config() datago.DatagoConfig { config := datago.GetDatagoConfig() - db_config := datago.GetSourceDBConfig() + db_config := datago.GetDefaultSourceDBConfig() db_config.Sources = get_test_source() db_config.PageSize = 32 config.SourceConfig = db_config diff --git a/tests/client_filesystem_test.go b/tests/client_filesystem_test.go index eeccb96..444a49f 100644 --- a/tests/client_filesystem_test.go +++ b/tests/client_filesystem_test.go @@ -77,7 +77,7 @@ func TestFilesystemLoad(t *testing.T) { // Run the tests config := datago.GetDatagoConfig() - fs_config := datago.GetSourceFileSystemConfig() + fs_config := datago.GetDefaultSourceFileSystemConfig() fs_config.RootPath = test_dir config.SourceConfig = fs_config