Skip to content

Commit

Permalink
verificationhelper: add VerificationReady callback for when verificat…
Browse files Browse the repository at this point in the history
…ion is accepted

This callback supersedes the ScanQRCode and ShowQRCode callbacks.

Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Feb 13, 2025
1 parent 14008ca commit 5600dd4
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 129 deletions.
66 changes: 17 additions & 49 deletions crypto/verificationhelper/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ import (
type MockVerificationCallbacks interface {
GetRequestedVerifications() map[id.UserID][]id.VerificationTransactionID
GetScanQRCodeTransactions() []id.VerificationTransactionID
GetVerificationsReadyTransactions() []id.VerificationTransactionID
GetQRCodeShown(id.VerificationTransactionID) *verificationhelper.QRCode
}

type baseVerificationCallbacks struct {
scanQRCodeTransactions []id.VerificationTransactionID
verificationsRequested map[id.UserID][]id.VerificationTransactionID
verificationsReady []id.VerificationTransactionID
qrCodesShown map[id.VerificationTransactionID]*verificationhelper.QRCode
qrCodesScanned map[id.VerificationTransactionID]struct{}
doneTransactions map[id.VerificationTransactionID]struct{}
Expand All @@ -33,6 +35,7 @@ type baseVerificationCallbacks struct {
}

var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil)
var _ MockVerificationCallbacks = (*baseVerificationCallbacks)(nil)

func newBaseVerificationCallbacks() *baseVerificationCallbacks {
return &baseVerificationCallbacks{
Expand All @@ -55,6 +58,10 @@ func (c *baseVerificationCallbacks) GetScanQRCodeTransactions() []id.Verificatio
return c.scanQRCodeTransactions
}

func (c *baseVerificationCallbacks) GetVerificationsReadyTransactions() []id.VerificationTransactionID {
return c.verificationsReady
}

func (c *baseVerificationCallbacks) GetQRCodeShown(txnID id.VerificationTransactionID) *verificationhelper.QRCode {
return c.qrCodesShown[txnID]
}
Expand Down Expand Up @@ -85,6 +92,16 @@ func (c *baseVerificationCallbacks) VerificationRequested(ctx context.Context, t
c.verificationsRequested[from] = append(c.verificationsRequested[from], txnID)
}

func (c *baseVerificationCallbacks) VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, allowScanQRCode bool, qrCode *verificationhelper.QRCode) {
c.verificationsReady = append(c.verificationsReady, txnID)
if allowScanQRCode {
c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID)
}
if qrCode != nil {
c.qrCodesShown[txnID] = qrCode
}
}

func (c *baseVerificationCallbacks) VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) {
c.verificationCancellation[txnID] = &event.VerificationCancelEventContent{
Code: code,
Expand Down Expand Up @@ -116,23 +133,6 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific
c.decimalsShown[txnID] = decimals
}

type scanQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
}

var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil)

func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks {
return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
}

func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks {
return &scanQRCodeVerificationCallbacks{base}
}
func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID)
}

type showQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
}
Expand All @@ -147,44 +147,13 @@ func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks)
return &showQRCodeVerificationCallbacks{base}
}

func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
c.qrCodesShown[txnID] = qrCode
}

func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
c.qrCodesScanned[txnID] = struct{}{}
}

type showAndScanQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
*showQRCodeVerificationCallbacks
*scanQRCodeVerificationCallbacks
}

var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)

func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks {
base := newBaseVerificationCallbacks()
return &showAndScanQRCodeVerificationCallbacks{
base,
newShowQRCodeVerificationCallbacks(),
newScanQRCodeVerificationCallbacks(),
}
}

func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks {
return &showAndScanQRCodeVerificationCallbacks{
base,
newShowQRCodeVerificationCallbacks(),
newScanQRCodeVerificationCallbacks(),
}
}

type allVerificationCallbacks struct {
*baseVerificationCallbacks
*sasVerificationCallbacks
*scanQRCodeVerificationCallbacks
*showQRCodeVerificationCallbacks
}

