Skip to content

Commit

Permalink
sqlite: use map for swap
Browse files Browse the repository at this point in the history
  • Loading branch information
n8maninger committed Feb 6, 2024
1 parent 48570ee commit 48422cb
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions persist/sqlite/contracts.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ func (s *Store) ReviseContract(revision contracts.SignedRevision, roots []types.
return fmt.Errorf("failed to swap sectors: %w", err)
}
oldA, oldB := roots[change.A], roots[change.B]
if swapped[0] != oldA {
return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldA, swapped[0])
} else if swapped[1] != oldB {
return fmt.Errorf("inconsistent sector swap: expected %s, got %s", oldB, swapped[1])
for root := range swapped {
if root != oldA && root != oldB {
return fmt.Errorf("inconsistent sector swap: expected %s or %s, got %s", oldA, oldB, root)
}
}
roots[change.A], roots[change.B] = roots[change.B], roots[change.A]
}
Expand Down Expand Up @@ -602,9 +602,9 @@ WHERE contract_id=$1 AND root_index=$2`, contractID, index)
return ref.root, nil
}

func swapSectors(tx txn, contractID int64, i, j uint64) ([2]types.Hash256, error) {
func swapSectors(tx txn, contractID int64, i, j uint64) (map[types.Hash256]bool, error) {
if i == j {
return [2]types.Hash256{}, nil
return nil, nil
}

var records []contractSectorRootRef
Expand All @@ -614,43 +614,46 @@ INNER JOIN stored_sectors ss ON (ss.id = csr.sector_id)
WHERE contract_id=$1 AND root_index IN ($2, $3)
ORDER BY root_index ASC;`, contractID, i, j)
if err != nil {
return [2]types.Hash256{}, fmt.Errorf("failed to query sector IDs: %w", err)
return nil, fmt.Errorf("failed to query sector IDs: %w", err)
}
defer rows.Close()
for rows.Next() {
ref, err := scanContractSectorRootRef(rows)
if err != nil {
return [2]types.Hash256{}, fmt.Errorf("failed to scan sector ref: %w", err)
return nil, fmt.Errorf("failed to scan sector ref: %w", err)
}
records = append(records, ref)
}

if len(records) != 2 {
return [2]types.Hash256{}, errors.New("failed to find both sectors")
return nil, errors.New("failed to find both sectors")
}

stmt, err := tx.Prepare(`UPDATE contract_sector_roots SET sector_id=$1 WHERE id=$2 RETURNING sector_id;`)
if err != nil {
return [2]types.Hash256{}, fmt.Errorf("failed to prepare update statement: %w", err)
return nil, fmt.Errorf("failed to prepare update statement: %w", err)
}
defer stmt.Close()

var newSectorID int64
err = stmt.QueryRow(records[1].sectorID, records[0].dbID).Scan(&newSectorID)
if err != nil {
return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err)
return nil, fmt.Errorf("failed to update sector ID: %w", err)
} else if newSectorID != records[1].sectorID {
return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID)
return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID)
}

err = stmt.QueryRow(records[0].sectorID, records[1].dbID).Scan(&newSectorID)
if err != nil {
return [2]types.Hash256{}, fmt.Errorf("failed to update sector ID: %w", err)
return nil, fmt.Errorf("failed to update sector ID: %w", err)
} else if newSectorID != records[0].sectorID {
return [2]types.Hash256{}, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID)
return nil, fmt.Errorf("expected sector ID %v, got %v", records[0].sectorID, newSectorID)
}

return [2]types.Hash256{records[0].root, records[1].root}, nil
return map[types.Hash256]bool{
records[0].root: true,
records[1].root: true,
}, nil
}

// lastContractSectors returns the last n sector IDs for a contract.
Expand Down

0 comments on commit 48422cb

Please sign in to comment.