diff --git a/go/test/endtoend/vreplication/cluster_test.go b/go/test/endtoend/vreplication/cluster_test.go
index 7d22d063945..7f06dc87680 100644
--- a/go/test/endtoend/vreplication/cluster_test.go
+++ b/go/test/endtoend/vreplication/cluster_test.go
@@ -886,7 +886,7 @@ func (vc *VitessCluster) getVttabletsInKeyspace(t *testing.T, cell *Cell, ksName
 	tablets := make(map[string]*cluster.VttabletProcess)
 	for _, shard := range keyspace.Shards {
 		for _, tablet := range shard.Tablets {
-			if tablet.Vttablet.GetTabletStatus() == "SERVING" {
+			if tablet.Vttablet.GetTabletStatus() == "SERVING" && (tabletType == "" || strings.EqualFold(tablet.Vttablet.GetTabletType(), tabletType)) {
 				log.Infof("Serving status of tablet %s is %s, %s", tablet.Name, tablet.Vttablet.ServingStatus, tablet.Vttablet.GetTabletStatus())
 				tablets[tablet.Name] = tablet.Vttablet
 			}
diff --git a/go/test/endtoend/vreplication/vreplication_test.go b/go/test/endtoend/vreplication/vreplication_test.go
index 4e50ea12af3..7e3f93b6b20 100644
--- a/go/test/endtoend/vreplication/vreplication_test.go
+++ b/go/test/endtoend/vreplication/vreplication_test.go
@@ -762,11 +762,16 @@ func shardCustomer(t *testing.T, testReverse bool, cells []*Cell, sourceCellOrAl
 		switchReads(t, workflowType, cellNames, ksWorkflow, false)
 		assertQueryExecutesOnTablet(t, vtgateConn, productTab, "customer", query, query)
 
+		switchWritesDryRun(t, workflowType, ksWorkflow, dryRunResultsSwitchWritesCustomerShard)
+
+		testSwitchWritesErrorHandling(t, []*cluster.VttabletProcess{productTab}, []*cluster.VttabletProcess{customerTab1, customerTab2},
+			workflow, workflowType)
+
 		var commit func(t *testing.T)
 		if withOpenTx {
 			commit, _ = vc.startQuery(t, openTxQuery)
 		}
-		switchWritesDryRun(t, workflowType, ksWorkflow, dryRunResultsSwitchWritesCustomerShard)
+		// Now let's confirm that it works as expected with an error.
 		switchWrites(t, workflowType, ksWorkflow, false)
 
 		checkThatVDiffFails(t, targetKs, workflow)
@@ -998,6 +1003,7 @@ func reshard(t *testing.T, ksName string, tableName string, workflow string, sou
 		require.NoError(t, vc.AddShards(t, cells, keyspace, targetShards, defaultReplicas, defaultRdonly, tabletIDBase, targetKsOpts))
 
 		tablets := vc.getVttabletsInKeyspace(t, defaultCell, ksName, "primary")
+		var sourceTablets, targetTablets []*cluster.VttabletProcess
 
 		// Test multi-primary setups, like a Galera cluster, which have auto increment steps > 1.
 		for _, tablet := range tablets {
@@ -1010,9 +1016,11 @@ func reshard(t *testing.T, ksName string, tableName string, workflow string, sou
 		targetShards = "," + targetShards + ","
 		for _, tab := range tablets {
 			if strings.Contains(targetShards, ","+tab.Shard+",") {
+				targetTablets = append(targetTablets, tab)
 				log.Infof("Waiting for vrepl to catch up on %s since it IS a target shard", tab.Shard)
 				catchup(t, tab, workflow, "Reshard")
 			} else {
+				sourceTablets = append(sourceTablets, tab)
 				log.Infof("Not waiting for vrepl to catch up on %s since it is NOT a target shard", tab.Shard)
 				continue
 			}
@@ -1026,6 +1034,10 @@ func reshard(t *testing.T, ksName string, tableName string, workflow string, sou
 		if dryRunResultSwitchWrites != nil {
 			reshardAction(t, "SwitchTraffic", workflow, ksName, "", "", callNames, "primary", "--dry-run")
 		}
+		if tableName == "customer" {
+			testSwitchWritesErrorHandling(t, sourceTablets, targetTablets, workflow, "reshard")
+		}
+		// Now let's confirm that it works as expected with an error.
 		reshardAction(t, "SwitchTraffic", workflow, ksName, "", "", callNames, "primary")
 		reshardAction(t, "Complete", workflow, ksName, "", "", "", "")
 		for tabletName, count := range counts {
@@ -1534,6 +1546,140 @@ func switchWritesDryRun(t *testing.T, workflowType, ksWorkflow string, dryRunRes
 	validateDryRunResults(t, output, dryRunResults)
 }
 
+// testSwitchWritesErrorHandling confirms that switching writes works as expected
+// in the face of vreplication lag (canSwitch() precheck) and when canceling the
+// switch due to replication failing to catch up in time.
+// The workflow MUST be migrating the customer table from the source to the
+// target keyspace AND the workflow must currently have reads switched but not
+// writes.
+func testSwitchWritesErrorHandling(t *testing.T, sourceTablets, targetTablets []*cluster.VttabletProcess, workflow, workflowType string) {
+	t.Run("validate switch writes error handling", func(t *testing.T) {
+		vtgateConn := getConnection(t, vc.ClusterConfig.hostname, vc.ClusterConfig.vtgateMySQLPort)
+		defer vtgateConn.Close()
+		require.NotZero(t, len(sourceTablets), "no source tablets provided")
+		require.NotZero(t, len(targetTablets), "no target tablets provided")
+		sourceKs := sourceTablets[0].Keyspace
+		targetKs := targetTablets[0].Keyspace
+		ksWorkflow := fmt.Sprintf("%s.%s", targetKs, workflow)
+		var err error
+		sourceConns := make([]*mysql.Conn, len(sourceTablets))
+		for i, tablet := range sourceTablets {
+			sourceConns[i], err = tablet.TabletConn(tablet.Keyspace, true)
+			require.NoError(t, err)
+			defer sourceConns[i].Close()
+		}
+		targetConns := make([]*mysql.Conn, len(targetTablets))
+		for i, tablet := range targetTablets {
+			targetConns[i], err = tablet.TabletConn(tablet.Keyspace, true)
+			require.NoError(t, err)
+			defer targetConns[i].Close()
+		}
+		startingTestRowID := 10000000
+		numTestRows := 100
+		addTestRows := func() {
+			for i := 0; i < numTestRows; i++ {
+				execVtgateQuery(t, vtgateConn, sourceTablets[0].Keyspace, fmt.Sprintf("insert into customer (cid, name) values (%d, 'laggingCustomer')",
+					startingTestRowID+i))
+			}
+		}
+		deleteTestRows := func() {
+			execVtgateQuery(t, vtgateConn, sourceTablets[0].Keyspace, fmt.Sprintf("delete from customer where cid >= %d", startingTestRowID))
+		}
+		addIndex := func() {
+			for _, targetConn := range targetConns {
+				execQuery(t, targetConn, "set session sql_mode=''")
+				execQuery(t, targetConn, "alter table customer add unique index name_idx (name)")
+			}
+		}
+		dropIndex := func() {
+			for _, targetConn := range targetConns {
+				execQuery(t, targetConn, "alter table customer drop index name_idx")
+			}
+		}
+		lockTargetTable := func() {
+			for _, targetConn := range targetConns {
+				execQuery(t, targetConn, "lock table customer read")
+			}
+		}
+		unlockTargetTable := func() {
+			for _, targetConn := range targetConns {
+				execQuery(t, targetConn, "unlock tables")
+			}
+		}
+		cleanupTestData := func() {
+			dropIndex()
+			deleteTestRows()
+		}
+		restartWorkflow := func() {
+			err = vc.VtctldClient.ExecuteCommand("workflow", "--keyspace", targetKs, "start", "--workflow", workflow)
+			require.NoError(t, err, "failed to start workflow: %v", err)
+		}
+		waitForTargetToCatchup := func() {
+			waitForWorkflowState(t, vc, ksWorkflow, binlogdatapb.VReplicationWorkflowState_Running.String())
+			waitForNoWorkflowLag(t, vc, targetKs, workflow)
+		}
+
+		// First let's test that the prechecks work as expected. We ALTER
+		// the table on the target shards to add a unique index on the name
+		// field.
+		addIndex()
+		// Then we replicate some test rows across the target shards by
+		// inserting them in the source keyspace.
+		addTestRows()
+		// Now the workflow should go into the error state and the lag should
+		// start to climb. So we sleep for twice the max lag duration that we
+		// will set for the SwitchTraffic call.
+		lagDuration := 3 * time.Second
+		time.Sleep(lagDuration * 3)
+		out, err := vc.VtctldClient.ExecuteCommandWithOutput(workflowType, "--workflow", workflow, "--target-keyspace", targetKs,
+			"SwitchTraffic", "--tablet-types=primary", "--timeout=30s", "--max-replication-lag-allowed", lagDuration.String())
+		// It should fail in the canSwitch() precheck.
+		require.Error(t, err)
+		require.Regexp(t, fmt.Sprintf(".*cannot switch traffic for workflow %s at this time: replication lag [0-9]+s is higher than allowed lag %s.*",
+			workflow, lagDuration.String()), out)
+		require.NotContains(t, out, "cancel migration failed")
+		// Confirm that queries still work fine.
+		execVtgateQuery(t, vtgateConn, sourceKs, "select * from customer limit 1")
+		cleanupTestData()
+		// We have to restart the workflow again as the duplicate key error
+		// is a permanent/terminal one.
+		restartWorkflow()
+		waitForTargetToCatchup()
+
+		// Now let's test that the cancel works by setting the command timeout
+		// to a fraction (6s) of the default max repl lag duration (30s). First
+		// we lock the customer table on the target tablets so that we cannot
+		// apply the INSERTs and catch up.
+		lockTargetTable()
+		addTestRows()
+		timeout := lagDuration * 2 // 6s
+		// Use the default max-replication-lag-allowed value of 30s.
+		// We run the command in a goroutine so that we can unblock things
+		// after the timeout is reached -- as the vplayer query is blocking
+		// on the table lock in the MySQL layer.
+		wg := sync.WaitGroup{}
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			out, err = vc.VtctldClient.ExecuteCommandWithOutput(workflowType, "--workflow", workflow, "--target-keyspace", targetKs,
+				"SwitchTraffic", "--tablet-types=primary", "--timeout", timeout.String())
+		}()
+		time.Sleep(timeout)
+		// Now we can unblock things and let it continue.
+		unlockTargetTable()
+		wg.Wait()
+		// It should fail due to the command context timeout and we should
+		// successfully cancel.
+		require.Error(t, err)
+		require.Contains(t, out, "failed to sync up replication between the source and target")
+		require.NotContains(t, out, "cancel migration failed")
+		// Confirm that queries still work fine.
+		execVtgateQuery(t, vtgateConn, sourceKs, "select * from customer limit 1")
+		deleteTestRows()
+		waitForTargetToCatchup()
+	})
+}
+
 // restartWorkflow confirms that a workflow can be successfully
 // stopped and started.
 func restartWorkflow(t *testing.T, ksWorkflow string) {
diff --git a/go/vt/vtctl/workflow/server.go b/go/vt/vtctl/workflow/server.go
index f4c761e8b5a..a4f5ba58364 100644
--- a/go/vt/vtctl/workflow/server.go
+++ b/go/vt/vtctl/workflow/server.go
@@ -688,11 +688,10 @@ func (s *Server) GetWorkflows(ctx context.Context, req *vtctldatapb.GetWorkflows
 
 		targetKeyspaceByWorkflow[workflow.Name] = tablet.Keyspace
 
-		timeUpdated := time.Unix(timeUpdatedSeconds, 0)
-		vreplicationLag := time.Since(timeUpdated)
-
 		// MaxVReplicationLag represents the time since we last processed any event
 		// in the workflow.
+		timeUpdated := time.Unix(timeUpdatedSeconds, 0)
+		vreplicationLag := time.Since(timeUpdated)
 		if currentMaxLag, ok := maxVReplicationLagByWorkflow[workflow.Name]; ok {
 			if vreplicationLag.Seconds() > currentMaxLag {
 				maxVReplicationLagByWorkflow[workflow.Name] = vreplicationLag.Seconds()
@@ -701,32 +700,18 @@ func (s *Server) GetWorkflows(ctx context.Context, req *vtctldatapb.GetWorkflows
 			maxVReplicationLagByWorkflow[workflow.Name] = vreplicationLag.Seconds()
 		}
 
-		// MaxVReplicationTransactionLag estimates the actual statement processing lag
-		// between the source and the target. If we are still processing source events it
-		// is the difference b/w current time and the timestamp of the last event. If
-		// heartbeats are more recent than the last event, then the lag is the time since
-		// the last heartbeat as there can be an actual event immediately after the
-		// heartbeat, but which has not yet been processed on the target.
-		// We don't allow switching during the copy phase, so in that case we just return
-		// a large lag. All timestamps are in seconds since epoch.
+		// MaxVReplicationTransactionLag estimates the max statement processing lag
+		// between the source and the target across all of the workflow streams.
 		if _, ok := maxVReplicationTransactionLagByWorkflow[workflow.Name]; !ok {
 			maxVReplicationTransactionLagByWorkflow[workflow.Name] = 0
 		}
-		lastTransactionTime := transactionTimeSeconds
-		lastHeartbeatTime := timeHeartbeat
-		if stream.State == binlogdatapb.VReplicationWorkflowState_Copying.String() {
-			maxVReplicationTransactionLagByWorkflow[workflow.Name] = math.MaxInt64
-		} else {
-			if lastTransactionTime == 0 /* no new events after copy */ ||
-				lastHeartbeatTime > lastTransactionTime /* no recent transactions, so all caught up */ {
-
-				lastTransactionTime = lastHeartbeatTime
-			}
-			now := time.Now().Unix() /* seconds since epoch */
-			transactionReplicationLag := float64(now - lastTransactionTime)
-			if transactionReplicationLag > maxVReplicationTransactionLagByWorkflow[workflow.Name] {
-				maxVReplicationTransactionLagByWorkflow[workflow.Name] = transactionReplicationLag
-			}
+		heartbeatTimestamp := &vttimepb.Time{
+			Seconds: timeHeartbeat,
+		}
+		transactionReplicationLag := getVReplicationTrxLag(stream.TransactionTimestamp, stream.TimeUpdated, heartbeatTimestamp,
+			binlogdatapb.VReplicationWorkflowState(binlogdatapb.VReplicationWorkflowState_value[stream.State]))
+		if transactionReplicationLag > maxVReplicationTransactionLagByWorkflow[workflow.Name] {
+			maxVReplicationTransactionLagByWorkflow[workflow.Name] = transactionReplicationLag
 		}
 
 		return nil
@@ -3242,8 +3227,10 @@ func (s *Server) switchWrites(ctx context.Context, req *vtctldatapb.WorkflowSwit
 			return handleError("failed to migrate the workflow streams", err)
 		}
 		if cancel {
-			sw.cancelMigration(ctx, sm)
-			return 0, sw.logs(), nil
+			if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+				err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+			}
+			return 0, sw.logs(), err
 		}
 
 		ts.Logger().Infof("Stopping streams")
@@ -3254,13 +3241,17 @@ func (s *Server) switchWrites(ctx context.Context, req *vtctldatapb.WorkflowSwit
 					ts.Logger().Errorf("stream in stopStreams: key %s shard %s stream %+v", key, stream.BinlogSource.Shard, stream.BinlogSource)
 				}
 			}
-			sw.cancelMigration(ctx, sm)
-			return handleError("failed to stop the workflow streams", err)
+			if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+				err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+			}
+			return handleError(fmt.Sprintf("failed to stop the workflow streams in the %s keyspace", ts.SourceKeyspaceName()), err)
 		}
 
 		ts.Logger().Infof("Stopping source writes")
 		if err := sw.stopSourceWrites(ctx); err != nil {
-			sw.cancelMigration(ctx, sm)
+			if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+				err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+			}
 			return handleError(fmt.Sprintf("failed to stop writes in the %s keyspace", ts.SourceKeyspaceName()), err)
 		}
 
@@ -3270,7 +3261,9 @@ func (s *Server) switchWrites(ctx context.Context, req *vtctldatapb.WorkflowSwit
 			// the tablet's deny list check and the first mysqld side table lock.
 			for cnt := 1; cnt <= lockTablesCycles; cnt++ {
 				if err := ts.executeLockTablesOnSource(ctx); err != nil {
-					sw.cancelMigration(ctx, sm)
+					if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+						err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+					}
 					return handleError(fmt.Sprintf("failed to execute LOCK TABLES (attempt %d of %d) on sources", cnt, lockTablesCycles), err)
 				}
 				// No need to UNLOCK the tables as the connection was closed once the locks were acquired
@@ -3281,25 +3274,33 @@ func (s *Server) switchWrites(ctx context.Context, req *vtctldatapb.WorkflowSwit
 
 		ts.Logger().Infof("Waiting for streams to catchup")
 		if err := sw.waitForCatchup(ctx, timeout); err != nil {
-			sw.cancelMigration(ctx, sm)
+			if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+				err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+			}
 			return handleError("failed to sync up replication between the source and target", err)
 		}
 
 		ts.Logger().Infof("Migrating streams")
 		if err := sw.migrateStreams(ctx, sm); err != nil {
-			sw.cancelMigration(ctx, sm)
+			if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+				err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+			}
 			return handleError("failed to migrate the workflow streams", err)
 		}
 
 		ts.Logger().Infof("Resetting sequences")
 		if err := sw.resetSequences(ctx); err != nil {
-			sw.cancelMigration(ctx, sm)
+			if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+				err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+			}
 			return handleError("failed to reset the sequences", err)
 		}
 
 		ts.Logger().Infof("Creating reverse streams")
 		if err := sw.createReverseVReplication(ctx); err != nil {
-			sw.cancelMigration(ctx, sm)
+			if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+				err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+			}
 			return handleError("failed to create the reverse vreplication streams", err)
 		}
 
@@ -3312,7 +3313,9 @@ func (s *Server) switchWrites(ctx context.Context, req *vtctldatapb.WorkflowSwit
 			initSeqCtx, cancel := context.WithTimeout(ctx, timeout/2)
 			defer cancel()
 			if err := sw.initializeTargetSequences(initSeqCtx, sequenceMetadata); err != nil {
-				sw.cancelMigration(ctx, sm)
+				if cerr := sw.cancelMigration(ctx, sm); cerr != nil {
+					err = vterrors.Errorf(vtrpcpb.Code_CANCELED, "%v\n\n%v", err, cerr)
+				}
 				return handleError(fmt.Sprintf("failed to initialize the sequences used in the %s keyspace", ts.TargetKeyspaceName()), err)
 			}
 		}
@@ -3365,15 +3368,14 @@ func (s *Server) canSwitch(ctx context.Context, ts *trafficSwitcher, state *Stat
 	if err != nil {
 		return "", err
 	}
+	if wf.MaxVReplicationTransactionLag > maxAllowedReplLagSecs {
+		return fmt.Sprintf(cannotSwitchHighLag, wf.MaxVReplicationTransactionLag, maxAllowedReplLagSecs), nil
+	}
 	for _, stream := range wf.ShardStreams {
 		for _, st := range stream.GetStreams() {
 			if st.Message == Frozen {
 				return cannotSwitchFrozen, nil
 			}
-			// If no new events have been replicated after the copy phase then it will be 0.
-			if vreplLag := time.Now().Unix() - st.TimeUpdated.Seconds; vreplLag > maxAllowedReplLagSecs {
-				return fmt.Sprintf(cannotSwitchHighLag, vreplLag, maxAllowedReplLagSecs), nil
-			}
 			switch st.State {
 			case binlogdatapb.VReplicationWorkflowState_Copying.String():
 				return cannotSwitchCopyIncomplete, nil
@@ -3899,3 +3901,41 @@ func (s *Server) MigrateCreate(ctx context.Context, req *vtctldatapb.MigrateCrea
 	}
 	return s.moveTablesCreate(ctx, moveTablesCreateRequest, binlogdatapb.VReplicationWorkflowType_Migrate)
 }
+
+// getVReplicationTrxLag estimates the actual statement processing lag between the
+// source and the target. If we are still processing source events it is the
+// difference between current time and the timestamp of the last event. If
+// heartbeats are more recent than the last event, then the lag is the time since
+// the last heartbeat as there can be an actual event immediately after the
+// heartbeat, but which has not yet been processed on the target. We don't allow
+// switching during the copy phase, so in that case we just return a large lag.
+// All timestamps are in seconds since epoch.
+func getVReplicationTrxLag(trxTs, updatedTs, heartbeatTs *vttimepb.Time, state binlogdatapb.VReplicationWorkflowState) float64 {
+	if state == binlogdatapb.VReplicationWorkflowState_Copying {
+		return math.MaxInt64
+	}
+	if trxTs == nil {
+		trxTs = &vttimepb.Time{}
+	}
+	lastTransactionTime := trxTs.Seconds
+	if updatedTs == nil {
+		updatedTs = &vttimepb.Time{}
+	}
+	lastUpdatedTime := updatedTs.Seconds
+	if heartbeatTs == nil {
+		heartbeatTs = &vttimepb.Time{}
+	}
+	lastHeartbeatTime := heartbeatTs.Seconds
+	// We do NOT update the heartbeat timestamp when we are regularly updating the
+	// position as we replicate transactions (GTIDs).
+	// When we DO record a heartbeat, we set the updated time to the same value.
+	// When recording that we are throttled, we update the updated time but NOT
+	// the heartbeat time.
+	if lastTransactionTime == 0 /* No replicated events after copy */ ||
+		(lastUpdatedTime == lastHeartbeatTime && /* The last update was from a heartbeat */
+			lastUpdatedTime > lastTransactionTime /* No recent transactions, only heartbeats, so all caught up */) {
+		lastTransactionTime = lastUpdatedTime
+	}
+	now := time.Now().Unix() // Seconds since epoch
+	return float64(now - lastTransactionTime)
+}
diff --git a/go/vt/vtctl/workflow/stream_migrator.go b/go/vt/vtctl/workflow/stream_migrator.go
index 7d225f6dd9f..1a7ffc71f24 100644
--- a/go/vt/vtctl/workflow/stream_migrator.go
+++ b/go/vt/vtctl/workflow/stream_migrator.go
@@ -158,12 +158,15 @@ func (sm *StreamMigrator) Templates() []*VReplicationStream {
 }
 
 // CancelStreamMigrations cancels the stream migrations.
-func (sm *StreamMigrator) CancelStreamMigrations(ctx context.Context) {
+func (sm *StreamMigrator) CancelStreamMigrations(ctx context.Context) error {
 	if sm.streams == nil {
-		return
+		return nil
 	}
+	errs := &concurrency.AllErrorRecorder{}
 
-	_ = sm.deleteTargetStreams(ctx)
+	if err := sm.deleteTargetStreams(ctx); err != nil {
+		errs.RecordError(fmt.Errorf("could not delete target streams: %v", err))
+	}
 
 	// Restart the source streams, but leave the Reshard workflow's reverse
 	// variant stopped.
@@ -176,8 +179,13 @@ func (sm *StreamMigrator) CancelStreamMigrations(ctx context.Context) {
 		return err
 	})
 	if err != nil {
+		errs.RecordError(fmt.Errorf("could not restart source streams: %v", err))
 		sm.logger.Errorf("Cancel stream migrations failed: could not restart source streams: %v", err)
 	}
+	if errs.HasErrors() {
+		return errs.AggrError(vterrors.Aggregate)
+	}
+	return nil
 }
 
 // MigrateStreams migrates N streams
diff --git a/go/vt/vtctl/workflow/switcher.go b/go/vt/vtctl/workflow/switcher.go
index 0cbdce164dc..5e95e648299 100644
--- a/go/vt/vtctl/workflow/switcher.go
+++ b/go/vt/vtctl/workflow/switcher.go
@@ -110,8 +110,8 @@ func (r *switcher) stopStreams(ctx context.Context, sm *StreamMigrator) ([]strin
 	return sm.StopStreams(ctx)
 }
 
-func (r *switcher) cancelMigration(ctx context.Context, sm *StreamMigrator) {
-	r.ts.cancelMigration(ctx, sm)
+func (r *switcher) cancelMigration(ctx context.Context, sm *StreamMigrator) error {
+	return r.ts.cancelMigration(ctx, sm)
 }
 
 func (r *switcher) lockKeyspace(ctx context.Context, keyspace, action string) (context.Context, func(*error), error) {
diff --git a/go/vt/vtctl/workflow/switcher_dry_run.go b/go/vt/vtctl/workflow/switcher_dry_run.go
index 21b975a0d6b..b7ad8207574 100644
--- a/go/vt/vtctl/workflow/switcher_dry_run.go
+++ b/go/vt/vtctl/workflow/switcher_dry_run.go
@@ -214,8 +214,9 @@ func (dr *switcherDryRun) stopStreams(ctx context.Context, sm *StreamMigrator) (
 	return nil, nil
 }
 
-func (dr *switcherDryRun) cancelMigration(ctx context.Context, sm *StreamMigrator) {
+func (dr *switcherDryRun) cancelMigration(ctx context.Context, sm *StreamMigrator) error {
 	dr.drLog.Log("Cancel migration as requested")
+	return nil
 }
 
 func (dr *switcherDryRun) lockKeyspace(ctx context.Context, keyspace, _ string) (context.Context, func(*error), error) {
diff --git a/go/vt/vtctl/workflow/switcher_interface.go b/go/vt/vtctl/workflow/switcher_interface.go
index 8d0f9e847be..9f73fd45ad6 100644
--- a/go/vt/vtctl/workflow/switcher_interface.go
+++ b/go/vt/vtctl/workflow/switcher_interface.go
@@ -25,7 +25,7 @@ import (
 
 type iswitcher interface {
 	lockKeyspace(ctx context.Context, keyspace, action string) (context.Context, func(*error), error)
-	cancelMigration(ctx context.Context, sm *StreamMigrator)
+	cancelMigration(ctx context.Context, sm *StreamMigrator) error
 	stopStreams(ctx context.Context, sm *StreamMigrator) ([]string, error)
 	stopSourceWrites(ctx context.Context) error
 	waitForCatchup(ctx context.Context, filteredReplicationWaitTime time.Duration) error
diff --git a/go/vt/vtctl/workflow/traffic_switcher.go b/go/vt/vtctl/workflow/traffic_switcher.go
index f4d8a13054b..79a0492750b 100644
--- a/go/vt/vtctl/workflow/traffic_switcher.go
+++ b/go/vt/vtctl/workflow/traffic_switcher.go
@@ -996,8 +996,9 @@ func (ts *trafficSwitcher) changeTableSourceWrites(ctx context.Context, access a
 
 // cancelMigration attempts to revert all changes made during the migration so that we can get back to the
 // state when traffic switching (or reversing) was initiated.
-func (ts *trafficSwitcher) cancelMigration(ctx context.Context, sm *StreamMigrator) {
+func (ts *trafficSwitcher) cancelMigration(ctx context.Context, sm *StreamMigrator) error {
 	var err error
+	cancelErrs := &concurrency.AllErrorRecorder{}
 
 	if ctx.Err() != nil {
 		// Even though we create a new context later on we still record any context error:
@@ -1006,9 +1007,13 @@ func (ts *trafficSwitcher) cancelMigration(ctx context.Context, sm *StreamMigrat
 	}
 
 	// We create a new context while canceling the migration, so that we are independent of the original
-	// context being cancelled prior to or during the cancel operation.
-	cmTimeout := 60 * time.Second
-	cmCtx, cmCancel := context.WithTimeout(context.Background(), cmTimeout)
+	// context being canceled prior to or during the cancel operation itself.
+	// First we create a copy of the parent context, so that we maintain the locks, but which cannot be
+	// canceled by the parent context.
+	wcCtx := context.WithoutCancel(ctx)
+	// Now we create a child context from that which has a timeout.
+	cmTimeout := 2 * time.Minute
+	cmCtx, cmCancel := context.WithTimeout(wcCtx, cmTimeout)
 	defer cmCancel()
 
 	if ts.MigrationType() == binlogdatapb.MigrationType_TABLES {
@@ -1017,10 +1022,14 @@ func (ts *trafficSwitcher) cancelMigration(ctx context.Context, sm *StreamMigrat
 		err = ts.changeShardsAccess(cmCtx, ts.SourceKeyspaceName(), ts.SourceShards(), allowWrites)
 	}
 	if err != nil {
+		cancelErrs.RecordError(fmt.Errorf("could not revert denied tables / shard access: %v", err))
 		ts.Logger().Errorf("Cancel migration failed: could not revert denied tables / shard access: %v", err)
 	}
 
-	sm.CancelStreamMigrations(cmCtx)
+	if err := sm.CancelStreamMigrations(cmCtx); err != nil {
+		cancelErrs.RecordError(fmt.Errorf("could not cancel stream migrations: %v", err))
+		ts.Logger().Errorf("Cancel migration failed: could not cancel stream migrations: %v", err)
+	}
 
 	err = ts.ForAllTargets(func(target *MigrationTarget) error {
 		query := fmt.Sprintf("update _vt.vreplication set state='Running', message='' where db_name=%s and workflow=%s",
@@ -1029,13 +1038,19 @@ func (ts *trafficSwitcher) cancelMigration(ctx context.Context, sm *StreamMigrat
 		return err
 	})
 	if err != nil {
+		cancelErrs.RecordError(fmt.Errorf("could not restart vreplication: %v", err))
 		ts.Logger().Errorf("Cancel migration failed: could not restart vreplication: %v", err)
 	}
 
-	err = ts.deleteReverseVReplication(cmCtx)
-	if err != nil {
-		ts.Logger().Errorf("Cancel migration failed: could not delete revers vreplication entries: %v", err)
+	if err := ts.deleteReverseVReplication(cmCtx); err != nil {
+		cancelErrs.RecordError(fmt.Errorf("could not delete reverse vreplication streams: %v", err))
+		ts.Logger().Errorf("Cancel migration failed: could not delete reverse vreplication streams: %v", err)
 	}
+
+	if cancelErrs.HasErrors() {
+		return vterrors.Wrap(cancelErrs.AggrError(vterrors.Aggregate), "cancel migration failed, manual cleanup work may be necessary")
+	}
+	return nil
 }
 
 func (ts *trafficSwitcher) freezeTargetVReplication(ctx context.Context) error {