Expand All @@ -193,7 +162,6 @@ func newAllVerificationCallbacks() *allVerificationCallbacks {
return &allVerificationCallbacks{
base,
newSASVerificationCallbacksWithBase(base),
newScanQRCodeVerificationCallbacksWithBase(base),
newShowQRCodeVerificationCallbacksWithBase(base),
}
}
4 changes: 4 additions & 0 deletions crypto/verificationhelper/qrcode.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ func NewQRCodeFromBytes(data []byte) (*QRCode, error) {
//
// [Section 11.12.2.4.1]: https://spec.matrix.org/v1.9/client-server-api/#qr-code-format
func (q *QRCode) Bytes() []byte {
if q == nil {
return nil
}

var buf bytes.Buffer
buf.WriteString("MATRIX") // Header
buf.WriteByte(0x02) // Version
Expand Down
25 changes: 13 additions & 12 deletions crypto/verificationhelper/reciprocate.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,28 +270,30 @@ func (vh *VerificationHelper) ConfirmQRCodeScanned(ctx context.Context, txnID id
return nil
}

func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *VerificationTransaction) error {
func (vh *VerificationHelper) generateQRCode(ctx context.Context, txn *VerificationTransaction) (*QRCode, error) {
log := vh.getLog(ctx).With().
Str("verification_action", "generate and show QR code").
Stringer("transaction_id", txn.TransactionID).
Logger()
ctx = log.WithContext(ctx)
if vh.showQRCode == nil {
log.Info().Msg("Ignoring QR code generation request as showing a QR code is not enabled on this device")
return nil

if !slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) ||
!slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate) {
log.Info().Msg("Ignoring QR code generation request as reciprocating is not supported by both devices")
return nil, nil
} else if !slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeScan) {
log.Info().Msg("Ignoring QR code generation request as other device cannot scan QR codes")
return nil
return nil, nil
}

ownCrossSigningPublicKeys := vh.mach.GetOwnCrossSigningPublicKeys(ctx)
if ownCrossSigningPublicKeys == nil || len(ownCrossSigningPublicKeys.MasterKey) == 0 {
return errors.New("failed to get own cross-signing master public key")
return nil, errors.New("failed to get own cross-signing master public key")
}

ownMasterKeyTrusted, err := vh.mach.CryptoStore.IsKeySignedBy(ctx, vh.client.UserID, ownCrossSigningPublicKeys.MasterKey, vh.client.UserID, vh.mach.OwnIdentity().SigningKey)
if err != nil {
return err
return nil, err
}
mode := QRCodeModeCrossSigning
if vh.client.UserID == txn.TheirUserID {
Expand All @@ -304,7 +306,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve
} else {
// This is a cross-signing situation.
if !ownMasterKeyTrusted {
return errors.New("cannot cross-sign other device when own master key is not trusted")
return nil, errors.New("cannot cross-sign other device when own master key is not trusted")
}
mode = QRCodeModeCrossSigning
}
Expand All @@ -318,7 +320,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve
// Key 2 is the other user's master signing key.
theirSigningKeys, err := vh.mach.GetCrossSigningPublicKeys(ctx, txn.TheirUserID)
if err != nil {
return err
return nil, err
}
key2 = theirSigningKeys.MasterKey.Bytes()
case QRCodeModeSelfVerifyingMasterKeyTrusted:
Expand All @@ -328,7 +330,7 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve
// Key 2 is the other device's key.
theirDevice, err := vh.mach.GetOrFetchDevice(ctx, txn.TheirUserID, txn.TheirDeviceID)
if err != nil {
return err
return nil, err
}
key2 = theirDevice.SigningKey.Bytes()
case QRCodeModeSelfVerifyingMasterKeyUntrusted:
Expand All @@ -343,6 +345,5 @@ func (vh *VerificationHelper) generateAndShowQRCode(ctx context.Context, txn *Ve

qrCode := NewQRCode(mode, txn.TransactionID, [32]byte(key1), [32]byte(key2))
txn.QRCodeSharedSecret = qrCode.SharedSecret
vh.showQRCode(ctx, txn.TransactionID, qrCode)
return nil
return qrCode, nil
}
81 changes: 45 additions & 36 deletions crypto/verificationhelper/verificationhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ type RequiredCallbacks interface {
// from another device.
VerificationRequested(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID)

// VerificationReady is called when a verification request has been
// accepted by both parties.
VerificationReady(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode)

// VerificationCancelled is called when the verification is cancelled.
VerificationCancelled(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string)

Expand All @@ -48,18 +52,7 @@ type ShowSASCallbacks interface {
ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int)
}

type ScanQRCodeCallbacks interface {
// ScanQRCode is called when another device has sent a
// m.key.verification.ready event and indicated that they are capable of
// showing a QR code.
ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID)
}

type ShowQRCodeCallbacks interface {
// ShowQRCode is called when the verification has been accepted and a QR
// code should be shown to the user.
ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode)

