Skip to content

Commit

Permalink
verificationhelper: don't request QR scan if not enabled
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Feb 5, 2025
1 parent 475c4bf commit 890db20
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 33 deletions.
69 changes: 58 additions & 11 deletions crypto/verificationhelper/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ type baseVerificationCallbacks struct {
decimalsShown map[id.VerificationTransactionID][]int
}

var _ verificationhelper.RequiredCallbacks = (*baseVerificationCallbacks)(nil)

func newBaseVerificationCallbacks() *baseVerificationCallbacks {
return &baseVerificationCallbacks{
verificationsRequested: map[id.UserID][]id.VerificationTransactionID{},
Expand Down Expand Up @@ -98,6 +100,8 @@ type sasVerificationCallbacks struct {
*baseVerificationCallbacks
}

var _ verificationhelper.ShowSASCallbacks = (*sasVerificationCallbacks)(nil)

func newSASVerificationCallbacks() *sasVerificationCallbacks {
return &sasVerificationCallbacks{newBaseVerificationCallbacks()}
}
Expand All @@ -112,41 +116,84 @@ func (c *sasVerificationCallbacks) ShowSAS(ctx context.Context, txnID id.Verific
c.decimalsShown[txnID] = decimals
}

type qrCodeVerificationCallbacks struct {
type scanQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
}

func newQRCodeVerificationCallbacks() *qrCodeVerificationCallbacks {
return &qrCodeVerificationCallbacks{newBaseVerificationCallbacks()}
}
var _ verificationhelper.ScanQRCodeCallbacks = (*scanQRCodeVerificationCallbacks)(nil)

func newQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *qrCodeVerificationCallbacks {
return &qrCodeVerificationCallbacks{base}
func newScanQRCodeVerificationCallbacks() *scanQRCodeVerificationCallbacks {
return &scanQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
}

func (c *qrCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
func newScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *scanQRCodeVerificationCallbacks {
return &scanQRCodeVerificationCallbacks{base}
}
func (c *scanQRCodeVerificationCallbacks) ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID) {
c.scanQRCodeTransactions = append(c.scanQRCodeTransactions, txnID)
}

func (c *qrCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
type showQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
}

var _ verificationhelper.ShowQRCodeCallbacks = (*showQRCodeVerificationCallbacks)(nil)

func newShowQRCodeVerificationCallbacks() *showQRCodeVerificationCallbacks {
return &showQRCodeVerificationCallbacks{newBaseVerificationCallbacks()}
}

func newShowQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showQRCodeVerificationCallbacks {
return &showQRCodeVerificationCallbacks{base}
}

func (c *showQRCodeVerificationCallbacks) ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *verificationhelper.QRCode) {
c.qrCodesShown[txnID] = qrCode
}

func (c *qrCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
func (c *showQRCodeVerificationCallbacks) QRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) {
c.qrCodesScanned[txnID] = struct{}{}
}

type showAndScanQRCodeVerificationCallbacks struct {
*baseVerificationCallbacks
*showQRCodeVerificationCallbacks
*scanQRCodeVerificationCallbacks
}

var _ verificationhelper.ScanQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)
var _ verificationhelper.ShowQRCodeCallbacks = (*showAndScanQRCodeVerificationCallbacks)(nil)

func newShowAndScanQRCodeVerificationCallbacks() *showAndScanQRCodeVerificationCallbacks {
base := newBaseVerificationCallbacks()
return &showAndScanQRCodeVerificationCallbacks{
base,
newShowQRCodeVerificationCallbacks(),
newScanQRCodeVerificationCallbacks(),
}
}

func newShowAndScanQRCodeVerificationCallbacksWithBase(base *baseVerificationCallbacks) *showAndScanQRCodeVerificationCallbacks {
return &showAndScanQRCodeVerificationCallbacks{
base,
newShowQRCodeVerificationCallbacks(),
newScanQRCodeVerificationCallbacks(),
}
}

type allVerificationCallbacks struct {
*baseVerificationCallbacks
*sasVerificationCallbacks
*qrCodeVerificationCallbacks
*scanQRCodeVerificationCallbacks
*showQRCodeVerificationCallbacks
}

func newAllVerificationCallbacks() *allVerificationCallbacks {
base := newBaseVerificationCallbacks()
return &allVerificationCallbacks{
base,
newSASVerificationCallbacksWithBase(base),
newQRCodeVerificationCallbacksWithBase(base),
newScanQRCodeVerificationCallbacksWithBase(base),
newShowQRCodeVerificationCallbacksWithBase(base),
}
}
31 changes: 18 additions & 13 deletions crypto/verificationhelper/verificationhelper.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"time"

"github.com/rs/zerolog"
"go.mau.fi/util/exslices"
"go.mau.fi/util/jsontime"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
Expand Down Expand Up @@ -47,12 +48,14 @@ type ShowSASCallbacks interface {
ShowSAS(ctx context.Context, txnID id.VerificationTransactionID, emojis []rune, emojiDescriptions []string, decimals []int)
}

type ShowQRCodeCallbacks interface {
type ScanQRCodeCallbacks interface {
// ScanQRCode is called when another device has sent a
// m.key.verification.ready event and indicated that they are capable of
// showing a QR code.
ScanQRCode(ctx context.Context, txnID id.VerificationTransactionID)
}

type ShowQRCodeCallbacks interface {
// ShowQRCode is called when the verification has been accepted and a QR
// code should be shown to the user.
ShowQRCode(ctx context.Context, txnID id.VerificationTransactionID, qrCode *QRCode)
Expand Down Expand Up @@ -108,24 +111,22 @@ func NewVerificationHelper(client *mautrix.Client, mach *crypto.OlmMachine, stor
helper.verificationDone = c.VerificationDone
}

supportedMethods := map[event.VerificationMethod]struct{}{}
if c, ok := callbacks.(ShowSASCallbacks); ok {
supportedMethods[event.VerificationMethodSAS] = struct{}{}
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodSAS)
helper.showSAS = c.ShowSAS
}
if c, ok := callbacks.(ShowQRCodeCallbacks); ok {
supportedMethods[event.VerificationMethodQRCodeShow] = struct{}{}
supportedMethods[event.VerificationMethodReciprocate] = struct{}{}
helper.scanQRCode = c.ScanQRCode
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeShow)
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
helper.showQRCode = c.ShowQRCode
helper.qrCodeScaned = c.QRCodeScanned
}
if supportsScan {
supportedMethods[event.VerificationMethodQRCodeScan] = struct{}{}
supportedMethods[event.VerificationMethodReciprocate] = struct{}{}
if c, ok := callbacks.(ScanQRCodeCallbacks); ok && supportsScan {
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodQRCodeScan)
helper.supportedMethods = append(helper.supportedMethods, event.VerificationMethodReciprocate)
helper.scanQRCode = c.ScanQRCode
}

helper.supportedMethods = maps.Keys(supportedMethods)
helper.supportedMethods = exslices.DeduplicateUnsorted(helper.supportedMethods)
return &helper
}

