Skip to content

Commit

Permalink
Wrap POST /pair in a transaction (#259)
Browse files Browse the repository at this point in the history
* Wrap POST /pair in a transaction

* Tighten up err returns and column whitelists
  • Loading branch information
elffjs authored Jan 13, 2024
1 parent 58c8f1a commit 41c2b55
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions internal/controllers/user_integrations_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ func (udc *UserDevicesController) GetAutoPiPairMessage(c *fiber.Ctx) error {
// We also had a legacy mode for web2-paired devices. This was never used in production.
externalID := c.Query("external_id")

vnft, ad, err := udc.checkPairable(c.Context(), userDeviceID, externalID)
vnft, ad, err := udc.checkPairable(c.Context(), udc.DBS().Reader, userDeviceID, externalID)
if err != nil {
return err
}
Expand Down Expand Up @@ -807,7 +807,13 @@ func (udc *UserDevicesController) PostPairAutoPi(c *fiber.Ctx) error {

logger.Info().Interface("request", pairReq).Msg("Pairing request body.")

vnft, ad, err := udc.checkPairable(c.Context(), userDeviceID, pairReq.ExternalID)
tx, err := udc.DBS().Writer.BeginTx(c.Context(), &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return err
}
defer tx.Rollback() //nolint

vnft, ad, err := udc.checkPairable(c.Context(), tx, userDeviceID, pairReq.ExternalID)
if err != nil {
return err
}
Expand Down Expand Up @@ -880,23 +886,23 @@ func (udc *UserDevicesController) PostPairAutoPi(c *fiber.Ctx) error {
ID: requestID,
Status: models.MetaTransactionRequestStatusUnsubmitted,
}
err = mtr.Insert(c.Context(), udc.DBS().Writer, boil.Infer())
err = mtr.Insert(c.Context(), tx, boil.Infer())
if err != nil {
return err
}

ad.UnpairRequestID = null.String{}
ad.PairRequestID = null.StringFrom(requestID)
_, err = ad.Update(c.Context(), udc.DBS().Writer, boil.Infer())
_, err = ad.Update(c.Context(), tx, boil.Whitelist(models.AftermarketDeviceColumns.UnpairRequestID, models.AftermarketDeviceColumns.PairRequestID, models.AftermarketDeviceColumns.UpdatedAt))
if err != nil {
return err
}
err = client.PairAftermarketDeviceSignTwoOwners(requestID, apToken, vehicleToken, aftermarketDeviceSig, vehicleOwnerSig)
if err != nil {

if err := tx.Commit(); err != nil {
return err
}

return nil
return client.PairAftermarketDeviceSignTwoOwners(requestID, apToken, vehicleToken, aftermarketDeviceSig, vehicleOwnerSig)
}

// Yes, this is ugly, we'll fix it.
Expand All @@ -906,31 +912,30 @@ func (udc *UserDevicesController) PostPairAutoPi(c *fiber.Ctx) error {
ID: requestID,
Status: models.MetaTransactionRequestStatusUnsubmitted,
}
err = mtr.Insert(c.Context(), udc.DBS().Writer, boil.Infer())
err = mtr.Insert(c.Context(), tx, boil.Infer())
if err != nil {
return err
}

ad.UnpairRequestID = null.String{}
ad.PairRequestID = null.StringFrom(requestID)
_, err = ad.Update(c.Context(), udc.DBS().Writer, boil.Infer())
_, err = ad.Update(c.Context(), tx, boil.Whitelist(models.AftermarketDeviceColumns.UnpairRequestID, models.AftermarketDeviceColumns.PairRequestID, models.AftermarketDeviceColumns.UpdatedAt))
if err != nil {
return err
}

err = client.PairAftermarketDeviceSignSameOwner(requestID, apToken, vehicleToken, vehicleOwnerSig)
if err != nil {
if err := tx.Commit(); err != nil {
return err
}

return nil
return client.PairAftermarketDeviceSignSameOwner(requestID, apToken, vehicleToken, vehicleOwnerSig)
}

func (udc *UserDevicesController) checkPairable(ctx context.Context, userDeviceID, serial string) (*models.VehicleNFT, *models.AftermarketDevice, error) {
func (udc *UserDevicesController) checkPairable(ctx context.Context, exec boil.ContextExecutor, userDeviceID, serial string) (*models.VehicleNFT, *models.AftermarketDevice, error) {
ud, err := models.UserDevices(
models.UserDeviceWhere.ID.EQ(userDeviceID),
qm.Load(qm.Rels(models.UserDeviceRels.VehicleNFT, models.VehicleNFTRels.VehicleTokenAftermarketDevice)),
).One(ctx, udc.DBS().Reader)
).One(ctx, exec)
if err != nil {
// Access middleware will catch "not found".
return nil, nil, err
Expand All @@ -952,7 +957,7 @@ func (udc *UserDevicesController) checkPairable(ctx context.Context, userDeviceI
ad, err := models.AftermarketDevices(
models.AftermarketDeviceWhere.Serial.EQ(serial),
qm.Load(models.AftermarketDeviceRels.PairRequest),
).One(ctx, udc.DBS().Reader)
).One(ctx, exec)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil, fiber.NewError(fiber.StatusBadRequest, fmt.Sprintf("No aftermarket device with serial %q known.", serial))
Expand Down Expand Up @@ -1152,7 +1157,7 @@ func (udc *UserDevicesController) UnpairAutoPi(c *fiber.Ctx) error {
}

apnft.UnpairRequestID = null.StringFrom(requestID)
_, err = apnft.Update(c.Context(), tx, boil.Whitelist(models.AftermarketDeviceColumns.UnpairRequestID))
_, err = apnft.Update(c.Context(), tx, boil.Whitelist(models.AftermarketDeviceColumns.UnpairRequestID, models.AftermarketDeviceColumns.UpdatedAt))
if err != nil {
return err
}
Expand Down

0 comments on commit 41c2b55

Please sign in to comment.