// QRCodeScanned is called when the other user has scanned the QR code and
// sent the m.key.verification.start event.
QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID)
Expand All @@ -71,24 +64,25 @@ type VerificationHelper struct {

store VerificationStore
activeTransactionsLock sync.Mutex
// activeTransactions map[id.VerificationTransactionID]*verificationTransaction

// supportedMethods are the methods that *we* support
supportedMethods []event.VerificationMethod
verificationRequested func(ctx context.Context, txnID id.VerificationTransactionID, from id.UserID, fromDevice id.DeviceID)
verificationReady func(ctx context.Context, txnID id.VerificationTransactionID, otherDeviceID id.DeviceID, supportsSAS, supportsScanQRCode bool, qrCode *QRCode)
verificationCancelledCallback func(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string)
verificationDone func(ctx context.Context, txnID id.VerificationTransactionID)

// showSAS is a callback that will be called after the SAS verification
// dance is complete and we want the client to show the emojis/decimals
showSAS func(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int)

scanQRCode func(ctx context.Context, txnID id.VerificationTransactionID)
showQRCode func(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode)
// qrCodeScaned is a callback that will be called when the other device
// scanned the QR code we are showing
qrCodeScaned func(ctx context.Context, txnID id.VerificationTransactionID)
}

var _ mautrix.VerificationHelper = (*VerificationHelper)(nil)

func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsScan bool) *VerificationHelper {
func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, store VerificationStore, callbacks any, supportsQRShow, supportsQRScan bool) *VerificationHelper {
if client.Crypto == nil {
panic("client.Crypto is nil")
}
Expand All @@ -107,6 +101,7 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor
panic("callbacks must implement RequiredCallbacks")
} else {
helper.verificationRequested = c.VerificationRequested
helper.verificationReady = c.VerificationReady
helper.verificationCancelledCallback = c.VerificationCancelled
helper.verificationDone = c.VerificationDone
}
Expand All @@ -115,16 +110,18 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS)
helper.showSAS = c.ShowSAS
}
if c, ok := callbacks.(ShowQRCodeCallbacks); ok {
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow)
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
helper.showQRCode = c.ShowQRCode
helper.qrCodeScaned = c.QRCodeScanned
if supportsQRShow {
if c, ok := callbacks.(ShowQRCodeCallbacks); !ok {
panic("callbacks must implement ShowQRCodeCallbacks if supportsQRShow is true")
} else {
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow)
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
helper.qrCodeScaned = c.QRCodeScanned
}
}
if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan {
if supportsQRScan {
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan)
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
helper.scanQRCode = c.ScanQRCode
}
helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods)
return &helper
Expand Down Expand Up @@ -421,15 +418,19 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V
}
txn.VerificationState = VerificationStateReady

if vh.scanQRCode != nil &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
vh.scanQRCode(ctx, txn.TransactionID)
}
supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) &&
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS)
supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) &&
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate)
supportsScanQRCode := supportsReciprocate &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) &&
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow)

if err := vh.generateAndShowQRCode(ctx, &txn); err != nil {
qrCode, err := vh.generateQRCode(ctx, &txn)
if err != nil {
return err
}
vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode)
return vh.store.SaveVerificationTransaction(ctx, txn)
}

Expand Down Expand Up @@ -737,15 +738,23 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif
}
}

if vh.scanQRCode != nil &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
vh.scanQRCode(ctx, txn.TransactionID)
supportsSAS := slices.Contains(vh.supportedMethods, event.VerificationMethodSAS) &&
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodSAS)
supportsReciprocate := slices.Contains(vh.supportedMethods, event.VerificationMethodReciprocate) &&
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodReciprocate)
supportsScanQRCode := supportsReciprocate &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) &&
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow)

qrCode, err := vh.generateQRCode(ctx, &txn)
if err != nil {
vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate QR code: %w", err)
return
}

if err := vh.generateAndShowQRCode(ctx, &txn); err != nil {
vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to generate and show QR code: %w", err)
} else if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil {
vh.verificationReady(ctx, txn.TransactionID, txn.TheirDeviceID, supportsSAS, supportsScanQRCode, qrCode)

if err := vh.store.SaveVerificationTransaction(ctx, txn); err != nil {
vh.cancelVerificationTxn(ctx, txn, event.VerificationCancelCodeInternalError, "failed to save verification transaction: %w", err)
}
}
Expand Down
Loading

0 comments on commit 5600dd4

Please sign in to comment.