mirror of
https://github.com/swissmakers/fail2ban-ui.git
synced 2026-04-19 06:53:14 +02:00
Adding section comments for better readability
This commit is contained in:
@@ -33,8 +33,11 @@ import (
|
|||||||
"github.com/swissmakers/fail2ban-ui/pkg/web"
|
"github.com/swissmakers/fail2ban-ui/pkg/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Entrypoint
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// Get application settings from the config package.
|
|
||||||
settings := config.GetSettings()
|
settings := config.GetSettings()
|
||||||
|
|
||||||
if err := storage.Init(""); err != nil {
|
if err := storage.Init(""); err != nil {
|
||||||
@@ -50,69 +53,54 @@ func main() {
|
|||||||
log.Fatalf("failed to initialise fail2ban connectors: %v", err)
|
log.Fatalf("failed to initialise fail2ban connectors: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize OIDC authentication if enabled
|
// OIDC authentication (optional)
|
||||||
oidcConfig, err := config.GetOIDCConfigFromEnv()
|
oidcConfig, err := config.GetOIDCConfigFromEnv()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to load OIDC configuration: %v", err)
|
log.Fatalf("failed to load OIDC configuration: %v", err)
|
||||||
}
|
}
|
||||||
if oidcConfig != nil && oidcConfig.Enabled {
|
if oidcConfig != nil && oidcConfig.Enabled {
|
||||||
// Initialize session secret
|
|
||||||
if err := auth.InitializeSessionSecret(oidcConfig.SessionSecret); err != nil {
|
if err := auth.InitializeSessionSecret(oidcConfig.SessionSecret); err != nil {
|
||||||
log.Fatalf("failed to initialize session secret: %v", err)
|
log.Fatalf("failed to initialize session secret: %v", err)
|
||||||
}
|
}
|
||||||
// Initialize OIDC client
|
|
||||||
if _, err := auth.InitializeOIDC(oidcConfig); err != nil {
|
if _, err := auth.InitializeOIDC(oidcConfig); err != nil {
|
||||||
log.Fatalf("failed to initialize OIDC: %v", err)
|
log.Fatalf("failed to initialize OIDC: %v", err)
|
||||||
}
|
}
|
||||||
log.Println("OIDC authentication enabled")
|
log.Println("OIDC authentication enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set Gin mode based on the debug flag in settings.
|
|
||||||
if settings.Debug {
|
if settings.Debug {
|
||||||
gin.SetMode(gin.DebugMode)
|
gin.SetMode(gin.DebugMode)
|
||||||
} else {
|
} else {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new Gin router.
|
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
serverPort := strconv.Itoa(int(settings.Port))
|
serverPort := strconv.Itoa(int(settings.Port))
|
||||||
|
|
||||||
// Get bind address from environment variable, defaulting to 0.0.0.0
|
|
||||||
bindAddress, _ := config.GetBindAddressFromEnv()
|
bindAddress, _ := config.GetBindAddressFromEnv()
|
||||||
serverAddr := net.JoinHostPort(bindAddress, serverPort)
|
serverAddr := net.JoinHostPort(bindAddress, serverPort)
|
||||||
|
|
||||||
// Load HTML templates depending on whether the application is running inside a container.
|
// Container vs local: different paths for templates and static assets
|
||||||
_, container := os.LookupEnv("CONTAINER")
|
_, container := os.LookupEnv("CONTAINER")
|
||||||
if container {
|
if container {
|
||||||
// In container, templates are assumed to be in /app/templates
|
|
||||||
router.LoadHTMLGlob("/app/templates/*")
|
router.LoadHTMLGlob("/app/templates/*")
|
||||||
router.Static("/locales", "/app/locales")
|
router.Static("/locales", "/app/locales")
|
||||||
router.Static("/static", "/app/static")
|
router.Static("/static", "/app/static")
|
||||||
} else {
|
} else {
|
||||||
// When running locally, load templates from pkg/web/templates
|
|
||||||
router.LoadHTMLGlob("pkg/web/templates/*")
|
router.LoadHTMLGlob("pkg/web/templates/*")
|
||||||
router.Static("/locales", "./internal/locales")
|
router.Static("/locales", "./internal/locales")
|
||||||
router.Static("/static", "./pkg/web/static")
|
router.Static("/static", "./pkg/web/static")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize WebSocket hub
|
// WebSocket hub and console log capture
|
||||||
wsHub := web.NewHub()
|
wsHub := web.NewHub()
|
||||||
go wsHub.Run()
|
go wsHub.Run()
|
||||||
|
|
||||||
// Setup console log writer to capture stdout and send via WebSocket
|
|
||||||
web.SetupConsoleLogWriter(wsHub)
|
web.SetupConsoleLogWriter(wsHub)
|
||||||
// Update enabled state based on current settings
|
|
||||||
web.UpdateConsoleLogEnabled()
|
web.UpdateConsoleLogEnabled()
|
||||||
// Register callback to update console log state when settings change
|
|
||||||
config.SetUpdateConsoleLogStateFunc(func(enabled bool) {
|
config.SetUpdateConsoleLogStateFunc(func(enabled bool) {
|
||||||
web.SetConsoleLogEnabled(enabled)
|
web.SetConsoleLogEnabled(enabled)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Register all application routes, including the static files and templates.
|
|
||||||
web.RegisterRoutes(router, wsHub)
|
web.RegisterRoutes(router, wsHub)
|
||||||
|
|
||||||
// Check if LOTR mode is active
|
|
||||||
isLOTRMode := isLOTRModeActive(settings.AlertCountries)
|
isLOTRMode := isLOTRModeActive(settings.AlertCountries)
|
||||||
printWelcomeBanner(bindAddress, serverPort, isLOTRMode)
|
printWelcomeBanner(bindAddress, serverPort, isLOTRMode)
|
||||||
if isLOTRMode {
|
if isLOTRMode {
|
||||||
@@ -123,12 +111,10 @@ func main() {
|
|||||||
}
|
}
|
||||||
log.Printf("Server listening on %s:%s.\n", bindAddress, serverPort)
|
log.Printf("Server listening on %s:%s.\n", bindAddress, serverPort)
|
||||||
|
|
||||||
// Start the server on the configured address and port.
|
|
||||||
if err := router.Run(serverAddr); err != nil {
|
if err := router.Run(serverAddr); err != nil {
|
||||||
log.Fatalf("Could not start server: %v\n", err)
|
log.Fatalf("Could not start server: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLOTRModeActive(alertCountries []string) bool {
|
func isLOTRModeActive(alertCountries []string) bool {
|
||||||
if len(alertCountries) == 0 {
|
if len(alertCountries) == 0 {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -29,7 +29,10 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OIDCClient holds the OIDC provider, verifier, and OAuth2 configuration
|
// =========================================================================
|
||||||
|
// Types
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
type OIDCClient struct {
|
type OIDCClient struct {
|
||||||
Provider *oidc.Provider
|
Provider *oidc.Provider
|
||||||
Verifier *oidc.IDTokenVerifier
|
Verifier *oidc.IDTokenVerifier
|
||||||
@@ -37,7 +40,6 @@ type OIDCClient struct {
|
|||||||
Config *config.OIDCConfig
|
Config *config.OIDCConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserInfo represents the authenticated user information
|
|
||||||
type UserInfo struct {
|
type UserInfo struct {
|
||||||
ID string
|
ID string
|
||||||
Email string
|
Email string
|
||||||
@@ -49,7 +51,10 @@ var (
|
|||||||
oidcClient *OIDCClient
|
oidcClient *OIDCClient
|
||||||
)
|
)
|
||||||
|
|
||||||
// contextWithSkipVerify returns a context with an HTTP client that skips TLS verification if enabled
|
// =========================================================================
|
||||||
|
// Initialization
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
func contextWithSkipVerify(ctx context.Context, skipVerify bool) context.Context {
|
func contextWithSkipVerify(ctx context.Context, skipVerify bool) context.Context {
|
||||||
if !skipVerify {
|
if !skipVerify {
|
||||||
return ctx
|
return ctx
|
||||||
@@ -61,39 +66,32 @@ func contextWithSkipVerify(ctx context.Context, skipVerify bool) context.Context
|
|||||||
return oidc.ClientContext(ctx, client)
|
return oidc.ClientContext(ctx, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitializeOIDC sets up the OIDC client from configuration
|
|
||||||
func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) {
|
func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) {
|
||||||
if cfg == nil || !cfg.Enabled {
|
if cfg == nil || !cfg.Enabled {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retry OIDC provider discovery with exponential backoff
|
// Retry OIDC provider discovery with exponential backoff (e.g. because of Keycloak starting up)
|
||||||
// This handles cases where the provider isn't ready yet (e.g., Keycloak starting up)
|
|
||||||
maxRetries := 10
|
maxRetries := 10
|
||||||
retryDelay := 2 * time.Second
|
retryDelay := 2 * time.Second
|
||||||
var provider *oidc.Provider
|
var provider *oidc.Provider
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
// Create context with timeout for each attempt
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
ctx = contextWithSkipVerify(ctx, cfg.SkipVerify)
|
ctx = contextWithSkipVerify(ctx, cfg.SkipVerify)
|
||||||
|
|
||||||
// Try to discover OIDC provider
|
|
||||||
provider, err = oidc.NewProvider(ctx, cfg.IssuerURL)
|
provider, err = oidc.NewProvider(ctx, cfg.IssuerURL)
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// Success - provider discovered
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// Log retry attempt (but don't fail yet)
|
|
||||||
config.DebugLog("OIDC provider discovery attempt %d/%d failed: %v, retrying in %v...", attempt+1, maxRetries, err, retryDelay)
|
config.DebugLog("OIDC provider discovery attempt %d/%d failed: %v, retrying in %v...", attempt+1, maxRetries, err, retryDelay)
|
||||||
|
|
||||||
if attempt < maxRetries-1 {
|
if attempt < maxRetries-1 {
|
||||||
time.Sleep(retryDelay)
|
time.Sleep(retryDelay)
|
||||||
// Exponential backoff: increase delay for each retry
|
// Increases the delay for each retry (exponential backoff)
|
||||||
retryDelay = time.Duration(float64(retryDelay) * 1.5)
|
retryDelay = time.Duration(float64(retryDelay) * 1.5)
|
||||||
if retryDelay > 10*time.Second {
|
if retryDelay > 10*time.Second {
|
||||||
retryDelay = 10 * time.Second
|
retryDelay = 10 * time.Second
|
||||||
@@ -105,7 +103,6 @@ func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) {
|
|||||||
return nil, fmt.Errorf("failed to discover OIDC provider after %d attempts: %w", maxRetries, err)
|
return nil, fmt.Errorf("failed to discover OIDC provider after %d attempts: %w", maxRetries, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create OAuth2 configuration
|
|
||||||
oauth2Config := &oauth2.Config{
|
oauth2Config := &oauth2.Config{
|
||||||
ClientID: cfg.ClientID,
|
ClientID: cfg.ClientID,
|
||||||
ClientSecret: cfg.ClientSecret,
|
ClientSecret: cfg.ClientSecret,
|
||||||
@@ -114,7 +111,6 @@ func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) {
|
|||||||
Scopes: cfg.Scopes,
|
Scopes: cfg.Scopes,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create ID token verifier
|
|
||||||
verifier := provider.Verifier(&oidc.Config{
|
verifier := provider.Verifier(&oidc.Config{
|
||||||
ClientID: cfg.ClientID,
|
ClientID: cfg.ClientID,
|
||||||
})
|
})
|
||||||
@@ -131,17 +127,18 @@ func InitializeOIDC(cfg *config.OIDCConfig) (*OIDCClient, error) {
|
|||||||
return oidcClient, nil
|
return oidcClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetOIDCClient returns the initialized OIDC client
|
// =========================================================================
|
||||||
|
// Public Accessors
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
func GetOIDCClient() *OIDCClient {
|
func GetOIDCClient() *OIDCClient {
|
||||||
return oidcClient
|
return oidcClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEnabled returns whether OIDC is enabled
|
|
||||||
func IsEnabled() bool {
|
func IsEnabled() bool {
|
||||||
return oidcClient != nil && oidcClient.Config != nil && oidcClient.Config.Enabled
|
return oidcClient != nil && oidcClient.Config != nil && oidcClient.Config.Enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConfig returns the OIDC configuration
|
|
||||||
func GetConfig() *config.OIDCConfig {
|
func GetConfig() *config.OIDCConfig {
|
||||||
if oidcClient == nil {
|
if oidcClient == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -149,12 +146,16 @@ func GetConfig() *config.OIDCConfig {
|
|||||||
return oidcClient.Config
|
return oidcClient.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAuthURL generates the authorization URL for OIDC login
|
// =========================================================================
|
||||||
|
// OAuth2 Flow
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
// Returns the OAuth2 authorization URL for the given state.
|
||||||
func (c *OIDCClient) GetAuthURL(state string) string {
|
func (c *OIDCClient) GetAuthURL(state string) string {
|
||||||
return c.OAuth2Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
return c.OAuth2Config.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExchangeCode exchanges the authorization code for tokens
|
// Exchanges the authorization code for tokens.
|
||||||
func (c *OIDCClient) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
|
func (c *OIDCClient) ExchangeCode(ctx context.Context, code string) (*oauth2.Token, error) {
|
||||||
if c.OAuth2Config == nil {
|
if c.OAuth2Config == nil {
|
||||||
return nil, fmt.Errorf("OIDC client not properly initialized")
|
return nil, fmt.Errorf("OIDC client not properly initialized")
|
||||||
@@ -169,7 +170,7 @@ func (c *OIDCClient) ExchangeCode(ctx context.Context, code string) (*oauth2.Tok
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyToken verifies the ID token and extracts user information
|
// Verifies the ID token and extracts user information.
|
||||||
func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*UserInfo, error) {
|
func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*UserInfo, error) {
|
||||||
rawIDToken, ok := token.Extra("id_token").(string)
|
rawIDToken, ok := token.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -177,13 +178,11 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx = contextWithSkipVerify(ctx, c.Config.SkipVerify)
|
ctx = contextWithSkipVerify(ctx, c.Config.SkipVerify)
|
||||||
// Verify the ID token
|
|
||||||
idToken, err := c.Verifier.Verify(ctx, rawIDToken)
|
idToken, err := c.Verifier.Verify(ctx, rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to verify ID token: %w", err)
|
return nil, fmt.Errorf("failed to verify ID token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract claims
|
|
||||||
var claims struct {
|
var claims struct {
|
||||||
Subject string `json:"sub"`
|
Subject string `json:"sub"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@@ -204,17 +203,15 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use
|
|||||||
Name: claims.Name,
|
Name: claims.Name,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine username based on configured claim
|
|
||||||
switch c.Config.UsernameClaim {
|
switch c.Config.UsernameClaim {
|
||||||
case "email":
|
case "email":
|
||||||
userInfo.Username = claims.Email
|
userInfo.Username = claims.Email
|
||||||
case "preferred_username":
|
case "preferred_username":
|
||||||
userInfo.Username = claims.PreferredUsername
|
userInfo.Username = claims.PreferredUsername
|
||||||
if userInfo.Username == "" {
|
if userInfo.Username == "" {
|
||||||
userInfo.Username = claims.Email // Fallback to email
|
userInfo.Username = claims.Email
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
// Try to get the claim value dynamically
|
|
||||||
var claimValue interface{}
|
var claimValue interface{}
|
||||||
if err := idToken.Claims(&map[string]interface{}{
|
if err := idToken.Claims(&map[string]interface{}{
|
||||||
c.Config.UsernameClaim: &claimValue,
|
c.Config.UsernameClaim: &claimValue,
|
||||||
@@ -231,7 +228,6 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback name construction
|
|
||||||
if userInfo.Name == "" {
|
if userInfo.Name == "" {
|
||||||
if claims.GivenName != "" || claims.FamilyName != "" {
|
if claims.GivenName != "" || claims.FamilyName != "" {
|
||||||
userInfo.Name = fmt.Sprintf("%s %s", claims.GivenName, claims.FamilyName)
|
userInfo.Name = fmt.Sprintf("%s %s", claims.GivenName, claims.FamilyName)
|
||||||
@@ -241,6 +237,5 @@ func (c *OIDCClient) VerifyToken(ctx context.Context, token *oauth2.Token) (*Use
|
|||||||
userInfo.Name = userInfo.Username
|
userInfo.Name = userInfo.Username
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return userInfo, nil
|
return userInfo, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,12 +28,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// =========================================================================
|
||||||
sessionCookieName = "fail2ban_ui_session"
|
// Types and Constants
|
||||||
sessionKeyLength = 32 // AES-256
|
// =========================================================================
|
||||||
)
|
|
||||||
|
|
||||||
// Session represents a user session
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
UserID string `json:"userID"`
|
UserID string `json:"userID"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
@@ -42,22 +40,28 @@ type Session struct {
|
|||||||
ExpiresAt time.Time `json:"expiresAt"`
|
ExpiresAt time.Time `json:"expiresAt"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
sessionCookieName = "fail2ban_ui_session"
|
||||||
|
sessionKeyLength = 32
|
||||||
|
)
|
||||||
|
|
||||||
var sessionSecret []byte
|
var sessionSecret []byte
|
||||||
|
|
||||||
// InitializeSessionSecret initializes the session encryption secret
|
// =========================================================================
|
||||||
|
// Session Management
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
// Initializes the encryption key for session cookies.
|
||||||
func InitializeSessionSecret(secret string) error {
|
func InitializeSessionSecret(secret string) error {
|
||||||
if secret == "" {
|
if secret == "" {
|
||||||
return fmt.Errorf("session secret cannot be empty")
|
return fmt.Errorf("session secret cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decode base64 secret or use directly if not base64
|
|
||||||
decoded, err := base64.URLEncoding.DecodeString(secret)
|
decoded, err := base64.URLEncoding.DecodeString(secret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Not base64, use as-is (but ensure it's 32 bytes for AES-256)
|
|
||||||
if len(secret) < sessionKeyLength {
|
if len(secret) < sessionKeyLength {
|
||||||
return fmt.Errorf("session secret must be at least %d bytes", sessionKeyLength)
|
return fmt.Errorf("session secret must be at least %d bytes", sessionKeyLength)
|
||||||
}
|
}
|
||||||
// Use first 32 bytes
|
|
||||||
sessionSecret = []byte(secret[:sessionKeyLength])
|
sessionSecret = []byte(secret[:sessionKeyLength])
|
||||||
} else {
|
} else {
|
||||||
if len(decoded) < sessionKeyLength {
|
if len(decoded) < sessionKeyLength {
|
||||||
@@ -69,7 +73,7 @@ func InitializeSessionSecret(secret string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateSession creates a new encrypted session cookie
|
// Creates a session cookie with the user info.
|
||||||
func CreateSession(w http.ResponseWriter, r *http.Request, userInfo *UserInfo, maxAge int) error {
|
func CreateSession(w http.ResponseWriter, r *http.Request, userInfo *UserInfo, maxAge int) error {
|
||||||
session := &Session{
|
session := &Session{
|
||||||
UserID: userInfo.ID,
|
UserID: userInfo.ID,
|
||||||
@@ -79,29 +83,25 @@ func CreateSession(w http.ResponseWriter, r *http.Request, userInfo *UserInfo, m
|
|||||||
ExpiresAt: time.Now().Add(time.Duration(maxAge) * time.Second),
|
ExpiresAt: time.Now().Add(time.Duration(maxAge) * time.Second),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize session to JSON
|
|
||||||
sessionData, err := json.Marshal(session)
|
sessionData, err := json.Marshal(session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to marshal session: %w", err)
|
return fmt.Errorf("failed to marshal session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encrypt session data
|
|
||||||
encrypted, err := encrypt(sessionData)
|
encrypted, err := encrypt(sessionData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to encrypt session: %w", err)
|
return fmt.Errorf("failed to encrypt session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if we're using HTTPS
|
|
||||||
isSecure := r != nil && (r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https")
|
isSecure := r != nil && (r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https")
|
||||||
|
|
||||||
// Create secure cookie
|
|
||||||
cookie := &http.Cookie{
|
cookie := &http.Cookie{
|
||||||
Name: sessionCookieName,
|
Name: sessionCookieName,
|
||||||
Value: encrypted,
|
Value: encrypted,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: maxAge,
|
MaxAge: maxAge,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: isSecure, // Only secure over HTTPS
|
Secure: isSecure,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -109,26 +109,23 @@ func CreateSession(w http.ResponseWriter, r *http.Request, userInfo *UserInfo, m
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSession retrieves and validates a session from the cookie
|
// Reads and validates the session cookie.
|
||||||
func GetSession(r *http.Request) (*Session, error) {
|
func GetSession(r *http.Request) (*Session, error) {
|
||||||
cookie, err := r.Cookie(sessionCookieName)
|
cookie, err := r.Cookie(sessionCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt session data
|
|
||||||
decrypted, err := decrypt(cookie.Value)
|
decrypted, err := decrypt(cookie.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to decrypt session: %w", err)
|
return nil, fmt.Errorf("failed to decrypt session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deserialize session
|
|
||||||
var session Session
|
var session Session
|
||||||
if err := json.Unmarshal(decrypted, &session); err != nil {
|
if err := json.Unmarshal(decrypted, &session); err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if session is expired
|
|
||||||
if time.Now().After(session.ExpiresAt) {
|
if time.Now().After(session.ExpiresAt) {
|
||||||
return nil, fmt.Errorf("session expired")
|
return nil, fmt.Errorf("session expired")
|
||||||
}
|
}
|
||||||
@@ -136,9 +133,8 @@ func GetSession(r *http.Request) (*Session, error) {
|
|||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSession clears the session cookie
|
// Clears the session cookie.
|
||||||
func DeleteSession(w http.ResponseWriter, r *http.Request) {
|
func DeleteSession(w http.ResponseWriter, r *http.Request) {
|
||||||
// Determine if we're using HTTPS
|
|
||||||
isSecure := r != nil && (r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https")
|
isSecure := r != nil && (r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https")
|
||||||
|
|
||||||
cookie := &http.Cookie{
|
cookie := &http.Cookie{
|
||||||
@@ -147,13 +143,16 @@ func DeleteSession(w http.ResponseWriter, r *http.Request) {
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: isSecure, // Only secure over HTTPS
|
Secure: isSecure,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
}
|
}
|
||||||
http.SetCookie(w, cookie)
|
http.SetCookie(w, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
// encrypt encrypts data using AES-GCM
|
// =========================================================================
|
||||||
|
// Encryption Helpers
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
func encrypt(plaintext []byte) (string, error) {
|
func encrypt(plaintext []byte) (string, error) {
|
||||||
block, err := aes.NewCipher(sessionSecret)
|
block, err := aes.NewCipher(sessionSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -174,7 +173,6 @@ func encrypt(plaintext []byte) (string, error) {
|
|||||||
return base64.URLEncoding.EncodeToString(ciphertext), nil
|
return base64.URLEncoding.EncodeToString(ciphertext), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decrypt decrypts data using AES-GCM
|
|
||||||
func decrypt(ciphertext string) ([]byte, error) {
|
func decrypt(ciphertext string) ([]byte, error) {
|
||||||
data, err := base64.URLEncoding.DecodeString(ciphertext)
|
data, err := base64.URLEncoding.DecodeString(ciphertext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -20,18 +20,20 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DebugLog prints debug messages only if debug mode is enabled.
|
// =========================================================================
|
||||||
|
// Debug Logging
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
// Prints debug messages if debug mode is enabled.
|
||||||
func DebugLog(format string, v ...interface{}) {
|
func DebugLog(format string, v ...interface{}) {
|
||||||
// Avoid deadlocks by not calling GetSettings() inside DebugLog.
|
|
||||||
debugEnabled := false
|
debugEnabled := false
|
||||||
debugEnabled = currentSettings.Debug
|
debugEnabled = currentSettings.Debug
|
||||||
if !debugEnabled {
|
if !debugEnabled {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Ensure correct usage of fmt.Printf-style formatting
|
|
||||||
if len(v) > 0 {
|
if len(v) > 0 {
|
||||||
log.Printf(format, v...) // Uses format directives
|
log.Printf(format, v...)
|
||||||
} else {
|
} else {
|
||||||
log.Println(format) // Just prints the message
|
log.Println(format)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user