Expand Down Expand Up @@ -420,7 +421,9 @@ func (vh *VerificationHelper) AcceptVerification(ctx context.Context, txnID id.V
}
txn.VerificationState = VerificationStateReady

if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
if vh.scanQRCode != nil &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
vh.scanQRCode(ctx, txn.TransactionID)
}

Expand Down Expand Up @@ -734,7 +737,9 @@ func (vh *VerificationHelper) onVerificationReady(ctx context.Context, txn Verif
}
}

if vh.scanQRCode != nil && slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
if vh.scanQRCode != nil &&
slices.Contains(vh.supportedMethods, event.VerificationMethodQRCodeScan) && // technically redundant because vh.scanQRCode is only set if this is true
slices.Contains(txn.TheirSupportedMethods, event.VerificationMethodQRCodeShow) {
vh.scanQRCode(ctx, txn.TransactionID)
}

Expand Down
19 changes: 10 additions & 9 deletions crypto/verificationhelper/verificationhelper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,12 @@ func TestVerification_Start(t *testing.T) {
expectedVerificationMethods []event.VerificationMethod
}{
{false, newBaseVerificationCallbacks(), "no supported verification methods", nil},
{true, newBaseVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{false, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodSAS}},
{true, newSASVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
{true, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, newQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{true, newScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{false, newScanQRCodeVerificationCallbacks(), "no supported verification methods", nil},
{false, newShowQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{true, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, newShowAndScanQRCodeVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
{true, newAllVerificationCallbacks(), "", []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
}
Expand All @@ -124,7 +125,7 @@ func TestVerification_Start(t *testing.T) {
return
}

assert.NoError(t, err)
require.NoError(t, err)
assert.NotEmpty(t, txnID)

toDeviceInbox := ts.DeviceInbox[aliceUserID]
Expand Down Expand Up @@ -283,8 +284,8 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
expectedVerificationMethods []event.VerificationMethod
}{
{false, false, newSASVerificationCallbacks(), newSASVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodSAS}},
{true, false, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, true, newQRCodeVerificationCallbacks(), newQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{true, false, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate}},
{false, true, newShowAndScanQRCodeVerificationCallbacks(), newShowAndScanQRCodeVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate}},
{true, false, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
{true, true, newAllVerificationCallbacks(), newAllVerificationCallbacks(), []event.VerificationMethod{event.VerificationMethodQRCodeShow, event.VerificationMethodQRCodeScan, event.VerificationMethodReciprocate, event.VerificationMethodSAS}},
}
Expand Down Expand Up @@ -321,10 +322,10 @@ func TestVerification_Accept_CorrectMethodsPresented(t *testing.T) {
err = receivingHelper.AcceptVerification(ctx, txnID)
require.NoError(t, err)

_, sendingIsQRCallbacks := tc.sendingCallbacks.(*qrCodeVerificationCallbacks)
_, sendingIsQRCallbacks := tc.sendingCallbacks.(*showQRCodeVerificationCallbacks)
_, sendingIsAllCallbacks := tc.sendingCallbacks.(*allVerificationCallbacks)
sendingCanShowQR := sendingIsQRCallbacks || sendingIsAllCallbacks
_, receivingIsQRCallbacks := tc.receivingCallbacks.(*qrCodeVerificationCallbacks)
_, receivingIsQRCallbacks := tc.receivingCallbacks.(*showQRCodeVerificationCallbacks)
_, receivingIsAllCallbacks := tc.receivingCallbacks.(*allVerificationCallbacks)
receivingCanShowQR := receivingIsQRCallbacks || receivingIsAllCallbacks

Expand Down

0 comments on commit 890db20

Please sign in to comment.