Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MIDRC-768 Check if instance type is available in AZ when creating subnets for Nextflow #111

Merged
merged 10 commits into from
Dec 16, 2024
3 changes: 2 additions & 1 deletion hatchery/authz.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ var isUserAuthorizedForPayModels = func(userName string, allowedPayModels []stri
}
currentPayModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(fmt.Sprintf("Failed to get current pay model for user '%s', unable to check if user is authorized to launch container. Error: %v", userName, err))
Config.Logger.Printf("Failed to get current pay model for user '%s', unable to check if user is authorized to launch container. Error: %v", userName, err)

return false, nil
}

Expand Down
12 changes: 6 additions & 6 deletions hatchery/hatchery.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ func launch(w http.ResponseWriter, r *http.Request) {
dbconfig := initializeDbConfig()
activeGen3LicenseUsers, err := getActiveGen3LicenseUserMaps(dbconfig, Config.ContainersMap[hash])
if err != nil {
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
}
// Check for config max
nextLicenseId := getNextLicenseId(activeGen3LicenseUsers, Config.ContainersMap[hash].License.MaxLicenseIds)
Expand All @@ -425,15 +425,15 @@ func launch(w http.ResponseWriter, r *http.Request) {
}
newItem, err := createGen3LicenseUserMap(dbconfig, userName, nextLicenseId, Config.ContainersMap[hash])
if err != nil {
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
}
Config.Logger.Printf("Created new license-user-map item: %v", newItem)

}

