Skip to content

Commit

Permalink
client/asset/dcr: add package-level func for setting custom wallets
Browse files Browse the repository at this point in the history
  • Loading branch information
itswisdomagain committed Oct 22, 2021
1 parent 89fa6ea commit 9e52b2a
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 125 deletions.
35 changes: 27 additions & 8 deletions client/asset/dcr/dcr.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ var _ asset.Driver = (*Driver)(nil)

// Open creates the DCR exchange wallet. Start the wallet with its Run method.
func (d *Driver) Open(cfg *asset.WalletConfig, logger dex.Logger, network dex.Network) (asset.Wallet, error) {
return NewWallet(cfg, logger, network, &rpcWallet{})
return NewWallet(cfg, logger, network)
}

// DecodeCoinID creates a human-readable representation of a coin ID for Decred.
Expand Down Expand Up @@ -416,9 +416,9 @@ var _ asset.Wallet = (*ExchangeWallet)(nil)

// NewWallet is the exported constructor by which the DEX will import the
// exchange wallet.
func NewWallet(cfg *asset.WalletConfig, logger dex.Logger, network dex.Network, wallet Wallet) (*ExchangeWallet, error) {
func NewWallet(cfg *asset.WalletConfig, logger dex.Logger, network dex.Network) (*ExchangeWallet, error) {
// loadConfig will set fields if defaults are used and set the chainParams
// package variable.
// variable.
walletCfg, chainParams, err := loadConfig(cfg.Settings, network)
if err != nil {
return nil, err
Expand All @@ -429,10 +429,17 @@ func NewWallet(cfg *asset.WalletConfig, logger dex.Logger, network dex.Network,
return nil, err
}

dcr.wallet = wallet
err = dcr.wallet.Initialize(cfg, walletCfg, chainParams, logger)
if err != nil {
return nil, err
// Set dcr.wallet using either the default rpcWallet or a custom wallet.
if customWalletConstructor == nil {
dcr.wallet, err = newRPCWallet(walletCfg, chainParams, logger)
if err != nil {
return nil, err
}
} else {
dcr.wallet, err = customWalletConstructor(walletCfg, chainParams, logger)
if err != nil {
return nil, fmt.Errorf("custom wallet setup error: %v", err)
}
}

return dcr, nil
Expand Down Expand Up @@ -520,6 +527,14 @@ func (dcr *ExchangeWallet) Connect(ctx context.Context) (*sync.WaitGroup, error)
}
}()

curnet, err := dcr.wallet.Network(ctx)
if err != nil {
return nil, fmt.Errorf("unable to fetch wallet network: %w", err)
}
if curnet != dcr.chainParams.Net {
return nil, fmt.Errorf("unexpected wallet network %s, expected %s", curnet, dcr.chainParams.Net)
}

// Initialize the best block.
dcr.tipMtx.Lock()
dcr.currentTip, err = dcr.getBestBlock(ctx)
Expand Down Expand Up @@ -2441,7 +2456,11 @@ func msgTxToHex(msgTx *wire.MsgTx) (string, error) {
// signTx attempts to sign all transaction inputs. If it fails to completely
// sign the transaction, it is an error and a nil *wire.MsgTx is returned.
func (dcr *ExchangeWallet) signTx(baseTx *wire.MsgTx) (*wire.MsgTx, error) {
res, err := dcr.wallet.SignRawTransaction(dcr.ctx, baseTx)
txHex, err := msgTxToHex(baseTx)
if err != nil {
return nil, fmt.Errorf("failed to encode MsgTx: %w", err)
}
res, err := dcr.wallet.SignRawTransaction(dcr.ctx, txHex)
if err != nil {
return nil, fmt.Errorf("signrawtransaction error: %w", err)
}
Expand Down
8 changes: 6 additions & 2 deletions client/asset/dcr/dcr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func tNewWallet() (*ExchangeWallet, *tRPCClient, func(), error) {
return nil, nil, nil, err
}
wallet.wallet = &rpcWallet{
node: client,
rpcClient: client,
}
wallet.ctx = walletCtx

Expand Down Expand Up @@ -327,6 +327,10 @@ func newTRPCClient() *tRPCClient {
}
}

func (c *tRPCClient) GetCurrentNet(context.Context) (wire.CurrencyNet, error) {
return tChainParams.Net, nil
}

func (c *tRPCClient) EstimateSmartFee(_ context.Context, confirmations int64, mode chainjson.EstimateSmartFeeMode) (*chainjson.EstimateSmartFeeResult, error) {
if c.estFeeErr != nil {
return nil, c.estFeeErr
Expand Down Expand Up @@ -2296,7 +2300,7 @@ func TestSyncStatus(t *testing.T) {
node.blockchainInfoErr = nil

wallet.wallet = &rpcWallet{
node: node,
rpcClient: node,
tipAtConnect: 100,
}
node.blockchainInfo = &chainjson.GetBlockChainInfoResult{
Expand Down
Loading

0 comments on commit 9e52b2a

Please sign in to comment.