From 9e1be2560c7aa86d0612492edc7d1a2ff99d9dcc Mon Sep 17 00:00:00 2001 From: Michael Reber Date: Sat, 21 Feb 2026 19:29:30 +0100 Subject: [PATCH] Add a check to validate the origin of the WebSocket connection to prevent cross-origin hijacking --- pkg/web/websocket.go | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/pkg/web/websocket.go b/pkg/web/websocket.go index 8ca526f..a2f572f 100644 --- a/pkg/web/websocket.go +++ b/pkg/web/websocket.go @@ -20,6 +20,8 @@ import ( "encoding/json" "log" "net/http" + "net/url" + "strings" "sync" "time" @@ -55,9 +57,41 @@ const ( var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { + CheckOrigin: checkWSOrigin, +} + +// ========================================================================= +// WebSocket Origin Validation +// ========================================================================= + +// Checks the origin of the WebSocket connection to prevent cross-origin hijacking. +func checkWSOrigin(r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin == "" { return true - }, + } + u, err := url.Parse(origin) + if err != nil { + log.Printf("WebSocket rejected: malformed Origin header %q", origin) + return false + } + reqHost := r.Host + if reqHost == "" { + reqHost = r.URL.Host + } + originHost := u.Host + if !strings.Contains(originHost, ":") && strings.Contains(reqHost, ":") { + if u.Scheme == "https" { + originHost += ":443" + } else { + originHost += ":80" + } + } + if !strings.EqualFold(originHost, reqHost) { + log.Printf("WebSocket rejected: origin %q does not match host %q", origin, reqHost) + return false + } + return true } // =========================================================================