From 50398a6d4da488769eea2375961d9b2610944fc1 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 24 Jul 2021 23:58:35 +0300 Subject: [PATCH] Handle OTK counts and device lists coming in through the transaction websocket --- appservice/http.go | 40 ++++++++++++++++++++++------------------ appservice/websocket.go | 7 +------ 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/appservice/http.go b/appservice/http.go index f11dd63d..ce4387c6 100644 --- a/appservice/http.go +++ b/appservice/http.go @@ -130,29 +130,33 @@ func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { Message: "Failed to parse body JSON", }.Write(w) } else { - if as.Registration.EphemeralEvents { - if txn.EphemeralEvents != nil { - as.handleEvents(txn.EphemeralEvents, event.EphemeralEventType) - } else if txn.MSC2409EphemeralEvents != nil { - as.handleEvents(txn.MSC2409EphemeralEvents, event.EphemeralEventType) - } - } - as.handleEvents(txn.Events, event.UnknownEventType) - if txn.DeviceLists != nil { - as.handleDeviceLists(txn.DeviceLists) - } else if txn.MSC3202DeviceLists != nil { - as.handleDeviceLists(txn.MSC3202DeviceLists) - } - if txn.DeviceOTKCount != nil { - as.handleOTKCounts(txn.DeviceOTKCount) - } else if txn.MSC3202DeviceOTKCount != nil { - as.handleOTKCounts(txn.MSC3202DeviceOTKCount) - } + as.handleTransaction(&txn) WriteBlankOK(w) } as.lastProcessedTransaction = txnID } +func (as *AppService) handleTransaction(txn *Transaction) { + if as.Registration.EphemeralEvents { + if txn.EphemeralEvents != nil { + as.handleEvents(txn.EphemeralEvents, event.EphemeralEventType) + } else if txn.MSC2409EphemeralEvents != nil { + as.handleEvents(txn.MSC2409EphemeralEvents, event.EphemeralEventType) + } + } + as.handleEvents(txn.Events, event.UnknownEventType) + if txn.DeviceLists != nil { + as.handleDeviceLists(txn.DeviceLists) + } else if txn.MSC3202DeviceLists != nil { + as.handleDeviceLists(txn.MSC3202DeviceLists) + } + if txn.DeviceOTKCount != nil { + as.handleOTKCounts(txn.DeviceOTKCount) + } else if txn.MSC3202DeviceOTKCount != nil { + as.handleOTKCounts(txn.MSC3202DeviceOTKCount) + } +} + func (as *AppService) handleOTKCounts(otks map[id.UserID]mautrix.OTKCount) { for userID, otkCounts := range otks { otkCounts.UserID = userID diff --git a/appservice/websocket.go b/appservice/websocket.go index e34715d7..35dcbc4f 100644 --- a/appservice/websocket.go +++ b/appservice/websocket.go @@ -18,8 +18,6 @@ import ( "sync/atomic" "github.com/gorilla/websocket" - - "maunium.net/go/mautrix/event" ) type WebsocketRequest struct { @@ -188,10 +186,7 @@ func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) return } if msg.Command == "" || msg.Command == "transaction" { - if as.Registration.EphemeralEvents && msg.EphemeralEvents != nil { - as.handleEvents(msg.EphemeralEvents, event.EphemeralEventType) - } - as.handleEvents(msg.Events, event.UnknownEventType) + as.handleTransaction(&msg.Transaction) } else if msg.Command == "connect" { as.Log.Debugln("Websocket connect confirmation received") } else if msg.Command == "response" || msg.Command == "error" {