diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index b7919f9..307a8c8 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -38,7 +38,7 @@ jobs: run: pre-commit run --all-files - name: Build - run: cd cmd && go build -v main.go + run: cd cmd/filesystem && go build -v main.go - name: Test env: diff --git a/cmd/db/main.go b/cmd/db/main.go new file mode 100644 index 0000000..943cffa --- /dev/null +++ b/cmd/db/main.go @@ -0,0 +1,90 @@ +package main + +import ( + datago "datago/pkg" + "flag" + "fmt" + "os" + "runtime/pprof" + "runtime/trace" + "time" +) + +func main() { + + cropAndResize := flag.Bool("crop_and_resize", false, "Whether to crop and resize the images and masks") + itemFetchBuffer := flag.Int("item_fetch_buffer", 256, "The number of items to pre-load") + itemReadyBuffer := flag.Int("item_ready_buffer", 128, "The number of items ready to be served") + limit := flag.Int("limit", 2000, "The number of items to fetch") + profile := flag.Bool("profile", false, "Whether to profile the code") + source := flag.String("source", os.Getenv("DATAGO_TEST_DB"), "The data source to select on") + + // Parse the flags before setting the configuration values + flag.Parse() + + // Initialize the configuration + config := datago.GetDatagoConfig() + + sourceConfig := datago.GetDefaultSourceDBConfig() + sourceConfig.Sources = *source + + config.ImageConfig = datago.GetDefaultImageTransformConfig() + config.ImageConfig.CropAndResize = *cropAndResize + + config.SourceConfig = sourceConfig + config.PrefetchBufferSize = int32(*itemFetchBuffer) + config.SamplesBufferSize = int32(*itemReadyBuffer) + config.Limit = *limit + + dataroom_client := datago.GetClient(config) + + // Go-routine which will feed the sample data to the workers + // and fetch the next page + startTime := time.Now() // Record the start time + + if *profile { + fmt.Println("Profiling the code") + { + f, _ := os.Create("trace.out") + // read with go tool trace trace.out + + err := trace.Start(f) + if err != nil { + panic(err) + } + defer trace.Stop() + } + { + f, _ := os.Create("cpu.prof") + // read with go tool pprof cpu.prof + err := pprof.StartCPUProfile(f) + if err != nil { + panic(err) + } + defer pprof.StopCPUProfile() + } + } + + dataroom_client.Start() + + // Fetch all of the binary payloads as they become available + // NOTE: This is useless, just making sure that we empty the payloads channel + n_samples := 0 + for { + sample := dataroom_client.GetSample() + if sample.ID == "" { + fmt.Println("No more samples") + break + } + n_samples++ + } + + // Cancel the context to kill the goroutines + dataroom_client.Stop() + + // Calculate the elapsed time + elapsedTime := time.Since(startTime) + fps := float64(config.Limit) / elapsedTime.Seconds() + fmt.Printf("Total execution time: %.2f seconds. Samples %d \n", elapsedTime.Seconds(), n_samples) + fmt.Printf("Average throughput: %.2f samples per second \n", fps) +} diff --git a/cmd/main.go b/cmd/filesystem/main.go similarity index 94% rename from cmd/main.go rename to cmd/filesystem/main.go index 3989ace..16d367d 100644 --- a/cmd/main.go +++ b/cmd/filesystem/main.go @@ -29,11 +29,9 @@ func main() { sourceConfig.Rank = 0 sourceConfig.WorldSize = 1 - config.ImageConfig = datago.ImageTransformConfig{ - DefaultImageSize: 1024, - DownsamplingRatio: 32, - CropAndResize: *cropAndResize, - } + config.ImageConfig = datago.GetDefaultImageTransformConfig() + config.ImageConfig.CropAndResize = *cropAndResize + config.SourceConfig = sourceConfig config.PrefetchBufferSize = int32(*itemFetchBuffer) config.SamplesBufferSize = int32(*itemReadyBuffer) diff --git a/pkg/client.go b/pkg/client.go index 8c2e5c2..27a8235 100644 --- a/pkg/client.go +++ b/pkg/client.go @@ -51,6 +51,12 @@ func (c *ImageTransformConfig) setDefaults() { c.PreEncodeImages = false } +func GetDefaultImageTransformConfig() ImageTransformConfig { + config := ImageTransformConfig{} + config.setDefaults() + return config +} + // DatagoConfig is the main configuration structure for the datago client type DatagoConfig struct { SourceType DatagoSourceType `json:"source_type"` @@ -161,12 +167,8 @@ type DatagoClient struct { // GetClient is a constructor for the DatagoClient, given a JSON configuration string func GetClient(config DatagoConfig) *DatagoClient { - // Make sure that the GC is run more often than usual - // VIPS will allocate a lot of memory and we want to make sure that it's released as soon as possible - os.Setenv("GOGC", "10") // Default is 100, we're running it when heap is 10% larger than the last GC - // Initialize the vips library - err := os.Setenv("VIPS_DISC_THRESHOLD", "5g") + err := os.Setenv("VIPS_DISC_THRESHOLD", "10g") if err != nil { log.Panicf("Error setting VIPS_DISC_THRESHOLD: %v", err) } @@ -255,7 +257,7 @@ func (c *DatagoClient) Start() { if c.imageConfig.CropAndResize { fmt.Println("Cropping and resizing images") fmt.Println("Base image size | downsampling ratio | min | max:", c.imageConfig.DefaultImageSize, c.imageConfig.DownsamplingRatio, c.imageConfig.MinAspectRatio, c.imageConfig.MaxAspectRatio) - arAwareTransform = newARAwareTransform(c.imageConfig) + arAwareTransform = GetArAwareTransform(c.imageConfig) } if c.imageConfig.PreEncodeImages { diff --git a/pkg/generator_db.go b/pkg/generator_db.go index 7fee9f5..e845fa6 100644 --- a/pkg/generator_db.go +++ b/pkg/generator_db.go @@ -135,6 +135,12 @@ func (c *SourceDBConfig) setDefaults() { c.DuplicateState = -1 } +func GetDefaultSourceDBConfig() SourceDBConfig { + config := SourceDBConfig{} + config.setDefaults() + return config +} + func (c *SourceDBConfig) getDbRequest() dbRequest { fields := "attributes,image_direct_url,source" @@ -210,12 +216,6 @@ func (c *SourceDBConfig) getDbRequest() dbRequest { } } -func GetSourceDBConfig() SourceDBConfig { - config := SourceDBConfig{} - config.setDefaults() - return config -} - type datagoGeneratorDB struct { baseRequest http.Request config SourceDBConfig diff --git a/pkg/generator_filesystem.go b/pkg/generator_filesystem.go index 400ea4e..3add694 100644 --- a/pkg/generator_filesystem.go +++ b/pkg/generator_filesystem.go @@ -31,7 +31,7 @@ func (c *SourceFileSystemConfig) setDefaults() { c.RootPath = os.Getenv("DATAGO_TEST_FILESYSTEM") } -func GetSourceFileSystemConfig() SourceFileSystemConfig { +func GetDefaultSourceFileSystemConfig() SourceFileSystemConfig { config := SourceFileSystemConfig{} config.setDefaults() return config diff --git a/pkg/serdes.go b/pkg/serdes.go index 06c4fc3..e94cb2d 100644 --- a/pkg/serdes.go +++ b/pkg/serdes.go @@ -1,7 +1,6 @@ package datago import ( - "bytes" "fmt" "io" "net/http" @@ -11,41 +10,67 @@ import ( "github.com/davidbyttow/govips/v2/vips" ) -func readBodyBuffered(resp *http.Response) ([]byte, error) { - // Use a bytes.Buffer to accumulate the response body - // Faster than the default ioutil.ReadAll which reallocates - var body bytes.Buffer +func sanitizeImage(img *vips.ImageRef, isMask bool) error { - bufferSize := 2048 * 1024 // 2MB - - // Create a fixed-size buffer for reading - localBuffer := make([]byte, bufferSize) + // Catch possible crash in libvips and recover from it + defer func() { + if r := recover(); r != nil { + fmt.Printf("caught crash: %v", r) + } + }() - for { - n, err := resp.Body.Read(localBuffer) - if err != nil && err != io.EOF { - return nil, err + // If the image is 4 channels, we need to drop the alpha channel + if img.Bands() == 4 { + err := img.Flatten(&vips.Color{R: 255, G: 255, B: 255}) // Flatten with white background + if err != nil { + return fmt.Errorf("error flattening image: %w", err) } - if n > 0 { - body.Write(localBuffer[:n]) + fmt.Println("Image flattened") + } + + // If the image is not a mask but is 1 channel, we want to convert it to 3 channels + if (img.Bands() == 1) && !isMask { + if img.Metadata().Format == vips.ImageTypeJPEG { + err := img.ToColorSpace(vips.InterpretationSRGB) + if err != nil { + return fmt.Errorf("error converting to sRGB: %w", err) + } + } else { + // // FIXME: maybe that we could recover these still. By default throws an error, sRGB and PNG not supported + return fmt.Errorf("1 channel PNG image not supported") } - if err == io.EOF { - break + } + + // If the image is 2 channels, that's gray+alpha and we flatten it + if img.Bands() == 2 { + err := img.ExtractBand(1, 1) + fmt.Println("Gray+alpha image, removing alpha") + if err != nil { + return fmt.Errorf("error extracting band: %w", err) } } - return body.Bytes(), nil + + return nil } func imageFromBuffer(buffer []byte, transform *ARAwareTransform, aspectRatio float64, encodeImage bool, isMask bool) (*ImagePayload, float64, error) { - // Decode the image payload using vips - img, err := vips.NewImageFromBuffer(buffer) + // Decode the image payload using vips, using bulletproof settings + importParams := vips.NewImportParams() + importParams.AutoRotate.Set(true) + importParams.FailOnError.Set(true) + importParams.Page.Set(0) + importParams.NumPages.Set(1) + importParams.HeifThumbnail.Set(false) + importParams.SvgUnlimited.Set(false) + + img, err := vips.LoadImageFromBuffer(buffer, importParams) if err != nil { - return nil, -1., err + return nil, -1., fmt.Errorf("error loading image: %w", err) } - err = img.AutoRotate() + err = sanitizeImage(img, isMask) if err != nil { - return nil, -1., err + return nil, -1., fmt.Errorf("error processing image: %w", err) } // Optionally crop and resize the image on the fly. Save the aspect ratio in the process for future use @@ -54,66 +79,50 @@ func imageFromBuffer(buffer []byte, transform *ARAwareTransform, aspectRatio flo if transform != nil { aspectRatio, err = transform.cropAndResizeToClosestAspectRatio(img, aspectRatio) if err != nil { - return nil, -1., err + return nil, -1., fmt.Errorf("error cropping and resizing image: %w", err) } } width, height := img.Width(), img.Height() - // If the image is 4 channels, we need to drop the alpha channel - if img.Bands() == 4 { - err = img.Flatten(&vips.Color{R: 255, G: 255, B: 255}) // Flatten with white background - if err != nil { - fmt.Println("Error flattening image:", err) - return nil, -1., err - } - fmt.Println("Image flattened") - } - - // If the image is not a mask but is 1 channel, we want to convert it to 3 channels - if (img.Bands() == 1) && !isMask { - err = img.ToColorSpace(vips.InterpretationSRGB) - if err != nil { - fmt.Println("Error converting to sRGB:", err) - return nil, -1., err - } - } - // If requested, re-encode the image to a jpg or png var imgBytes []byte var channels int var bitDepth int if encodeImage { - if err != nil { - return nil, -1., err - } - if img.Bands() == 3 { // Re-encode the image to a jpg imgBytes, _, err = img.ExportJpeg(&vips.JpegExportParams{Quality: 95}) - if err != nil { - return nil, -1., err - } } else { // Re-encode the image to a png - imgBytes, _, err = img.ExportPng(vips.NewPngExportParams()) - if err != nil { - return nil, -1., err - } + imgBytes, _, err = img.ExportPng(&vips.PngExportParams{ + Compression: 6, + Filter: vips.PngFilterNone, + Interlace: false, + Palette: false, + Bitdepth: 8, // force 8 bit depth + }) + } + + if err != nil { + return nil, -1., err } + channels = -1 // Signal that we have encoded the image } else { + channels = img.Bands() imgBytes, err = img.ToBytes() + if err != nil { return nil, -1., err } - channels = img.Bands() - // Define bit depth de facto, not exposed in the vips interface bitDepth = len(imgBytes) / (width * height * channels) * 8 // 8 bits per byte } + defer img.Close() // release vips buffers when done + if bitDepth == 0 && !encodeImage { panic("Bit depth not set") } @@ -145,14 +154,9 @@ func fetchURL(client *http.Client, url string, retries int) (urlPayload, error) continue } - defer func() { - err := resp.Body.Close() - if err != nil { - fmt.Print(err) - } - }() + body_bytes, err := io.ReadAll(resp.Body) + resp.Body.Close() - body_bytes, err := readBodyBuffered(resp) if err != nil { // Renew the http client, not a shared resource client = &http.Client{Timeout: 30 * time.Second} @@ -181,31 +185,29 @@ func fetchImage(client *http.Client, url string, retries int, transform *ARAware continue } - defer func() { - err := resp.Body.Close() - if err != nil { - fmt.Print(err) - } - }() + body_bytes, err := io.ReadAll(resp.Body) + resp.Body.Close() - body_bytes, err := readBodyBuffered(resp) if err != nil { errReport = err exponentialBackoffWait(i) continue } - // Decode into a flat buffer using vips - imgPayload_ptr, aspectRatio, err := imageFromBuffer(body_bytes, transform, aspectRatio, encodeImage, isMask) - if err != nil { - break + // Decode into a flat buffer using vips. Note that this can fail on its own + { + imgPayload_ptr, aspectRatio, err := imageFromBuffer(body_bytes, transform, aspectRatio, encodeImage, isMask) + if err != nil { + errReport = err + continue + } + return imgPayload_ptr, aspectRatio, nil } - return imgPayload_ptr, aspectRatio, nil } return nil, -1., errReport } -func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult dbSampleMetadata, transform *ARAwareTransform, encodeImage bool) *Sample { +func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult dbSampleMetadata, transform *ARAwareTransform, encodeImage bool) (*Sample, error) { // Per sample work: // - fetch the raw payloads // - deserialize / decode, depending on the types @@ -219,10 +221,8 @@ func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult d // Base image if config.RequireImages { baseImage, newAspectRatio, err := fetchImage(httpClient, sampleResult.ImageDirectURL, retries, transform, aspectRatio, encodeImage, false) - if err != nil { - fmt.Println("Error fetching image:", sampleResult.Id) - return nil + return nil, fmt.Errorf("error fetching image: %v", sampleResult.Id) } else { imgPayload = baseImage aspectRatio = newAspectRatio @@ -239,8 +239,7 @@ func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult d // Image types, registered as latents but they need to be jpg-decoded new_image, _, err := fetchImage(httpClient, latent.URL, retries, transform, aspectRatio, encodeImage, false) if err != nil { - fmt.Println("Error fetching masked image:", sampleResult.Id, latent.LatentType) - return nil + return nil, fmt.Errorf("error fetching masked image: %v %v", sampleResult.Id, latent.LatentType) } extraImages[latent.LatentType] = *new_image @@ -248,16 +247,16 @@ func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult d // Mask types, registered as latents but they need to be png-decoded mask_ptr, _, err := fetchImage(httpClient, latent.URL, retries, transform, aspectRatio, encodeImage, true) if err != nil { - fmt.Println("Error fetching mask:", sampleResult.Id, latent.LatentType) - return nil + return nil, fmt.Errorf("error fetching mask: %v %v", sampleResult.Id, latent.LatentType) } masks[latent.LatentType] = *mask_ptr } else { + fmt.Println("Loading latents ", latent.URL) + // Vanilla latents, pure binary payloads latentPayload, err := fetchURL(httpClient, latent.URL, retries) if err != nil { - fmt.Println("Error fetching latent:", err) - return nil + return nil, fmt.Errorf("error fetching latent: %v", err) } latents[latent.LatentType] = LatentPayload{ @@ -282,5 +281,5 @@ func fetchSample(config *SourceDBConfig, httpClient *http.Client, sampleResult d Masks: masks, AdditionalImages: extraImages, Tags: sampleResult.Tags, - CocaEmbedding: cocaEmbedding} + CocaEmbedding: cocaEmbedding}, nil } diff --git a/pkg/transforms.go b/pkg/transforms.go index 4c08b09..ad79266 100644 --- a/pkg/transforms.go +++ b/pkg/transforms.go @@ -28,7 +28,7 @@ type ARAwareTransform struct { PreEncodeImages bool } -func buildImageSizeList(defaultImageSize int, downsamplingRatio int, minAspectRatio float64, maxAspectRatio float64) []ImageSize { +func BuildImageSizeList(defaultImageSize int, downsamplingRatio int, minAspectRatio float64, maxAspectRatio float64) []ImageSize { patchSize := defaultImageSize / downsamplingRatio patchSizeSq := float64(patchSize * patchSize) var imgSizes []ImageSize @@ -53,9 +53,9 @@ func buildImageSizeList(defaultImageSize int, downsamplingRatio int, minAspectRa return imgSizes } -func newARAwareTransform(imageConfig ImageTransformConfig) *ARAwareTransform { +func GetArAwareTransform(imageConfig ImageTransformConfig) *ARAwareTransform { // Build the image size list - imgSizes := buildImageSizeList(imageConfig.DefaultImageSize, imageConfig.DownsamplingRatio, imageConfig.MinAspectRatio, imageConfig.MaxAspectRatio) + imgSizes := BuildImageSizeList(imageConfig.DefaultImageSize, imageConfig.DownsamplingRatio, imageConfig.MinAspectRatio, imageConfig.MaxAspectRatio) // Fill in the map table to match aspect ratios and image sizes aspectRatioToSize := make(map[float64]ImageSize) @@ -74,7 +74,7 @@ func newARAwareTransform(imageConfig ImageTransformConfig) *ARAwareTransform { } } -func (t *ARAwareTransform) getClosestAspectRatio(imageWidth int, imageHeight int) float64 { +func (t *ARAwareTransform) GetClosestAspectRatio(imageWidth int, imageHeight int) float64 { // Find the closest aspect ratio to the given aspect ratio if len(t.aspectRatioToSize) == 0 { fmt.Println("Aspect ratio to size map is empty") @@ -98,17 +98,27 @@ func (t *ARAwareTransform) getClosestAspectRatio(imageWidth int, imageHeight int return closestAspectRatio } -func (t *ARAwareTransform) cropAndResizeToClosestAspectRatio(image *vips.ImageRef, referenceAR float64) (float64, error) { +func safeCrop(image *vips.ImageRef, width, height int) error { + // Catch possible crash in libvips and recover from it + defer func() { + if r := recover(); r != nil { + fmt.Printf("caught crash: %v", r) + } + }() + err := image.ThumbnailWithSize(width, height, vips.InterestingCentre, vips.SizeBoth) + return err +} +func (t *ARAwareTransform) cropAndResizeToClosestAspectRatio(image *vips.ImageRef, referenceAR float64) (float64, error) { // Get the closest aspect ratio if referenceAR <= 0. { - referenceAR = t.getClosestAspectRatio(image.Width(), image.Height()) + referenceAR = t.GetClosestAspectRatio(image.Width(), image.Height()) } // Desired target size is a lookup away, this is pre-computed/bucketed targetSize := t.aspectRatioToSize[referenceAR] // Trust libvips to do resize and crop in one go. Note that jpg decoding happens here and can fail - err := image.ThumbnailWithSize(targetSize.Width, targetSize.Height, vips.InterestingCentre, vips.SizeBoth) + err := safeCrop(image, targetSize.Width, targetSize.Height) return referenceAR, err } diff --git a/pkg/worker_filesystem.go b/pkg/worker_filesystem.go index 5063f5d..5113ebf 100644 --- a/pkg/worker_filesystem.go +++ b/pkg/worker_filesystem.go @@ -9,37 +9,37 @@ type BackendFileSystem struct { config *DatagoConfig } -func loadFromDisk(fsSample fsSampleMetadata, transform *ARAwareTransform, encodeImage bool) *Sample { +func loadFromDisk(fsSample fsSampleMetadata, transform *ARAwareTransform, encodeImage bool) (*Sample, error) { // Load the file into []bytes bytesBuffer, err := os.ReadFile(fsSample.FilePath) if err != nil { - fmt.Println("Error reading file:", fsSample.FilePath) - return nil + return nil, fmt.Errorf("error reading file: %v", fsSample.FilePath) } // Slightly faster take, requires Go 1.21+ which breaks gopy speed for now - // // Using mmap to put the file directly into memory, removes buffering needs - // r, err := mmap.Open(fsSample.FilePath) - // if err != nil { - // panic(err) - // } + { + // // Using mmap to put the file directly into memory, removes buffering needs + // r, err := mmap.Open(fsSample.FilePath) + // if err != nil { + // panic(err) + // } - // bytesBuffer := make([]byte, r.Len()) - // _, err = r.ReadAt(bytesBuffer, 0) - // if err != nil { - // panic(err) - // } + // bytesBuffer := make([]byte, r.Len()) + // _, err = r.ReadAt(bytesBuffer, 0) + // if err != nil { + // panic(err) + // } + } // Decode the image, can error out here also, and return the sample imgPayload, _, err := imageFromBuffer(bytesBuffer, transform, -1., encodeImage, false) if err != nil { - fmt.Println("Error loading image:", fsSample.FileName) - return nil + return nil, fmt.Errorf("error loading image: %v", fsSample.FileName) } return &Sample{ID: fsSample.FileName, Image: *imgPayload, - } + }, nil } func (b BackendFileSystem) collectSamples(chanSampleMetadata chan SampleDataPointers, chanSamples chan Sample, transform *ARAwareTransform, encodeImages bool) { @@ -62,8 +62,8 @@ func (b BackendFileSystem) collectSamples(chanSampleMetadata chan SampleDataPoin panic("Failed to cast the item to fetch to dbSampleMetadata. This worker is probably misconfigured") } - sample := loadFromDisk(fsSample, transform, encodeImages) - if sample != nil { + sample, err := loadFromDisk(fsSample, transform, encodeImages) + if err == nil && sample != nil { chanSamples <- *sample } } diff --git a/pkg/worker_http.go b/pkg/worker_http.go index 373384f..f66e5f9 100644 --- a/pkg/worker_http.go +++ b/pkg/worker_http.go @@ -31,8 +31,8 @@ func (b BackendHTTP) collectSamples(chanSampleMetadata chan SampleDataPointers, panic("Failed to cast the item to fetch to dbSampleMetadata. This worker is probably misconfigured") } - sample := fetchSample(b.config, &httpClient, httpSample, transform, encodeImages) - if sample != nil { + sample, err := fetchSample(b.config, &httpClient, httpSample, transform, encodeImages) + if err == nil && sample != nil { chanSamples <- *sample } } diff --git a/python/benchmark_db.py b/python/benchmark_db.py index 3b26d5c..8eaa323 100644 --- a/python/benchmark_db.py +++ b/python/benchmark_db.py @@ -4,6 +4,7 @@ import numpy as np from go_types import go_array_to_pil_image, go_array_to_numpy import typer +import json def benchmark( @@ -15,26 +16,33 @@ def benchmark( require_images: bool = typer.Option(True, help="Request the original images"), require_embeddings: bool = typer.Option(False, help="Request embeddings"), test_masks: bool = typer.Option(True, help="Test masks"), - test_latents: bool = typer.Option(True, help="Test latents"), ): print(f"Running benchmark for {source} - {limit} samples") - - # Get a generic client config - client_config = datago.GetDatagoConfig() - client_config.ImageConfig.CropAndResize = crop_and_resize - - # Specify the source parameters as you see fit - source_config = datago.GetSourceDBConfig() - source_config.Sources = source - source_config.RequireImages = require_images - source_config.RequireEmbeddings = require_embeddings - source_config.HasMasks = "segmentation_mask" if test_masks else "" - source_config.HasLatents = "caption_latent_t5xxl" if test_latents else "" - - # Get a new client instance, happy benchmarking - client_config.SourceConfig = source_config - client = datago.GetClient(client_config) - + client_config = { + "source_type": datago.SourceTypeDB, + "source_config": { + "page_size": 512, + "sources": source, + "require_images": require_images, + "require_embeddings": require_embeddings, + "has_masks": "segmentation_mask" if test_masks else "", + "rank": 0, + "world_size": 1, + }, + "image_config": { + "crop_and_resize": crop_and_resize, + "default_image_size": 512, + "downsampling_ratio": 16, + "min_aspect_ratio": 0.5, + "max_aspect_ratio": 2.0, + "pre_encode_images": False, + }, + "prefetch_buffer_size": 128, + "samples_buffer_size": 64, + "limit": limit, + } + + client = datago.GetClientFromJSON(json.dumps(client_config)) client.Start() # Optional, but good practice to start the client to reduce latency to first sample (while you're instantiating models for instance) start = time.time() diff --git a/python/go_types.py b/python/go_types.py index 41a2d4b..6892ad8 100644 --- a/python/go_types.py +++ b/python/go_types.py @@ -64,5 +64,5 @@ def go_array_to_pil_image(go_array) -> Optional[Image.Image]: if c == 4: return Image.frombuffer("RGBA", (w, h), np_array, "raw", "RGBA", 0, 1) - assert c == 3, "Expected 3 channels" + assert c == 3, f"Expected 3 channels, got {c}" return Image.fromarray(np_array) diff --git a/python/test_datago_db.py b/python/test_datago_db.py index 9f10686..b61683b 100644 --- a/python/test_datago_db.py +++ b/python/test_datago_db.py @@ -45,7 +45,7 @@ def test_get_sample_db(): client_config = datago.GetDatagoConfig() client_config.SamplesBufferSize = 10 - source_config = datago.GetSourceDBConfig() + source_config = datago.GetDefaultSourceDBConfig() source_config.Sources = get_test_source() client_config.SourceConfig = source_config diff --git a/tests/client_db_test.go b/tests/client_db_test.go index d172fb4..17060bb 100644 --- a/tests/client_db_test.go +++ b/tests/client_db_test.go @@ -17,7 +17,7 @@ func get_test_source() string { func get_default_test_config() datago.DatagoConfig { config := datago.GetDatagoConfig() - db_config := datago.GetSourceDBConfig() + db_config := datago.GetDefaultSourceDBConfig() db_config.Sources = get_test_source() db_config.PageSize = 32 config.SourceConfig = db_config diff --git a/tests/client_filesystem_test.go b/tests/client_filesystem_test.go index eeccb96..444a49f 100644 --- a/tests/client_filesystem_test.go +++ b/tests/client_filesystem_test.go @@ -77,7 +77,7 @@ func TestFilesystemLoad(t *testing.T) { // Run the tests config := datago.GetDatagoConfig() - fs_config := datago.GetSourceFileSystemConfig() + fs_config := datago.GetDefaultSourceFileSystemConfig() fs_config.RootPath = test_dir config.SourceConfig = fs_config diff --git a/tests/transform_test.go b/tests/transform_test.go new file mode 100644 index 0000000..100c026 --- /dev/null +++ b/tests/transform_test.go @@ -0,0 +1,63 @@ +package datago_test + +import ( + "math" + "math/rand" + "testing" + + datago "datago/pkg" +) + +func TestArAwareTransform(t *testing.T) { + imageConfig := datago.ImageTransformConfig{ + DefaultImageSize: 1024, + DownsamplingRatio: 32, + MinAspectRatio: 0.5, + MaxAspectRatio: 2.0, + } + + transform := datago.GetArAwareTransform(imageConfig) + + // For a couple of image sizes, check that we get the expected aspect ratio + sizes := map[string][2]int{ + "1024x1024": {1024, 1024}, + "704x1440": {704, 1440}, + "736x1408": {736, 1408}, + "736x1376": {736, 1376}, + "768x1344": {768, 1344}, + "768x1312": {768, 1312}, + "800x1280": {800, 1280}, + "832x1248": {832, 1248}, + "832x1216": {832, 1216}, + "864x1184": {864, 1184}, + "896x1152": {896, 1152}, + "928x1120": {928, 1120}, + "960x1088": {960, 1088}, + "992x1056": {992, 1056}, + "1056x992": {1056, 992}, + "1088x960": {1088, 960}, + "1120x928": {1120, 928}, + "1152x896": {1152, 896}, + "1184x864": {1184, 864}, + "1216x832": {1216, 832}, + "1248x832": {1248, 832}, + "1280x800": {1280, 800}, + "1312x768": {1312, 768}, + "1344x768": {1344, 768}, + "1376x736": {1376, 736}, + "1408x736": {1408, 736}, + "1440x704": {1440, 704}, + } + + for size, dimensions := range sizes { + // Fuzz the sizes by a random factor + fuz := rand.Intn(100) + 1 + fuz2 := rand.Intn(5) + + transformedSize := transform.GetClosestAspectRatio(dimensions[0]*fuz+fuz2, dimensions[1]*fuz) + if math.Abs(transformedSize-float64(dimensions[0])/float64(dimensions[1])) > 1e-3 { + t.Error("Aspect ratio mismatch") + t.Logf("Size: %s, Aspect Ratio: %f", size, transformedSize) + } + } +}