diff --git a/resources/services/sagemaker/training_jobs.go b/resources/services/sagemaker/training_jobs.go index 40025397f..7fe0ceed2 100644 --- a/resources/services/sagemaker/training_jobs.go +++ b/resources/services/sagemaker/training_jobs.go @@ -10,8 +10,12 @@ import ( "github.com/cloudquery/cq-provider-aws/client" "github.com/cloudquery/cq-provider-sdk/provider/diag" "github.com/cloudquery/cq-provider-sdk/provider/schema" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" ) +const MAX_GOROUTINES = 10 + func SagemakerTrainingJobs() *schema.Table { return &schema.Table{ Name: "aws_sagemaker_training_jobs", @@ -562,10 +566,27 @@ func SagemakerTrainingJobs() *schema.Table { // Table Resolver Functions // ==================================================================================================================== +func fetchTrainingJobDefinition(ctx context.Context, res chan<- interface{}, svc client.SageMakerClient, region string, n types.TrainingJobSummary) error { + config := sagemaker.DescribeTrainingJobInput{ + TrainingJobName: n.TrainingJobName, + } + response, err := svc.DescribeTrainingJob(ctx, &config, func(options *sagemaker.Options) { + options.Region = region + }) + if err != nil { + return diag.WrapError(err) + } + + res <- response + return nil +} + func fetchSagemakerTrainingJobs(ctx context.Context, meta schema.ClientMeta, _ *schema.Resource, res chan<- interface{}) error { c := meta.(*client.Client) svc := c.Services().SageMaker config := sagemaker.ListTrainingJobsInput{} + var sem = semaphore.NewWeighted(int64(MAX_GOROUTINES)) + for { response, err := svc.ListTrainingJobs(ctx, &config, func(options *sagemaker.Options) { options.Region = c.Region @@ -573,23 +594,22 @@ func fetchSagemakerTrainingJobs(ctx context.Context, meta schema.ClientMeta, _ * if err != nil { return diag.WrapError(err) } - - // get more details about the notebook instance - for _, n := range response.TrainingJobSummaries { - - config := sagemaker.DescribeTrainingJobInput{ - TrainingJobName: n.TrainingJobName, - } - response, err := svc.DescribeTrainingJob(ctx, &config, func(options *sagemaker.Options) { - options.Region = c.Region - }) - if err != nil { + errs, ctx := errgroup.WithContext(ctx) + for _, d := range response.TrainingJobSummaries { + if err := sem.Acquire(ctx, 1); err != nil { return diag.WrapError(err) } - - res <- response + func(summary types.TrainingJobSummary) { + errs.Go(func() error { + defer sem.Release(1) + return fetchTrainingJobDefinition(ctx, res, svc, c.Region, summary) + }) + }(d) + } + err = errs.Wait() + if err != nil { + return diag.WrapError(err) } - if aws.ToString(response.NextToken) == "" { break }