Skip to content
This repository has been archived by the owner on Aug 16, 2022. It is now read-only.

Commit

Permalink
feat: Parallelize Sagemaker Training Jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
bbernays authored May 20, 2022
1 parent 80bad8e commit c925608
Showing 1 changed file with 34 additions and 14 deletions.
48 changes: 34 additions & 14 deletions resources/services/sagemaker/training_jobs.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@ import (
"github.com/cloudquery/cq-provider-aws/client"
"github.com/cloudquery/cq-provider-sdk/provider/diag"
"github.com/cloudquery/cq-provider-sdk/provider/schema"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)

const MAX_GOROUTINES = 10

func SagemakerTrainingJobs() *schema.Table {
return &schema.Table{
Name: "aws_sagemaker_training_jobs",
Expand Down Expand Up @@ -562,34 +566,50 @@ func SagemakerTrainingJobs() *schema.Table {
// Table Resolver Functions
// ====================================================================================================================

func fetchTrainingJobDefinition(ctx context.Context, res chan<- interface{}, svc client.SageMakerClient, region string, n types.TrainingJobSummary) error {
config := sagemaker.DescribeTrainingJobInput{
TrainingJobName: n.TrainingJobName,
}
response, err := svc.DescribeTrainingJob(ctx, &config, func(options *sagemaker.Options) {
options.Region = region
})
if err != nil {
return diag.WrapError(err)
}

res <- response
return nil
}

func fetchSagemakerTrainingJobs(ctx context.Context, meta schema.ClientMeta, _ *schema.Resource, res chan<- interface{}) error {
c := meta.(*client.Client)
svc := c.Services().SageMaker
config := sagemaker.ListTrainingJobsInput{}
var sem = semaphore.NewWeighted(int64(MAX_GOROUTINES))

for {
response, err := svc.ListTrainingJobs(ctx, &config, func(options *sagemaker.Options) {
options.Region = c.Region
})
if err != nil {
return diag.WrapError(err)
}

// get more details about the notebook instance
for _, n := range response.TrainingJobSummaries {

config := sagemaker.DescribeTrainingJobInput{
TrainingJobName: n.TrainingJobName,
}
response, err := svc.DescribeTrainingJob(ctx, &config, func(options *sagemaker.Options) {
options.Region = c.Region
})
if err != nil {
errs, ctx := errgroup.WithContext(ctx)
for _, d := range response.TrainingJobSummaries {
if err := sem.Acquire(ctx, 1); err != nil {
return diag.WrapError(err)
}

res <- response
func(summary types.TrainingJobSummary) {
errs.Go(func() error {
defer sem.Release(1)
return fetchTrainingJobDefinition(ctx, res, svc, c.Region, summary)
})
}(d)
}
err = errs.Wait()
if err != nil {
return diag.WrapError(err)
}

if aws.ToString(response.NextToken) == "" {
break
}
Expand Down

0 comments on commit c925608

Please sign in to comment.