Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix handling of creating faux transaction for recovered outputs #3959

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions base_layer/wallet/src/output_manager_service/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,8 @@ pub enum OutputManagerRequest {
num_kernels: usize,
num_outputs: usize,
},
ScanForRecoverableOutputs {
outputs: Vec<TransactionOutput>,
tx_id: TxId,
},
ScanOutputs {
outputs: Vec<TransactionOutput>,
tx_id: TxId,
},
ScanForRecoverableOutputs(Vec<TransactionOutput>),
ScanOutputs(Vec<TransactionOutput>),
AddKnownOneSidedPaymentScript(KnownOneSidedPaymentScript),
CreateOutputWithFeatures {
value: MicroTari,
Expand Down Expand Up @@ -194,8 +188,8 @@ impl fmt::Display for OutputManagerRequest {
"FeeEstimate(amount: {}, fee_per_gram: {}, num_kernels: {}, num_outputs: {})",
amount, fee_per_gram, num_kernels, num_outputs
),
ScanForRecoverableOutputs { .. } => write!(f, "ScanForRecoverableOutputs"),
ScanOutputs { .. } => write!(f, "ScanOutputs"),
ScanForRecoverableOutputs(_) => write!(f, "ScanForRecoverableOutputs"),
ScanOutputs(_) => write!(f, "ScanOutputs"),
AddKnownOneSidedPaymentScript(_) => write!(f, "AddKnownOneSidedPaymentScript"),
CreateOutputWithFeatures { value, features } => {
write!(f, "CreateOutputWithFeatures({}, {})", value, features,)
Expand Down Expand Up @@ -247,8 +241,8 @@ pub enum OutputManagerResponse {
PublicRewindKeys(Box<PublicRewindKeys>),
RecoveryByte(u8),
FeeEstimate(MicroTari),
RewoundOutputs(Vec<UnblindedOutput>),
ScanOutputs(Vec<UnblindedOutput>),
RewoundOutputs(Vec<RecoveredOutput>),
ScanOutputs(Vec<RecoveredOutput>),
AddKnownOneSidedPaymentScript,
CreateOutputWithFeatures {
output: Box<UnblindedOutputBuilder>,
Expand Down Expand Up @@ -300,6 +294,12 @@ pub struct PublicRewindKeys {
pub rewind_blinding_public_key: PublicKey,
}

#[derive(Debug, Clone)]
pub struct RecoveredOutput {
pub tx_id: TxId,
pub output: UnblindedOutput,
}

#[derive(Clone)]
pub struct OutputManagerHandle {
handle: SenderService<OutputManagerRequest, Result<OutputManagerResponse, OutputManagerError>>,
Expand Down Expand Up @@ -728,11 +728,10 @@ impl OutputManagerHandle {
pub async fn scan_for_recoverable_outputs(
&mut self,
outputs: Vec<TransactionOutput>,
tx_id: TxId,
) -> Result<Vec<UnblindedOutput>, OutputManagerError> {
) -> Result<Vec<RecoveredOutput>, OutputManagerError> {
match self
.handle
.call(OutputManagerRequest::ScanForRecoverableOutputs { outputs, tx_id })
.call(OutputManagerRequest::ScanForRecoverableOutputs(outputs))
.await??
{
OutputManagerResponse::RewoundOutputs(outputs) => Ok(outputs),
Expand All @@ -743,13 +742,8 @@ impl OutputManagerHandle {
pub async fn scan_outputs_for_one_sided_payments(
&mut self,
outputs: Vec<TransactionOutput>,
tx_id: TxId,
) -> Result<Vec<UnblindedOutput>, OutputManagerError> {
match self
.handle
.call(OutputManagerRequest::ScanOutputs { outputs, tx_id })
.await??
{
) -> Result<Vec<RecoveredOutput>, OutputManagerError> {
match self.handle.call(OutputManagerRequest::ScanOutputs(outputs)).await?? {
OutputManagerResponse::ScanOutputs(outputs) => Ok(outputs),
_ => Err(OutputManagerError::UnexpectedApiResponse),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ use crate::{
key_manager_service::KeyManagerInterface,
output_manager_service::{
error::{OutputManagerError, OutputManagerStorageError},
handle::RecoveredOutput,
resources::OutputManagerKeyManagerBranch,
storage::{
database::{OutputManagerBackend, OutputManagerDatabase},
Expand Down Expand Up @@ -84,8 +85,7 @@ where
pub async fn scan_and_recover_outputs(
&mut self,
outputs: Vec<TransactionOutput>,
tx_id: TxId,
) -> Result<Vec<UnblindedOutput>, OutputManagerError> {
) -> Result<Vec<RecoveredOutput>, OutputManagerError> {
let start = Instant::now();
let outputs_length = outputs.len();
let mut rewound_outputs: Vec<(UnblindedOutput, BulletRangeProof)> = outputs
Expand Down Expand Up @@ -131,7 +131,7 @@ where
rewind_time.as_millis(),
);

let mut rewound_outputs_to_return = Vec::new();
let mut rewound_outputs_with_tx_id: Vec<RecoveredOutput> = Vec::new();
for (output, proof) in rewound_outputs.iter_mut() {
let db_output = DbUnblindedOutput::rewindable_from_unblinded_output(
output.clone(),
Expand All @@ -140,6 +140,7 @@ where
None,
Some(proof),
)?;
let tx_id = TxId::new_random();
let output_hex = db_output.commitment.to_hex();
if let Err(e) = self.db.add_unspent_output_with_tx_id(tx_id, db_output).await {
match e {
Expand All @@ -153,6 +154,11 @@ where
_ => return Err(OutputManagerError::from(e)),
}
}

rewound_outputs_with_tx_id.push(RecoveredOutput {
output: output.clone(),
tx_id,
});
self.update_outputs_script_private_key_and_update_key_manager_index(output)
.await?;
trace!(
Expand All @@ -162,9 +168,9 @@ where
output.value,
output.features,
);
rewound_outputs_to_return.push(output.clone());
}
Ok(rewound_outputs_to_return)

Ok(rewound_outputs_with_tx_id)
}

/// Find the key manager index that corresponds to the spending key in the rewound output, if found then modify
Expand Down
21 changes: 13 additions & 8 deletions base_layer/wallet/src/output_manager_service/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ use crate::{
OutputManagerRequest,
OutputManagerResponse,
PublicRewindKeys,
RecoveredOutput,
},
recovery::StandardUtxoRecoverer,
resources::{OutputManagerKeyManagerBranch, OutputManagerResources},
Expand Down Expand Up @@ -420,17 +421,17 @@ where
OutputManagerRequest::CalculateRecoveryByte { spending_key, value } => Ok(
OutputManagerResponse::RecoveryByte(self.calculate_recovery_byte(spending_key, value)?),
),
OutputManagerRequest::ScanForRecoverableOutputs { outputs, tx_id } => StandardUtxoRecoverer::new(
OutputManagerRequest::ScanForRecoverableOutputs(outputs) => StandardUtxoRecoverer::new(
self.resources.master_key_manager.clone(),
self.resources.rewind_data.clone(),
self.resources.factories.clone(),
self.resources.db.clone(),
)
.scan_and_recover_outputs(outputs, tx_id)
.scan_and_recover_outputs(outputs)
.await
.map(OutputManagerResponse::RewoundOutputs),
OutputManagerRequest::ScanOutputs { outputs, tx_id } => self
.scan_outputs_for_one_sided_payments(outputs, tx_id)
OutputManagerRequest::ScanOutputs(outputs) => self
.scan_outputs_for_one_sided_payments(outputs)
.await
.map(OutputManagerResponse::ScanOutputs),
OutputManagerRequest::AddKnownOneSidedPaymentScript(known_script) => self
Expand Down Expand Up @@ -2006,12 +2007,11 @@ where
async fn scan_outputs_for_one_sided_payments(
&mut self,
outputs: Vec<TransactionOutput>,
tx_id: TxId,
) -> Result<Vec<UnblindedOutput>, OutputManagerError> {
) -> Result<Vec<RecoveredOutput>, OutputManagerError> {
let known_one_sided_payment_scripts: Vec<KnownOneSidedPaymentScript> =
self.resources.db.get_all_known_one_sided_payment_scripts().await?;

let mut rewound_outputs: Vec<UnblindedOutput> = Vec::new();
let mut rewound_outputs: Vec<RecoveredOutput> = Vec::new();
for output in outputs {
let position = known_one_sided_payment_scripts
.iter()
Expand Down Expand Up @@ -2062,9 +2062,14 @@ where
)?;

let output_hex = output.commitment.to_hex();
let tx_id = TxId::new_random();

match self.resources.db.add_unspent_output_with_tx_id(tx_id, db_output).await {
Ok(_) => {
rewound_outputs.push(rewound_output);
rewound_outputs.push(RecoveredOutput {
output: rewound_output,
tx_id,
});
},
Err(OutputManagerStorageError::DuplicateOutput) => {
warn!(
Expand Down
24 changes: 11 additions & 13 deletions base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,11 @@ where TBackend: WalletBackend + 'static
total_scanned += outputs.len();

let start = Instant::now();
let (tx_id, found_outputs) = self.scan_for_outputs(outputs).await?;
let found_outputs = self.scan_for_outputs(outputs).await?;
scan_for_outputs_profiling.push(start.elapsed());

let (count, amount) = self
.import_utxos_to_transaction_service(found_outputs, tx_id, current_height)
.import_utxos_to_transaction_service(found_outputs, current_height)
.await?;

self.resources
Expand Down Expand Up @@ -492,38 +492,36 @@ where TBackend: WalletBackend + 'static
async fn scan_for_outputs(
&mut self,
outputs: Vec<TransactionOutput>,
) -> Result<(TxId, Vec<(UnblindedOutput, String)>), UtxoScannerError> {
let mut found_outputs: Vec<(UnblindedOutput, String)> = Vec::new();
let tx_id = TxId::new_random();
) -> Result<Vec<(UnblindedOutput, String, TxId)>, UtxoScannerError> {
let mut found_outputs: Vec<(UnblindedOutput, String, TxId)> = Vec::new();
if self.mode == UtxoScannerMode::Recovery {
found_outputs.append(
&mut self
.resources
.output_manager_service
.scan_for_recoverable_outputs(outputs.clone(), tx_id)
.scan_for_recoverable_outputs(outputs.clone())
.await?
.into_iter()
.map(|uo| (uo, self.resources.recovery_message.clone()))
.map(|ro| (ro.output, self.resources.recovery_message.clone(), ro.tx_id))
.collect(),
);
};
found_outputs.append(
&mut self
.resources
.output_manager_service
.scan_outputs_for_one_sided_payments(outputs.clone(), tx_id)
.scan_outputs_for_one_sided_payments(outputs.clone())
.await?
.into_iter()
.map(|uo| (uo, self.resources.one_sided_payment_message.clone()))
.map(|ro| (ro.output, self.resources.one_sided_payment_message.clone(), ro.tx_id))
.collect(),
);
Ok((tx_id, found_outputs))
Ok(found_outputs)
}

async fn import_utxos_to_transaction_service(
&mut self,
utxos: Vec<(UnblindedOutput, String)>,
tx_id: TxId,
utxos: Vec<(UnblindedOutput, String, TxId)>,
current_height: u64,
) -> Result<(u64, MicroTari), UtxoScannerError> {
let mut num_recovered = 0u64;
Expand All @@ -532,7 +530,7 @@ where TBackend: WalletBackend + 'static
// value is a placeholder.
let source_public_key = CommsPublicKey::default();

for (uo, message) in utxos {
for (uo, message, tx_id) in utxos {
match self
.import_unblinded_utxo_to_transaction_service(
uo.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2136,7 +2136,6 @@ async fn scan_for_recovery_test() {
.into_iter()
.chain(non_rewindable_outputs.clone().into_iter())
.collect::<Vec<TransactionOutput>>(),
TxId::from(1u64),
)
.await
.unwrap();
Expand All @@ -2146,7 +2145,9 @@ async fn scan_for_recovery_test() {

assert_eq!(recovered_outputs.len(), NUM_REWINDABLE - 1);
for o in rewindable_unblinded_outputs.iter().skip(1) {
assert!(recovered_outputs.iter().any(|ro| ro.spending_key == o.spending_key));
assert!(recovered_outputs
.iter()
.any(|ro| ro.output.spending_key == o.spending_key));
}
}

Expand All @@ -2172,7 +2173,7 @@ async fn recovered_output_key_not_in_keychain() {

let result = oms
.output_manager_handle
.scan_for_recoverable_outputs(vec![rewindable_output], TxId::from(1u64))
.scan_for_recoverable_outputs(vec![rewindable_output])
.await;

assert!(matches!(
Expand Down
23 changes: 12 additions & 11 deletions base_layer/wallet/tests/support/output_manager_service_mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ use std::sync::{Arc, Mutex};

use futures::StreamExt;
use log::*;
use tari_common_types::transaction::TxId;
use tari_service_framework::{reply_channel, reply_channel::Receiver};
use tari_shutdown::ShutdownSignal;
use tari_wallet::output_manager_service::{
error::OutputManagerError,
handle::{OutputManagerEvent, OutputManagerHandle, OutputManagerRequest, OutputManagerResponse},
handle::{OutputManagerEvent, OutputManagerHandle, OutputManagerRequest, OutputManagerResponse, RecoveredOutput},
storage::models::DbUnblindedOutput,
};
use tokio::sync::{broadcast, broadcast::Sender, oneshot};
Expand Down Expand Up @@ -96,17 +97,17 @@ impl OutputManagerServiceMock {
) {
info!(target: LOG_TARGET, "Handling Request: {}", request);
match request {
OutputManagerRequest::ScanForRecoverableOutputs {
outputs: requested_outputs,
tx_id: _tx_id,
} => {
OutputManagerRequest::ScanForRecoverableOutputs(requested_outputs) => {
let lock = acquire_lock!(self.state.recoverable_outputs);
let outputs = (*lock)
.clone()
.into_iter()
.filter_map(|dbuo| {
if requested_outputs.iter().any(|ro| dbuo.commitment == ro.commitment) {
Some(dbuo.unblinded_output)
Some(RecoveredOutput {
output: dbuo.unblinded_output,
tx_id: TxId::new_random(),
})
} else {
None
}
Expand All @@ -120,17 +121,17 @@ impl OutputManagerServiceMock {
e
});
},
OutputManagerRequest::ScanOutputs {
outputs: requested_outputs,
tx_id: _tx_id,
} => {
OutputManagerRequest::ScanOutputs(requested_outputs) => {
let lock = acquire_lock!(self.state.one_sided_payments);
let outputs = (*lock)
.clone()
.into_iter()
.filter_map(|dbuo| {
if requested_outputs.iter().any(|ro| dbuo.commitment == ro.commitment) {
Some(dbuo.unblinded_output)
Some(RecoveredOutput {
output: dbuo.unblinded_output,
tx_id: TxId::new_random(),
})
} else {
None
}
Expand Down
9 changes: 3 additions & 6 deletions base_layer/wallet/tests/transaction_service_tests/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -925,18 +925,15 @@ fn recover_one_sided_transaction() {
let outputs = completed_tx.transaction.body.outputs().clone();

let unblinded = bob_oms
.scan_outputs_for_one_sided_payments(outputs.clone(), TxId::new_random())
.scan_outputs_for_one_sided_payments(outputs.clone())
.await
.unwrap();
// Bob should be able to claim 1 output.
assert_eq!(1, unblinded.len());
assert_eq!(value, unblinded[0].value);
assert_eq!(value, unblinded[0].output.value);

// Should ignore already existing outputs
let unblinded = bob_oms
.scan_outputs_for_one_sided_payments(outputs, TxId::new_random())
.await
.unwrap();
let unblinded = bob_oms.scan_outputs_for_one_sided_payments(outputs).await.unwrap();
assert!(unblinded.is_empty());
});
}
Expand Down