diff --git a/persist/sqlite/contracts.go b/persist/sqlite/contracts.go index 1ccd6820..8599b132 100644 --- a/persist/sqlite/contracts.go +++ b/persist/sqlite/contracts.go @@ -323,10 +323,10 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, roots []types. return fmt.Errorf("failed to swap sectors: %w", err) } oldA, oldB := roots[change.A], roots[change.B] - if swapped[0] != oldA { - return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldA, swapped[0]) - } else if swapped[1] != oldB { - return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldB, swapped[1]) + for root := range swapped { + if root != oldA && root != oldB { + return fmt.Errorf("inconsistent sector swap: expected %s or %s, got %s", oldA, oldB, root) + } } roots[change.A], roots[change.B] = roots[change.B], roots[change.A] } @@ -602,9 +602,9 @@ WHERE contract_id=$1 AND root_index=$2`, contractID, index) return ref.root, nil } -func swapSectors(tx txn, contractID int64, i, j uint64) ([2]types.Hash256, error) { +func swapSectors(tx txn, contractID int64, i, j uint64) (map[types.Hash256]bool, error) { if i == j { - return [2]types.Hash256{}, nil + return nil, nil } var records []contractSectorRootRef @@ -614,43 +614,46 @@ INNER JOIN stored_sectors ss ON (ss.id = csr.sector_id) WHERE contract_id=$1 AND root_index IN ($2, $3) ORDER BY root_index ASC;`, contractID, i, j) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to query sector IDs: %w", err) + return nil, fmt.Errorf("failed to query sector IDs: %w", err) } defer rows.Close() for rows.Next() { ref, err := scanContractSectorRootRef(rows) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to scan sector ref: %w", err) + return nil, fmt.Errorf("failed to scan sector ref: %w", err) } records = append(records, ref) } if len(records) != 2 { - return [2]types.Hash256{}, errors.New("failed to find both sectors") + return nil, errors.New("failed to find both sectors") } stmt, err := tx.Prepare(`UPDATE contract_sector_roots SET sector_id=$1 WHERE id=$2 RETURNING sector_id;`) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to prepare update statement: %w", err) + return nil, fmt.Errorf("failed to prepare update statement: %w", err) } defer stmt.Close() var newSectorID int64 err = stmt.QueryRow(records[1].sectorID, records[0].dbID).Scan(&newSectorID) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err) + return nil, fmt.Errorf("failed to update sector ID: %w", err) } else if newSectorID != records[1].sectorID { - return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) + return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) } err = stmt.QueryRow(records[0].sectorID, records[1].dbID).Scan(&newSectorID) if err != nil { - return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err) + return nil, fmt.Errorf("failed to update sector ID: %w", err) } else if newSectorID != records[0].sectorID { - return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) + return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID) } - return [2]types.Hash256{records[0].root, records[1].root}, nil + return map[types.Hash256]bool{ + records[0].root: true, + records[1].root: true, + }, nil } // lastContractSectors returns the last n sector IDs for a contract.