package storage import ( "context" "database/sql" "errors" "fmt" "os" "path/filepath" "strings" "sync" "time" _ "modernc.org/sqlite" ) var ( db *sql.DB initOnce sync.Once initErr error defaultPath = "fail2ban-ui.db" ) func boolToInt(b bool) int { if b { return 1 } return 0 } func intToBool(i int) bool { return i != 0 } func stringFromNull(ns sql.NullString) string { if ns.Valid { return ns.String } return "" } func intFromNull(ni sql.NullInt64) int { if ni.Valid { return int(ni.Int64) } return 0 } type AppSettingsRecord struct { Language string Port int Debug bool CallbackURL string RestartNeeded bool AlertCountriesJSON string SMTPHost string SMTPPort int SMTPUsername string SMTPPassword string SMTPFrom string SMTPUseTLS bool BantimeIncrement bool IgnoreIP string Bantime string Findtime string MaxRetry int DestEmail string } type ServerRecord struct { ID string Name string Type string Host string Port int SocketPath string LogPath string SSHUser string SSHKeyPath string AgentURL string AgentSecret string Hostname string TagsJSON string IsDefault bool Enabled bool NeedsRestart bool CreatedAt time.Time UpdatedAt time.Time } // BanEventRecord represents a single ban event stored in the internal database. type BanEventRecord struct { ID int64 `json:"id"` ServerID string `json:"serverId"` ServerName string `json:"serverName"` Jail string `json:"jail"` IP string `json:"ip"` Country string `json:"country"` Hostname string `json:"hostname"` Failures string `json:"failures"` Whois string `json:"whois"` Logs string `json:"logs"` OccurredAt time.Time `json:"occurredAt"` CreatedAt time.Time `json:"createdAt"` } // RecurringIPStat represents aggregation info for repeatedly banned IPs. type RecurringIPStat struct { IP string `json:"ip"` Country string `json:"country"` Count int64 `json:"count"` LastSeen time.Time `json:"lastSeen"` } // Init initializes the internal storage. Safe to call multiple times. func Init(dbPath string) error { initOnce.Do(func() { if dbPath == "" { dbPath = defaultPath } if err := ensureDirectory(dbPath); err != nil { initErr = err return } var err error db, err = sql.Open("sqlite", fmt.Sprintf("file:%s?_pragma=journal_mode(WAL)&_pragma=busy_timeout=5000", dbPath)) if err != nil { initErr = err return } if err = db.Ping(); err != nil { initErr = err return } initErr = ensureSchema(context.Background()) }) return initErr } // Close closes the underlying database if it has been initialised. func Close() error { if db == nil { return nil } return db.Close() } func GetAppSettings(ctx context.Context) (AppSettingsRecord, bool, error) { if db == nil { return AppSettingsRecord{}, false, errors.New("storage not initialised") } row := db.QueryRowContext(ctx, ` SELECT language, port, debug, callback_url, restart_needed, alert_countries, smtp_host, smtp_port, smtp_username, smtp_password, smtp_from, smtp_use_tls, bantime_increment, ignore_ip, bantime, findtime, maxretry, destemail FROM app_settings WHERE id = 1`) var ( lang, callback, alerts, smtpHost, smtpUser, smtpPass, smtpFrom, ignoreIP, bantime, findtime, destemail sql.NullString port, smtpPort, maxretry sql.NullInt64 debug, restartNeeded, smtpTLS, bantimeInc sql.NullInt64 ) err := row.Scan(&lang, &port, &debug, &callback, &restartNeeded, &alerts, &smtpHost, &smtpPort, &smtpUser, &smtpPass, &smtpFrom, &smtpTLS, &bantimeInc, &ignoreIP, &bantime, &findtime, &maxretry, &destemail) if errors.Is(err, sql.ErrNoRows) { return AppSettingsRecord{}, false, nil } if err != nil { return AppSettingsRecord{}, false, err } rec := AppSettingsRecord{ Language: stringFromNull(lang), Port: intFromNull(port), Debug: intToBool(intFromNull(debug)), CallbackURL: stringFromNull(callback), RestartNeeded: intToBool(intFromNull(restartNeeded)), AlertCountriesJSON: stringFromNull(alerts), SMTPHost: stringFromNull(smtpHost), SMTPPort: intFromNull(smtpPort), SMTPUsername: stringFromNull(smtpUser), SMTPPassword: stringFromNull(smtpPass), SMTPFrom: stringFromNull(smtpFrom), SMTPUseTLS: intToBool(intFromNull(smtpTLS)), BantimeIncrement: intToBool(intFromNull(bantimeInc)), IgnoreIP: stringFromNull(ignoreIP), Bantime: stringFromNull(bantime), Findtime: stringFromNull(findtime), MaxRetry: intFromNull(maxretry), DestEmail: stringFromNull(destemail), } return rec, true, nil } func SaveAppSettings(ctx context.Context, rec AppSettingsRecord) error { if db == nil { return errors.New("storage not initialised") } _, err := db.ExecContext(ctx, ` INSERT INTO app_settings ( id, language, port, debug, callback_url, restart_needed, alert_countries, smtp_host, smtp_port, smtp_username, smtp_password, smtp_from, smtp_use_tls, bantime_increment, ignore_ip, bantime, findtime, maxretry, destemail ) VALUES ( 1, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ) ON CONFLICT(id) DO UPDATE SET language = excluded.language, port = excluded.port, debug = excluded.debug, callback_url = excluded.callback_url, restart_needed = excluded.restart_needed, alert_countries = excluded.alert_countries, smtp_host = excluded.smtp_host, smtp_port = excluded.smtp_port, smtp_username = excluded.smtp_username, smtp_password = excluded.smtp_password, smtp_from = excluded.smtp_from, smtp_use_tls = excluded.smtp_use_tls, bantime_increment = excluded.bantime_increment, ignore_ip = excluded.ignore_ip, bantime = excluded.bantime, findtime = excluded.findtime, maxretry = excluded.maxretry, destemail = excluded.destemail `, rec.Language, rec.Port, boolToInt(rec.Debug), rec.CallbackURL, boolToInt(rec.RestartNeeded), rec.AlertCountriesJSON, rec.SMTPHost, rec.SMTPPort, rec.SMTPUsername, rec.SMTPPassword, rec.SMTPFrom, boolToInt(rec.SMTPUseTLS), boolToInt(rec.BantimeIncrement), rec.IgnoreIP, rec.Bantime, rec.Findtime, rec.MaxRetry, rec.DestEmail, ) return err } func ListServers(ctx context.Context) ([]ServerRecord, error) { if db == nil { return nil, errors.New("storage not initialised") } rows, err := db.QueryContext(ctx, ` SELECT id, name, type, host, port, socket_path, log_path, ssh_user, ssh_key_path, agent_url, agent_secret, hostname, tags, is_default, enabled, needs_restart, created_at, updated_at FROM servers ORDER BY created_at`) if err != nil { return nil, err } defer rows.Close() var records []ServerRecord for rows.Next() { var rec ServerRecord var host, socket, logPath, sshUser, sshKey, agentURL, agentSecret, hostname, tags sql.NullString var name, serverType sql.NullString var created, updated sql.NullString var port sql.NullInt64 var isDefault, enabled, needsRestart sql.NullInt64 if err := rows.Scan( &rec.ID, &name, &serverType, &host, &port, &socket, &logPath, &sshUser, &sshKey, &agentURL, &agentSecret, &hostname, &tags, &isDefault, &enabled, &needsRestart, &created, &updated, ); err != nil { return nil, err } rec.Name = stringFromNull(name) rec.Type = stringFromNull(serverType) rec.Host = stringFromNull(host) rec.Port = intFromNull(port) rec.SocketPath = stringFromNull(socket) rec.LogPath = stringFromNull(logPath) rec.SSHUser = stringFromNull(sshUser) rec.SSHKeyPath = stringFromNull(sshKey) rec.AgentURL = stringFromNull(agentURL) rec.AgentSecret = stringFromNull(agentSecret) rec.Hostname = stringFromNull(hostname) rec.TagsJSON = stringFromNull(tags) rec.IsDefault = intToBool(intFromNull(isDefault)) rec.Enabled = intToBool(intFromNull(enabled)) rec.NeedsRestart = intToBool(intFromNull(needsRestart)) if created.Valid { if t, err := time.Parse(time.RFC3339Nano, created.String); err == nil { rec.CreatedAt = t } } if updated.Valid { if t, err := time.Parse(time.RFC3339Nano, updated.String); err == nil { rec.UpdatedAt = t } } records = append(records, rec) } return records, rows.Err() } func ReplaceServers(ctx context.Context, servers []ServerRecord) error { if db == nil { return errors.New("storage not initialised") } tx, err := db.BeginTx(ctx, nil) if err != nil { return err } defer func() { if err != nil { _ = tx.Rollback() } }() if _, err = tx.ExecContext(ctx, `DELETE FROM servers`); err != nil { return err } stmt, err := tx.PrepareContext(ctx, ` INSERT INTO servers ( id, name, type, host, port, socket_path, log_path, ssh_user, ssh_key_path, agent_url, agent_secret, hostname, tags, is_default, enabled, needs_restart, created_at, updated_at ) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? )`) if err != nil { return err } defer stmt.Close() for _, srv := range servers { createdAt := srv.CreatedAt if createdAt.IsZero() { createdAt = time.Now().UTC() } updatedAt := srv.UpdatedAt if updatedAt.IsZero() { updatedAt = createdAt } if _, err = stmt.ExecContext(ctx, srv.ID, srv.Name, srv.Type, srv.Host, srv.Port, srv.SocketPath, srv.LogPath, srv.SSHUser, srv.SSHKeyPath, srv.AgentURL, srv.AgentSecret, srv.Hostname, srv.TagsJSON, boolToInt(srv.IsDefault), boolToInt(srv.Enabled), boolToInt(srv.NeedsRestart), createdAt.Format(time.RFC3339Nano), updatedAt.Format(time.RFC3339Nano), ); err != nil { return err } } err = tx.Commit() return err } func DeleteServer(ctx context.Context, id string) error { if db == nil { return errors.New("storage not initialised") } _, err := db.ExecContext(ctx, `DELETE FROM servers WHERE id = ?`, id) return err } // RecordBanEvent stores a ban event in the database. func RecordBanEvent(ctx context.Context, record BanEventRecord) error { if db == nil { return errors.New("storage not initialised") } if record.ServerID == "" { return errors.New("server id is required") } now := time.Now().UTC() if record.CreatedAt.IsZero() { record.CreatedAt = now } if record.OccurredAt.IsZero() { record.OccurredAt = now } const query = ` INSERT INTO ban_events ( server_id, server_name, jail, ip, country, hostname, failures, whois, logs, occurred_at, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` _, err := db.ExecContext( ctx, query, record.ServerID, record.ServerName, record.Jail, record.IP, record.Country, record.Hostname, record.Failures, record.Whois, record.Logs, record.OccurredAt.UTC(), record.CreatedAt.UTC(), ) return err } // ListBanEvents returns ban events ordered by creation date descending. func ListBanEvents(ctx context.Context, serverID string, limit int, since time.Time) ([]BanEventRecord, error) { if db == nil { return nil, errors.New("storage not initialised") } if limit <= 0 || limit > 500 { limit = 100 } baseQuery := ` SELECT id, server_id, server_name, jail, ip, country, hostname, failures, whois, logs, occurred_at, created_at FROM ban_events WHERE 1=1` args := []any{} if serverID != "" { baseQuery += " AND server_id = ?" args = append(args, serverID) } if !since.IsZero() { baseQuery += " AND occurred_at >= ?" args = append(args, since.UTC()) } baseQuery += " ORDER BY occurred_at DESC LIMIT ?" args = append(args, limit) rows, err := db.QueryContext(ctx, baseQuery, args...) if err != nil { return nil, err } defer rows.Close() var results []BanEventRecord for rows.Next() { var rec BanEventRecord if err := rows.Scan( &rec.ID, &rec.ServerID, &rec.ServerName, &rec.Jail, &rec.IP, &rec.Country, &rec.Hostname, &rec.Failures, &rec.Whois, &rec.Logs, &rec.OccurredAt, &rec.CreatedAt, ); err != nil { return nil, err } results = append(results, rec) } return results, rows.Err() } // CountBanEventsByServer returns simple aggregation per server. func CountBanEventsByServer(ctx context.Context, since time.Time) (map[string]int64, error) { if db == nil { return nil, errors.New("storage not initialised") } query := ` SELECT server_id, COUNT(*) FROM ban_events WHERE 1=1` args := []any{} if !since.IsZero() { query += " AND occurred_at >= ?" args = append(args, since.UTC()) } query += " GROUP BY server_id" rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() result := make(map[string]int64) for rows.Next() { var serverID string var count int64 if err := rows.Scan(&serverID, &count); err != nil { return nil, err } result[serverID] = count } return result, rows.Err() } // CountBanEvents returns total number of ban events optionally filtered by time. func CountBanEvents(ctx context.Context, since time.Time) (int64, error) { if db == nil { return 0, errors.New("storage not initialised") } query := ` SELECT COUNT(*) FROM ban_events WHERE 1=1` args := []any{} if !since.IsZero() { query += " AND occurred_at >= ?" args = append(args, since.UTC()) } var total int64 if err := db.QueryRowContext(ctx, query, args...).Scan(&total); err != nil { return 0, err } return total, nil } // CountBanEventsByCountry returns aggregation per country code. func CountBanEventsByCountry(ctx context.Context, since time.Time) (map[string]int64, error) { if db == nil { return nil, errors.New("storage not initialised") } query := ` SELECT COALESCE(country, '') AS country, COUNT(*) FROM ban_events WHERE 1=1` args := []any{} if !since.IsZero() { query += " AND occurred_at >= ?" args = append(args, since.UTC()) } query += " GROUP BY COALESCE(country, '')" rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() result := make(map[string]int64) for rows.Next() { var country sql.NullString var count int64 if err := rows.Scan(&country, &count); err != nil { return nil, err } result[stringFromNull(country)] = count } return result, rows.Err() } // ListRecurringIPStats returns IPs that have been banned at least minCount times. func ListRecurringIPStats(ctx context.Context, since time.Time, minCount, limit int) ([]RecurringIPStat, error) { if db == nil { return nil, errors.New("storage not initialised") } if minCount < 2 { minCount = 2 } if limit <= 0 || limit > 500 { limit = 100 } query := ` SELECT ip, COALESCE(country, '') AS country, COUNT(*) AS cnt, MAX(occurred_at) AS last_seen FROM ban_events WHERE ip != ''` args := []any{} if !since.IsZero() { query += " AND occurred_at >= ?" args = append(args, since.UTC()) } query += ` GROUP BY ip, COALESCE(country, '') HAVING cnt >= ? ORDER BY cnt DESC, last_seen DESC LIMIT ?` args = append(args, minCount, limit) rows, err := db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() var results []RecurringIPStat for rows.Next() { var stat RecurringIPStat var lastSeen sql.NullString if err := rows.Scan(&stat.IP, &stat.Country, &stat.Count, &lastSeen); err != nil { return nil, err } if lastSeen.Valid { if parsed, err := time.Parse(time.RFC3339Nano, lastSeen.String); err == nil { stat.LastSeen = parsed } } results = append(results, stat) } return results, rows.Err() } func ensureSchema(ctx context.Context) error { if db == nil { return errors.New("storage not initialised") } const createTable = ` CREATE TABLE IF NOT EXISTS app_settings ( id INTEGER PRIMARY KEY CHECK (id = 1), language TEXT, port INTEGER, debug INTEGER, callback_url TEXT, restart_needed INTEGER, alert_countries TEXT, smtp_host TEXT, smtp_port INTEGER, smtp_username TEXT, smtp_password TEXT, smtp_from TEXT, smtp_use_tls INTEGER, bantime_increment INTEGER, ignore_ip TEXT, bantime TEXT, findtime TEXT, maxretry INTEGER, destemail TEXT ); CREATE TABLE IF NOT EXISTS servers ( id TEXT PRIMARY KEY, name TEXT, type TEXT, host TEXT, port INTEGER, socket_path TEXT, log_path TEXT, ssh_user TEXT, ssh_key_path TEXT, agent_url TEXT, agent_secret TEXT, hostname TEXT, tags TEXT, is_default INTEGER, enabled INTEGER, needs_restart INTEGER DEFAULT 0, created_at TEXT, updated_at TEXT ); CREATE TABLE IF NOT EXISTS ban_events ( id INTEGER PRIMARY KEY AUTOINCREMENT, server_id TEXT NOT NULL, server_name TEXT NOT NULL, jail TEXT NOT NULL, ip TEXT NOT NULL, country TEXT, hostname TEXT, failures TEXT, whois TEXT, logs TEXT, occurred_at DATETIME NOT NULL, created_at DATETIME NOT NULL ); CREATE INDEX IF NOT EXISTS idx_ban_events_server_id ON ban_events(server_id); CREATE INDEX IF NOT EXISTS idx_ban_events_occurred_at ON ban_events(occurred_at); ` if _, err := db.ExecContext(ctx, createTable); err != nil { return err } // Backfill needs_restart column for existing databases that predate it. if _, err := db.ExecContext(ctx, `ALTER TABLE servers ADD COLUMN needs_restart INTEGER DEFAULT 0`); err != nil { if !strings.Contains(strings.ToLower(err.Error()), "duplicate column name") { return err } } return nil } func ensureDirectory(path string) error { if path == ":memory:" { return nil } dir := filepath.Dir(path) if dir == "." || dir == "" { return nil } return os.MkdirAll(dir, 0o755) }