Add a check to validate the origin of the WebSocket connection to prevent cross-origin hijacking

This commit is contained in:
2026-02-21 19:29:30 +01:00
parent 16112af26d
commit 9e1be2560c

View File

@@ -20,6 +20,8 @@ import (
"encoding/json"
"log"
"net/http"
"net/url"
"strings"
"sync"
"time"
@@ -55,9 +57,41 @@ const (
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
CheckOrigin: checkWSOrigin,
}
// =========================================================================
// WebSocket Origin Validation
// =========================================================================
// Checks the origin of the WebSocket connection to prevent cross-origin hijacking.
func checkWSOrigin(r *http.Request) bool {
origin := r.Header.Get("Origin")
if origin == "" {
return true
}
u, err := url.Parse(origin)
if err != nil {
log.Printf("WebSocket rejected: malformed Origin header %q", origin)
return false
}
reqHost := r.Host
if reqHost == "" {
reqHost = r.URL.Host
}
originHost := u.Host
if !strings.Contains(originHost, ":") && strings.Contains(reqHost, ":") {
if u.Scheme == "https" {
originHost += ":443"
} else {
originHost += ":80"
}
}
if !strings.EqualFold(originHost, reqHost) {
log.Printf("WebSocket rejected: origin %q does not match host %q", origin, reqHost)
return false
}
return true
},
}
// =========================================================================