allpaymodels, err := getPayModelsForUser(userName)
if err != nil {
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If linting issues are also being addressed, shouldn't Line:495 also use Print instead of Printf?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That one has been fixed too

}
if allpaymodels == nil { // Commons with no concept of paymodels
err = createLocalK8sPod(r.Context(), hash, userName, accessToken, envVars)
Expand Down Expand Up @@ -492,7 +492,7 @@ func terminate(w http.ResponseWriter, r *http.Request) {
dbconfig := initializeDbConfig()
activeGen3LicenseUsers, userlicerr := getLicenseUserMapsForUser(dbconfig, userName)
if userlicerr != nil {
Config.Logger.Printf(userlicerr.Error())
Config.Logger.Print(userlicerr.Error())
}
Config.Logger.Printf("Debug: Active gen3 license user maps %v", activeGen3LicenseUsers)
if len(activeGen3LicenseUsers) == 0 {
Expand All @@ -503,7 +503,7 @@ func terminate(w http.ResponseWriter, r *http.Request) {
Config.Logger.Printf("Debug: updating gen3 license user map as inactive for itemId %s", v.ItemId)
_, err := setGen3LicenseUserInactive(dbconfig, v.ItemId)
if err != nil {
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
}
}
}
Expand All @@ -519,7 +519,7 @@ func terminate(w http.ResponseWriter, r *http.Request) {

payModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
}
if payModel != nil && payModel.Ecs {
_, err = terminateEcsWorkspace(r.Context(), userName, accessToken, payModel.AWSAccountId)
Expand Down
84 changes: 65 additions & 19 deletions hatchery/nextflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ func createNextflowResources(userName string, nextflowGlobalConfig NextflowGloba
}

// create the VPC if it doesn't exist
vpcid, subnetids, err := setupVpcAndSquid(ec2Svc, userName, hostname)
// launch squid
// TODO: read the squid instance type from the hatchery config (would need to change
// `launchSquidInstance` function to update the instance type if the instance already
// exists) (MIDRC-751)
squidInstanceType := "t2.micro"
vpcid, subnetids, err := setupVpcAndSquid(ec2Svc, userName, hostname, nextflowConfig.InstanceType, squidInstanceType)
if err != nil {
Config.Logger.Printf("Unable to setup VPC: %v", err)
return "", "", err
Expand Down Expand Up @@ -456,7 +461,7 @@ var getNextflowAwsSettings = func(sess *session.Session, payModel *PayModel, use
}

// Create VPC for aws batch compute environment
func setupVpcAndSquid(ec2Svc *ec2.EC2, userName string, hostname string) (*string, *[]string, error) {
func setupVpcAndSquid(ec2Svc *ec2.EC2, userName string, hostname string, computeEnvInstanceType string, squidInstanceType string) (*string, *[]string, error) {
// TODO: make base CIDR configurable? (MIDRC-747)
cidrstring := "192.168.0.0/16"
_, IPNet, _ := net.ParseCIDR(cidrstring)
Expand Down Expand Up @@ -522,7 +527,7 @@ func setupVpcAndSquid(ec2Svc *ec2.EC2, userName string, hostname string) (*strin
// create subnets
for i, subnet := range subnets {
subnetName := fmt.Sprintf("%s-nf-subnet-%s-%d", hostname, userName, i)
subnetId, err := setupSubnet(subnetName, subnet, vpcid, ec2Svc)
subnetId, err := setupSubnet(subnetName, subnet, vpcid, ec2Svc, computeEnvInstanceType)
if err != nil {
return nil, nil, err
}
Expand All @@ -548,7 +553,7 @@ func setupVpcAndSquid(ec2Svc *ec2.EC2, userName string, hostname string) (*strin
}

// setup Squid
fwSubnetId, err := setupSquid(hostname, userName, cidrstring, ec2Svc, vpcid, igw, fwRouteTableId, routeTableId)
fwSubnetId, err := setupSquid(hostname, userName, cidrstring, ec2Svc, vpcid, igw, fwRouteTableId, routeTableId, squidInstanceType)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -580,7 +585,7 @@ func ensureLaunchTemplate(ec2Svc *ec2.EC2, userName string, hostname string, job
Config.Logger.Printf("Debug: Launch template '%s' does not exist, creating it", launchTemplateName)
launchTemplate, err := ec2Svc.CreateLaunchTemplate(&ec2.CreateLaunchTemplateInput{
LaunchTemplateName: aws.String(launchTemplateName),
LaunchTemplateData: &ec2.RequestLaunchTemplateData{
LaunchTemplateData: &ec2.RequestLaunchTemplateData{ // if changed, need to update launch template and compute env
UserData: aws.String(userData),
},
})
Expand Down Expand Up @@ -661,6 +666,12 @@ func createBatchComputeEnvironment(nextflowGlobalConfig NextflowGlobalConfig, ne
return "", err
}

subnets := []*string{}
for _, subnet := range subnetids {
s := subnet
subnets = append(subnets, &s)
}

// update any settings that may have changed in the config
// TODO also make sure it is pointing at the correct subnets - if the VPC is deleted,
// we should recreate the compute environment as well because it will be pointing at
Expand All @@ -679,6 +690,7 @@ func createBatchComputeEnvironment(nextflowGlobalConfig NextflowGlobalConfig, ne
MinvCpus: aws.Int64(int64(nextflowConfig.InstanceMinVCpus)),
MaxvCpus: aws.Int64(int64(nextflowConfig.InstanceMaxVCpus)),
InstanceTypes: []*string{aws.String(nextflowConfig.InstanceType)},
Subnets: subnets,
Type: aws.String(nextflowConfig.ComputeEnvironmentType),
Tags: tagsMap,
},
Expand Down Expand Up @@ -997,7 +1009,7 @@ func createS3bucket(nextflowGlobalConfig NextflowGlobalConfig, s3Svc *s3.S3, kms
}

// Function to set up squid and subnets for squid
func setupSquid(hostname string, userName string, cidrstring string, ec2svc *ec2.EC2, vpcid string, igw *string, fwRouteTableId *string, routeTableId *string) (*string, error) {
func setupSquid(hostname string, userName string, cidrstring string, ec2svc *ec2.EC2, vpcid string, igw *string, fwRouteTableId *string, routeTableId *string, instanceType string) (*string, error) {
_, IPNet, _ := net.ParseCIDR(cidrstring)
subnet, err := cidr.Subnet(IPNet, 2, 3)
if err != nil {
Expand All @@ -1008,8 +1020,7 @@ func setupSquid(hostname string, userName string, cidrstring string, ec2svc *ec2
// create subnet
subnetName := fmt.Sprintf("%s-nf-subnet-fw-%s", hostname, userName)
Config.Logger.Printf("Debug: Creating subnet '%s' with name '%s'", subnet, subnetName)

subnetId, err := setupSubnet(subnetName, subnetString, vpcid, ec2svc)
subnetId, err := setupSubnet(subnetName, subnetString, vpcid, ec2svc, instanceType)
if err != nil {
return nil, err
}
Expand All @@ -1035,8 +1046,7 @@ func setupSquid(hostname string, userName string, cidrstring string, ec2svc *ec2
}
Config.Logger.Printf("Debug: Associated route table '%s' to subnet '%s'", *fwRouteTableId, *subnetId)

// launch squid
squidInstanceId, err := launchSquidInstance(hostname, userName, ec2svc, subnetId, vpcid, subnetString)
squidInstanceId, err := launchSquidInstance(hostname, userName, ec2svc, subnetId, vpcid, subnetString, instanceType)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1077,7 +1087,7 @@ func setupSquid(hostname string, userName string, cidrstring string, ec2svc *ec2
}

// Generic function to create subnet, and route table
func setupSubnet(subnetName string, cidr string, vpcid string, ec2Svc *ec2.EC2) (*string, error) {
func setupSubnet(subnetName string, cidr string, vpcid string, ec2Svc *ec2.EC2, instanceType string) (*string, error) {
// Check if subnet exists if not create it
exsubnet, err := ec2Svc.DescribeSubnets(&ec2.DescribeSubnetsInput{
Filters: []*ec2.Filter{
Expand All @@ -1103,11 +1113,52 @@ func setupSubnet(subnetName string, cidr string, vpcid string, ec2Svc *ec2.EC2)
return exsubnet.Subnets[0].SubnetId, nil
}

// Fetch all availability zones
// this is being limited to a region by the ec2svc that gets passed in.
describeZonesOutput, err := ec2Svc.DescribeAvailabilityZones(&ec2.DescribeAvailabilityZonesInput{})
if err != nil {
return nil, fmt.Errorf("failed to describe availability zones: %v", err)
}

// Make sure the selected AZ has the instance type from nextflow configuration available.
var selectedZone string
for _, zone := range describeZonesOutput.AvailabilityZones {
if *zone.State != "available" {
continue
}
result, err := ec2Svc.DescribeInstanceTypeOfferings(&ec2.DescribeInstanceTypeOfferingsInput{
LocationType: aws.String("availability-zone"),
Filters: []*ec2.Filter{
{
Name: aws.String("location"),
Values: []*string{aws.String(*zone.ZoneName)},
},
{
Name: aws.String("instance-type"),
Values: []*string{aws.String(instanceType)},
},
},
})
if err != nil {
return nil, fmt.Errorf("Error describing instance type offerings: %v", err)
}
if len(result.InstanceTypeOfferings) > 0 {
Config.Logger.Printf("Debug: Zone: %v has instance type %v available. Using that for subnet", *zone.ZoneName, instanceType)
selectedZone = *zone.ZoneName
break // Exit the loop if we found a suitable zone
}
}

if selectedZone == "" {
return nil, fmt.Errorf("no suitable availability zone found")
}

// create subnet
Config.Logger.Printf("Debug: Creating subnet '%v' with name '%s'", cidr, subnetName)
sn, err := ec2Svc.CreateSubnet(&ec2.CreateSubnetInput{
CidrBlock: aws.String(cidr),
VpcId: aws.String(vpcid),
CidrBlock: aws.String(cidr),
VpcId: aws.String(vpcid),
AvailabilityZone: aws.String(selectedZone),
TagSpecifications: []*ec2.TagSpecification{
{
// Name
Expand Down Expand Up @@ -1207,7 +1258,7 @@ func associateRouteTablesToSubnets(ec2svc *ec2.EC2, subnets []string, routeTable
return nil
}

func launchSquidInstance(hostname string, userName string, ec2svc *ec2.EC2, subnetId *string, vpcId string, subnet string) (*string, error) {
func launchSquidInstance(hostname string, userName string, ec2svc *ec2.EC2, subnetId *string, vpcId string, subnet string, instanceType string) (*string, error) {
instanceName := fmt.Sprintf("%s-nf-squid-%s", hostname, userName)

// check if instance already exists, if it does start it
Expand Down Expand Up @@ -1320,11 +1371,6 @@ $(command -v docker) run --name squid --restart=always --network=host -d \
return nil, err
}

// instance type
// TODO: we could make this configurable via hatchery config (would need to change this
// function to update the instance type if the instance already exists) (MIDRC-751)
instanceType := "t2.micro"

// Launch EC2 instance
squid, err := ec2svc.RunInstances(&ec2.RunInstancesInput{
ImageId: amiId,
Expand Down
4 changes: 2 additions & 2 deletions hatchery/pods.go
Original file line number Diff line number Diff line change
Expand Up @@ -400,14 +400,14 @@ func buildPod(hatchConfig *FullHatcheryConfig, hatchApp *Container, userName str
//hatchConfig.Logger.Printf("sidecar configured")

var lifeCycle = k8sv1.Lifecycle{}
if hatchApp.LifecyclePreStop != nil && len(hatchApp.LifecyclePreStop) > 0 {
if len(hatchApp.LifecyclePreStop) > 0 {
lifeCycle.PreStop = &k8sv1.LifecycleHandler{
Exec: &k8sv1.ExecAction{
Command: hatchApp.LifecyclePreStop,
},
}
}
if hatchApp.LifecyclePostStart != nil && len(hatchApp.LifecyclePostStart) > 0 {
if len(hatchApp.LifecyclePostStart) > 0 {
lifeCycle.PostStart = &k8sv1.LifecycleHandler{
Exec: &k8sv1.ExecAction{
Command: hatchApp.LifecyclePostStart,
Expand Down
6 changes: 3 additions & 3 deletions hatchery/ram.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func acceptTransitGatewayShare(pm *PayModel, sess *session.Session, ramArn *stri
})
if err != nil {
// Log error
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
return err
}
if len(exResourceShares.ResourceShares) == 0 {
Expand All @@ -39,7 +39,7 @@ func acceptTransitGatewayShare(pm *PayModel, sess *session.Session, ramArn *stri
err := svc.acceptTGWShare(ramArn)
if err != nil {
// Log error
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
return err
}
} else {
Expand All @@ -64,7 +64,7 @@ func (creds *CREDS) acceptTGWShare(ramArn *string) error {
resourceShareInvitation, err := svc.GetResourceShareInvitations(ramInvitationInput)
if err != nil {
// Log error
Config.Logger.Printf(err.Error())
Config.Logger.Print(err.Error())
return err
}

Expand Down
5 changes: 2 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"errors"
"fmt"
"log"
"net/http"
"os"
Expand Down Expand Up @@ -36,12 +35,12 @@ func main() {
logger := log.New(os.Stdout, "", log.LstdFlags)
cleanPath, err := verifyPath(configPath)
if err != nil {
logger.Printf(fmt.Sprintf("Failed to load config - got %v", err))
logger.Printf("Failed to load config - got %v", err)
return
}
config, err := hatchery.LoadConfig(cleanPath, logger)
if err != nil {
config.Logger.Printf(fmt.Sprintf("Failed to load config - got %v", err))
config.Logger.Printf("Failed to load config - got %v", err)
return
}
hatchery.Config = config
Expand Down
Loading