From cffbc41d224fe0de6cfc8c69211b43df9547de79 Mon Sep 17 00:00:00 2001 From: Michael Reber Date: Sat, 14 Feb 2026 19:14:27 +0100 Subject: [PATCH] Adding section comments for better readability --- cmd/server/main.go | 28 ++++++---------------- internal/auth/oidc.go | 49 +++++++++++++++++--------------------- internal/auth/session.go | 46 +++++++++++++++++------------------ internal/config/logging.go | 12 ++++++---- 4 files changed, 58 insertions(+), 77 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index d4bd356..6195676 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -33,8 +33,11 @@ import ( "github.com/swissmakers/fail2ban-ui/pkg/web" ) +// ========================================================================= +// Entrypoint +// ========================================================================= + func main() { - // Get application settings from the config package. settings := config.GetSettings() if err := storage.Init(""); err != nil { @@ -50,69 +53,54 @@ func main() { log.Fatalf("failed to initialise fail2ban connectors: %v", err) } - // Initialize OIDC authentication if enabled + // OIDC authentication (optional) oidcConfig, err := config.GetOIDCConfigFromEnv() if err != nil { log.Fatalf("failed to load OIDC configuration: %v", err) } if oidcConfig != nil && oidcConfig.Enabled { - // Initialize session secret if err := auth.InitializeSessionSecret(oidcConfig.SessionSecret); err != nil { log.Fatalf("failed to initialize session secret: %v", err) } - // Initialize OIDC client if _, err := auth.InitializeOIDC(oidcConfig); err != nil { log.Fatalf("failed to initialize OIDC: %v", err) } log.Println("OIDC authentication enabled") } - // Set Gin mode based on the debug flag in settings. if settings.Debug { gin.SetMode(gin.DebugMode) } else { gin.SetMode(gin.ReleaseMode) } - // Create a new Gin router. router := gin.Default() serverPort := strconv.Itoa(int(settings.Port)) - - // Get bind address from environment variable, defaulting to 0.0.0.0 bindAddress, _ := config.GetBindAddressFromEnv() serverAddr := net.JoinHostPort(bindAddress, serverPort) - // Load HTML templates depending on whether the application is running inside a container. + // Container vs local: different paths for templates and static assets _, container := os.LookupEnv("CONTAINER") if container { - // In container, templates are assumed to be in /app/templates router.LoadHTMLGlob("/app/templates/*") router.Static("/locales", "/app/locales") router.Static("/static", "/app/static") } else { - // When running locally, load templates from pkg/web/templates router.LoadHTMLGlob("pkg/web/templates/*") router.Static("/locales", "./internal/locales") router.Static("/static", "./pkg/web/static") } - // Initialize WebSocket hub + // WebSocket hub and console log capture wsHub := web.NewHub() go wsHub.Run() - - // Setup console log writer to capture stdout and send via WebSocket web.SetupConsoleLogWriter(wsHub) - // Update enabled state based on current settings web.UpdateConsoleLogEnabled() - // Register callback to update console log state when settings change config.SetUpdateConsoleLogStateFunc(func(enabled bool) { web.SetConsoleLogEnabled(enabled) }) - // Register all application routes, including the static files and templates. web.RegisterRoutes(router, wsHub) - - // Check if LOTR mode is active isLOTRMode := isLOTRModeActive(settings.AlertCountries) printWelcomeBanner(bindAddress, serverPort, isLOTRMode) if isLOTRMode { @@ -123,12 +111,10 @@ func main() { } log.Printf("Server listening on %s:%s.\n", bindAddress, serverPort) - // Start the server on the configured address and port. if err := router.Run(serverAddr); err != nil { log.Fatalf("Could not start server: %v\n", err) } } - func isLOTRModeActive(alertCountries []string) bool { if len(alertCountries) == 0 { return false diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index fdd1640..1297d2f 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -29,7 +29,10 @@ import ( "golang.org/x/oauth2" ) -// OIDCClient holds the OIDC provider, verifier, and OAuth2 configuration +// ========================================================================= +// Types +// ========================================================================= + type OIDCClient struct { Provider *oidc.Provider Verifier *oidc.IDTokenVerifier @@ -37,7 +40,6 @@ type OIDCClient struct { Config *config.OIDCConfig } -// UserInfo represents the authenticated user information type UserInfo struct { ID string Email string @@ -49,7 +51,10 @@ var ( oidcClient *OIDCClient ) -// contextWithSkipVerify returns a context with an HTTP client that skips TLS verification if enabled +// ========================================================================= +// Initialization +// ========================================================================= + func contextWithSkipVerify(ctx context.Context, skipVerify bool) context.Context { if !skipVerify { return ctx @@ -61,39 +66,32 @@ func contextWithSkipVerify(ctx context.Context, skipVerify bool) context.Context return oidc.ClientContext(ctx, client) } -// InitializeOIDC sets up the OIDC client from configuration func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) { if cfg == nil || !cfg.Enabled { return nil, nil } - // Retry OIDC provider discovery with exponential backoff - // This handles cases where the provider isn't ready yet (e.g., Keycloak starting up) + // Retry OIDC provider discovery with exponential backoff (e.g. because of Keycloak starting up) maxRetries := 10 retryDelay := 2 * time.Second var provider *oidc.Provider var err error for attempt := 0; attempt < maxRetries; attempt++ { - // Create context with timeout for each attempt ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx = contextWithSkipVerify(ctx, cfg.SkipVerify) - // Try to discover OIDC provider provider, err = oidc.NewProvider(ctx, cfg.IssuerURL) cancel() if err == nil { - // Success - provider discovered break } - - // Log retry attempt (but don't fail yet) config.DebugLog("OIDC provider discovery attempt %d/%d failed: %v, retrying in %v...", attempt+1, maxRetries, err, retryDelay) if attempt < maxRetries-1 { time.Sleep(retryDelay) - // Exponential backoff: increase delay for each retry + // Increases the delay for each retry (exponential backoff) retryDelay = time.Duration(float64(retryDelay) * 1.5) if retryDelay > 10*time.Second { retryDelay = 10 * time.Second @@ -105,7 +103,6 @@ func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) { return nil, fmt.Errorf("failed to discover OIDC provider after %d attempts: %w", maxRetries, err) } - // Create OAuth2 configuration oauth2Config := &oauth2.Config{ ClientID: cfg.ClientID, ClientSecret: cfg.ClientSecret, @@ -114,7 +111,6 @@ func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) { Scopes: cfg.Scopes, } - // Create ID token verifier verifier := provider.Verifier(&oidc.Config{ ClientID: cfg.ClientID, }) @@ -131,17 +127,18 @@ func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) { return oidcClient, nil } -// GetOIDCClient returns the initialized OIDC client +// ========================================================================= +// Public Accessors +// ========================================================================= + func GetOIDCClient() *OIDCClient { return oidcClient } -// IsEnabled returns whether OIDC is enabled func IsEnabled() bool { return oidcClient != nil && oidcClient.Config != nil && oidcClient.Config.Enabled } -// GetConfig returns the OIDC configuration func GetConfig() *config.OIDCConfig { if oidcClient == nil { return nil @@ -149,12 +146,16 @@ func GetConfig() *config.OIDCConfig { return oidcClient.Config } -// GetAuthURL generates the authorization URL for OIDC login +// ========================================================================= +// OAuth2 Flow +// ========================================================================= + +// Returns the OAuth2 authorization URL for the given state. func (c *OIDCClient) GetAuthURL(state string) string { return c.OAuth2Config.AuthCodeURL(state, oauth2.AccessTypeOffline) } -// ExchangeCode exchanges the authorization code for tokens +// Exchanges the authorization code for tokens. func (c *OIDCClient) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) { if c.OAuth2Config == nil { return nil, fmt.Errorf("OIDC client not properly initialized") @@ -169,7 +170,7 @@ func (c *OIDCClient) ExchangeCode(ctx context.Context, code string) (*oauth2.Tok return token, nil } -// VerifyToken verifies the ID token and extracts user information +// Verifies the ID token and extracts user information. func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*UserInfo, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok { @@ -177,13 +178,11 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use } ctx = contextWithSkipVerify(ctx, c.Config.SkipVerify) - // Verify the ID token idToken, err := c.Verifier.Verify(ctx, rawIDToken) if err != nil { return nil, fmt.Errorf("failed to verify ID token: %w", err) } - // Extract claims var claims struct { Subject string `json:"sub"` Email string `json:"email"` @@ -204,17 +203,15 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use Name: claims.Name, } - // Determine username based on configured claim switch c.Config.UsernameClaim { case "email": userInfo.Username = claims.Email case "preferred_username": userInfo.Username = claims.PreferredUsername if userInfo.Username == "" { - userInfo.Username = claims.Email // Fallback to email + userInfo.Username = claims.Email } default: - // Try to get the claim value dynamically var claimValue interface{} if err := idToken.Claims(&map[string]interface{}{ c.Config.UsernameClaim: &claimValue, @@ -231,7 +228,6 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use } } - // Fallback name construction if userInfo.Name == "" { if claims.GivenName != "" || claims.FamilyName != "" { userInfo.Name = fmt.Sprintf("%s %s", claims.GivenName, claims.FamilyName) @@ -241,6 +237,5 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use userInfo.Name = userInfo.Username } } - return userInfo, nil } diff --git a/internal/auth/session.go b/internal/auth/session.go index fe7bb61..3a6a04c 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -28,12 +28,10 @@ import ( "time" ) -const ( - sessionCookieName = "fail2ban_ui_session" - sessionKeyLength = 32 // AES-256 -) +// ========================================================================= +// Types and Constants +// ========================================================================= -// Session represents a user session type Session struct { UserID string `json:"userID"` Email string `json:"email"` @@ -42,22 +40,28 @@ type Session struct { ExpiresAt time.Time `json:"expiresAt"` } +const ( + sessionCookieName = "fail2ban_ui_session" + sessionKeyLength = 32 +) + var sessionSecret []byte -// InitializeSessionSecret initializes the session encryption secret +// ========================================================================= +// Session Management +// ========================================================================= + +// Initializes the encryption key for session cookies. func InitializeSessionSecret(secret string) error { if secret == "" { return fmt.Errorf("session secret cannot be empty") } - // Decode base64 secret or use directly if not base64 decoded, err := base64.URLEncoding.DecodeString(secret) if err != nil { - // Not base64, use as-is (but ensure it's 32 bytes for AES-256) if len(secret) < sessionKeyLength { return fmt.Errorf("session secret must be at least %d bytes", sessionKeyLength) } - // Use first 32 bytes sessionSecret = []byte(secret[:sessionKeyLength]) } else { if len(decoded) < sessionKeyLength { @@ -69,7 +73,7 @@ func InitializeSessionSecret(secret string) error { return nil } -// CreateSession creates a new encrypted session cookie +// Creates a session cookie with the user info. func CreateSession(w http.ResponseWriter, r *http.Request, userInfo *UserInfo, maxAge int) error { session := &Session{ UserID: userInfo.ID, @@ -79,29 +83,25 @@ func CreateSession(w http.ResponseWriter, r *http.Request, userInfo *UserInfo, m ExpiresAt: time.Now().Add(time.Duration(maxAge) * time.Second), } - // Serialize session to JSON sessionData, err := json.Marshal(session) if err != nil { return fmt.Errorf("failed to marshal session: %w", err) } - // Encrypt session data encrypted, err := encrypt(sessionData) if err != nil { return fmt.Errorf("failed to encrypt session: %w", err) } - // Determine if we're using HTTPS isSecure := r != nil && (r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https") - // Create secure cookie cookie := &http.Cookie{ Name: sessionCookieName, Value: encrypted, Path: "/", MaxAge: maxAge, HttpOnly: true, - Secure: isSecure, // Only secure over HTTPS + Secure: isSecure, SameSite: http.SameSiteLaxMode, } @@ -109,26 +109,23 @@ func CreateSession(w http.ResponseWriter, r *http.Request, userInfo *UserInfo, m return nil } -// GetSession retrieves and validates a session from the cookie +// Reads and validates the session cookie. func GetSession(r *http.Request) (*Session, error) { cookie, err := r.Cookie(sessionCookieName) if err != nil { return nil, fmt.Errorf("no session cookie: %w", err) } - // Decrypt session data decrypted, err := decrypt(cookie.Value) if err != nil { return nil, fmt.Errorf("failed to decrypt session: %w", err) } - // Deserialize session var session Session if err := json.Unmarshal(decrypted, &session); err != nil { return nil, fmt.Errorf("failed to unmarshal session: %w", err) } - // Check if session is expired if time.Now().After(session.ExpiresAt) { return nil, fmt.Errorf("session expired") } @@ -136,9 +133,8 @@ func GetSession(r *http.Request) (*Session, error) { return &session, nil } -// DeleteSession clears the session cookie +// Clears the session cookie. func DeleteSession(w http.ResponseWriter, r *http.Request) { - // Determine if we're using HTTPS isSecure := r != nil && (r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https") cookie := &http.Cookie{ @@ -147,13 +143,16 @@ func DeleteSession(w http.ResponseWriter, r *http.Request) { Path: "/", MaxAge: -1, HttpOnly: true, - Secure: isSecure, // Only secure over HTTPS + Secure: isSecure, SameSite: http.SameSiteLaxMode, } http.SetCookie(w, cookie) } -// encrypt encrypts data using AES-GCM +// ========================================================================= +// Encryption Helpers +// ========================================================================= + func encrypt(plaintext []byte) (string, error) { block, err := aes.NewCipher(sessionSecret) if err != nil { @@ -174,7 +173,6 @@ func encrypt(plaintext []byte) (string, error) { return base64.URLEncoding.EncodeToString(ciphertext), nil } -// decrypt decrypts data using AES-GCM func decrypt(ciphertext string) ([]byte, error) { data, err := base64.URLEncoding.DecodeString(ciphertext) if err != nil { diff --git a/internal/config/logging.go b/internal/config/logging.go index fcf8040..2fd69cc 100644 --- a/internal/config/logging.go +++ b/internal/config/logging.go @@ -20,18 +20,20 @@ import ( "log" ) -// DebugLog prints debug messages only if debug mode is enabled. +// ========================================================================= +// Debug Logging +// ========================================================================= + +// Prints debug messages if debug mode is enabled. func DebugLog(format string, v ...interface{}) { - // Avoid deadlocks by not calling GetSettings() inside DebugLog. debugEnabled := false debugEnabled = currentSettings.Debug if !debugEnabled { return } - // Ensure correct usage of fmt.Printf-style formatting if len(v) > 0 { - log.Printf(format, v...) // Uses format directives + log.Printf(format, v...) } else { - log.Println(format) // Just prints the message + log.Println(format) } }