From b6c383350ececa2534e1455ecd2bedd19242d517 Mon Sep 17 00:00:00 2001 From: Dmytro Vovk Date: Thu, 13 Feb 2025 08:13:27 +0000 Subject: [PATCH] Created websocket server (#104) * updated ui * added websocket server for ui connections --- api/api_handler.go | 1 + api/bridge_handler.go | 3 - api/internal/end_points.go | 1 + api/websocket_handler.go | 191 ++++++++++++++++++++++++++ cmd/diagnostics/main.go | 14 +- internal/erigon_node/erigon_client.go | 3 + internal/erigon_node/subscribe.go | 53 +++++++ 7 files changed, 258 insertions(+), 8 deletions(-) create mode 100644 api/websocket_handler.go create mode 100644 internal/erigon_node/subscribe.go diff --git a/api/api_handler.go b/api/api_handler.go index 2c763a4..cc3e3b0 100644 --- a/api/api_handler.go +++ b/api/api_handler.go @@ -309,6 +309,7 @@ func NewAPIHandler( r.Get("/sessions/{sessionId}", r.GetSession) // Erigon Node data + r.Get("/v2/sessions/{sessionId}/nodes/{nodeId}/ws", r.HandleWebSocket) r.Get("/sessions/{sessionId}/nodes/{nodeId}/logs/{file}", r.Log) r.Get("/sessions/{sessionId}/nodes/{nodeId}/dbs/*", r.Tables) r.Get("/sessions/{sessionId}/nodes/{nodeId}/reorgs", r.ReOrg) diff --git a/api/bridge_handler.go b/api/bridge_handler.go index 1c06623..8a02414 100644 --- a/api/bridge_handler.go +++ b/api/bridge_handler.go @@ -37,7 +37,6 @@ const ( var wsBufferPool = new(sync.Pool) func (h BridgeHandler) Bridge(w http.ResponseWriter, r *http.Request) { - //Sends a success Message to the Node client, to receive more information ctx, cancel := context.WithCancel(r.Context()) defer cancel() @@ -131,8 +130,6 @@ func (h BridgeHandler) Bridge(w http.ResponseWriter, r *http.Request) { continue } - //fmt.Printf("Sending request %s\n", string(bytes)) - requestMutex.Lock() requestMap[rpcRequest.Id] = request requestMutex.Unlock() diff --git a/api/internal/end_points.go b/api/internal/end_points.go index 7cfb096..5f2a067 100644 --- a/api/internal/end_points.go +++ b/api/internal/end_points.go @@ -3,4 +3,5 @@ package internal const ( HealthCheckEndPoint = "/healthcheck" BridgeEndPoint = "/bridge" + WSEndPoint = "/ws" ) diff --git a/api/websocket_handler.go b/api/websocket_handler.go new file mode 100644 index 0000000..234c32b --- /dev/null +++ b/api/websocket_handler.go @@ -0,0 +1,191 @@ +package api + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + ActionSubscribe = "subscribe" + ActionUnsubscribe = "unsubscribe" +) + +// SubscriptionResponse is the response sent back to the client after an action is processed. +type ClientResponse struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` + Data *string `json:"data,omitempty"` +} + +type WebsocketHandler struct { + mu sync.Mutex + writeQueue chan []byte + conn *websocket.Conn + closeChan chan struct{} +} + +// **NewWebsocketHandler initializes WebsocketHandler** +func NewWebsocketHandler(conn *websocket.Conn) *WebsocketHandler { + handler := &WebsocketHandler{ + writeQueue: make(chan []byte, 100), + conn: conn, + closeChan: make(chan struct{}), + } + + go handler.startWriter() // Start dedicated writer goroutine + return handler +} + +// **Sends response safely** +func (h *WebsocketHandler) sendResponse(response *ClientResponse) { + resp, err := json.Marshal(response) + if err != nil { + fmt.Printf("Error marshaling response: %v\n", err) + return + } + + select { + case h.writeQueue <- resp: + default: + fmt.Println("Warning: writeQueue is full, dropping message") + } +} + +// **Dedicated writer goroutine** +func (h *WebsocketHandler) startWriter() { + for { + select { + case msg := <-h.writeQueue: + h.mu.Lock() + err := h.conn.WriteMessage(websocket.TextMessage, msg) + h.mu.Unlock() + + if err != nil { + fmt.Printf("Error writing response: %v\n", err) + return + } + + case <-h.closeChan: + fmt.Println("Writer goroutine stopped") + return + } + } +} + +// **Close WebSocket connection and stop writer** +func (h *WebsocketHandler) closeConnection() { + close(h.closeChan) + h.conn.Close() +} + +// **WebSocket handler function** +func (h *APIHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + type wsMessage struct { + Service string `json:"service"` + Action string `json:"action"` + } + + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Println("WebSocket upgrade failed:", err) + return + } + defer conn.Close() + + handler := NewWebsocketHandler(conn) + defer handler.closeConnection() + + channel := make(chan []byte) + + // **Goroutine to forward messages from the channel to the client** + go func() { + for { + select { + case <-r.Context().Done(): + return + case <-handler.closeChan: // Graceful shutdown + return + case message := <-channel: + handler.sendResponse(&ClientResponse{ + Status: "success", + Message: string(message), + }) + } + } + }() + + // **Enable Ping/Pong Handling** + conn.SetPongHandler(func(appData string) error { + return nil + }) + + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + + for { + select { + case <-handler.closeChan: + return + case <-ticker.C: + handler.mu.Lock() + err := conn.WriteMessage(websocket.PingMessage, nil) + handler.mu.Unlock() + + if err != nil { + fmt.Println("Ping failed, closing connection:", err) + handler.closeConnection() + return + } + } + } + }() + + for { + _, msg, err := conn.ReadMessage() + if err != nil { + fmt.Println("Error reading message:", err) + break + } + fmt.Printf("Received: %s\n", msg) + + client, err := h.findNodeClient(r) + if err != nil { + handler.sendResponse(&ClientResponse{ + Status: "error", + Message: "Client not found: " + err.Error(), + }) + return + } + + var inMsg wsMessage + if err := json.Unmarshal(msg, &inMsg); err != nil { + handler.sendResponse(&ClientResponse{ + Status: "error", + Message: "Invalid JSON: " + err.Error(), + }) + continue + } + + switch inMsg.Action { + case ActionSubscribe: + go client.Subscribe(r.Context(), channel, inMsg.Service) + case ActionUnsubscribe: + client.Unsubscribe(r.Context(), channel, inMsg.Service) + default: + handler.sendResponse(&ClientResponse{ + Status: "error", + Message: "Unknown action " + inMsg.Action, + }) + } + } +} diff --git a/cmd/diagnostics/main.go b/cmd/diagnostics/main.go index 202fa87..2c7fb64 100644 --- a/cmd/diagnostics/main.go +++ b/cmd/diagnostics/main.go @@ -59,11 +59,7 @@ func main() { } }() - packagePath := "github.com/erigontech/erigonwatch" - version, err := GetPackageVersion(packagePath) - if err == nil { - fmt.Printf("Diagnostics version: %s\n", version) - } + printUIVersion() fmt.Printf("Diagnostics UI is running on http://%s:%d\n", listenAddr, listenPort) //open(fmt.Sprintf("http://%s:%d", listenAddr, listenPort)) @@ -81,6 +77,14 @@ func main() { } } +func printUIVersion() { + packagePath := "github.com/erigontech/erigonwatch" + version, err := GetPackageVersion(packagePath) + if err == nil { + fmt.Printf("Diagnostics version: %s\n", version) + } +} + // open opens the specified URL in the default browser of the user. /*func open(url string) error { var cmd string diff --git a/internal/erigon_node/erigon_client.go b/internal/erigon_node/erigon_client.go index 1532b90..3cbe4c0 100644 --- a/internal/erigon_node/erigon_client.go +++ b/internal/erigon_node/erigon_client.go @@ -96,4 +96,7 @@ type Client interface { FindProfile(ctx context.Context, profile string) ([]byte, error) fetch(ctx context.Context, method string, params url.Values) (*NodeRequest, error) + + Subscribe(ctx context.Context, channel chan []byte, service string) error + Unsubscribe(ctx context.Context, channel chan []byte, service string) error } diff --git a/internal/erigon_node/subscribe.go b/internal/erigon_node/subscribe.go new file mode 100644 index 0000000..5cf6795 --- /dev/null +++ b/internal/erigon_node/subscribe.go @@ -0,0 +1,53 @@ +package erigon_node + +import ( + "context" +) + +func (c *NodeClient) Subscribe(ctx context.Context, channel chan []byte, service string) error { + request, err := c.fetch(ctx, "subscribe/"+service, nil) + + if err != nil { + return err + } + + for { + more, result, err := request.nextResult(ctx) + + if err != nil { + return err + } + + channel <- result + + if !more { + break + } + } + + return nil +} + +func (c *NodeClient) Unsubscribe(ctx context.Context, channel chan []byte, service string) error { + request, err := c.fetch(ctx, "unsubscribe/"+service, nil) + + if err != nil { + return err + } + + for { + more, result, err := request.nextResult(ctx) + + if err != nil { + return err + } + + channel <- result + + if !more { + break + } + } + + return nil +}