mirror of
https://github.com/swissmakers/fail2ban-ui.git
synced 2026-04-17 05:53:15 +02:00
Implement on context cancellation for ssh processes to prevent zombies
This commit is contained in:
@@ -2,6 +2,7 @@ package fail2ban
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -13,6 +14,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/swissmakers/fail2ban-ui/internal/config"
|
"github.com/swissmakers/fail2ban-ui/internal/config"
|
||||||
@@ -308,6 +310,13 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error {
|
|||||||
args := sc.buildSSHArgs([]string{"sh", "-s"})
|
args := sc.buildSSHArgs([]string{"sh", "-s"})
|
||||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||||
|
|
||||||
|
// Set process group to ensure all child processes (including SSH control master) are killed
|
||||||
|
// when the context is cancelled. This prevents zombie processes.
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
Setpgid: true,
|
||||||
|
Pgid: 0,
|
||||||
|
}
|
||||||
|
|
||||||
// Create a script that reads the base64 string from stdin and pipes it through base64 -d | bash
|
// Create a script that reads the base64 string from stdin and pipes it through base64 -d | bash
|
||||||
// We use a here-document to pass the base64 string
|
// We use a here-document to pass the base64 string
|
||||||
scriptContent := fmt.Sprintf("cat <<'ENDBASE64' | base64 -d | bash\n%s\nENDBASE64\n", scriptB64)
|
scriptContent := fmt.Sprintf("cat <<'ENDBASE64' | base64 -d | bash\n%s\nENDBASE64\n", scriptB64)
|
||||||
@@ -318,8 +327,43 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error {
|
|||||||
config.DebugLog("SSH ensureAction command [%s]: ssh %s (with here-doc via stdin)", sc.server.Name, strings.Join(args, " "))
|
config.DebugLog("SSH ensureAction command [%s]: ssh %s (with here-doc via stdin)", sc.server.Name, strings.Join(args, " "))
|
||||||
}
|
}
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
// Capture stdout and stderr
|
||||||
output := strings.TrimSpace(string(out))
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
// Start the command
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start ssh command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Monitor context cancellation and command completion
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- cmd.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
select {
|
||||||
|
case err = <-done:
|
||||||
|
// Command completed normally
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Context cancelled - kill the entire process group to prevent zombies
|
||||||
|
if cmd.Process != nil && cmd.Process.Pid > 0 {
|
||||||
|
// Kill the entire process group (negative PID kills the process group)
|
||||||
|
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM)
|
||||||
|
// Give it a moment to exit gracefully
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// Force kill if still running
|
||||||
|
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||||
|
// Wait for the process to exit to prevent zombies
|
||||||
|
_, _ = cmd.Process.Wait()
|
||||||
|
}
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
combinedOutput := append(stdout.Bytes(), stderr.Bytes()...)
|
||||||
|
output := strings.TrimSpace(string(combinedOutput))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
config.DebugLog("Failed to ensure action file for server %s: %v (output: %s)", sc.server.Name, err, output)
|
config.DebugLog("Failed to ensure action file for server %s: %v (output: %s)", sc.server.Name, err, output)
|
||||||
return fmt.Errorf("failed to ensure action file on remote server %s: %w (remote output: %s)", sc.server.Name, err, output)
|
return fmt.Errorf("failed to ensure action file on remote server %s: %w (remote output: %s)", sc.server.Name, err, output)
|
||||||
@@ -395,23 +439,65 @@ func (sc *SSHConnector) buildFail2banArgs(args ...string) []string {
|
|||||||
|
|
||||||
func (sc *SSHConnector) runRemoteCommand(ctx context.Context, command []string) (string, error) {
|
func (sc *SSHConnector) runRemoteCommand(ctx context.Context, command []string) (string, error) {
|
||||||
args := sc.buildSSHArgs(command)
|
args := sc.buildSSHArgs(command)
|
||||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
cmd := exec.Command("ssh", args...)
|
||||||
|
|
||||||
|
// Set process group to ensure all child processes (including SSH control master) are killed
|
||||||
|
// when we need to terminate. This prevents zombie processes.
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
Setpgid: true,
|
||||||
|
Pgid: 0,
|
||||||
|
}
|
||||||
|
|
||||||
settingSnapshot := config.GetSettings()
|
settingSnapshot := config.GetSettings()
|
||||||
if settingSnapshot.Debug {
|
if settingSnapshot.Debug {
|
||||||
config.DebugLog("SSH command [%s]: ssh %s", sc.server.Name, strings.Join(args, " "))
|
config.DebugLog("SSH command [%s]: ssh %s", sc.server.Name, strings.Join(args, " "))
|
||||||
}
|
}
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
output := strings.TrimSpace(string(out))
|
// Capture stdout and stderr
|
||||||
if err != nil {
|
var stdout, stderr bytes.Buffer
|
||||||
if settingSnapshot.Debug {
|
cmd.Stdout = &stdout
|
||||||
config.DebugLog("SSH command error [%s]: %v | output: %s", sc.server.Name, err, output)
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
// Start the command
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to start ssh command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Monitor context cancellation and command completion
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- cmd.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
// Command completed
|
||||||
|
combinedOutput := append(stdout.Bytes(), stderr.Bytes()...)
|
||||||
|
output := strings.TrimSpace(string(combinedOutput))
|
||||||
|
if err != nil {
|
||||||
|
if settingSnapshot.Debug {
|
||||||
|
config.DebugLog("SSH command error [%s]: %v | output: %s", sc.server.Name, err, output)
|
||||||
|
}
|
||||||
|
return output, fmt.Errorf("ssh command failed: %w (output: %s)", err, output)
|
||||||
}
|
}
|
||||||
return output, fmt.Errorf("ssh command failed: %w (output: %s)", err, output)
|
if settingSnapshot.Debug {
|
||||||
|
config.DebugLog("SSH command output [%s]: %s", sc.server.Name, output)
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
// Context cancelled - kill the entire process group to prevent zombies
|
||||||
|
if cmd.Process != nil && cmd.Process.Pid > 0 {
|
||||||
|
// Kill the entire process group (negative PID kills the process group)
|
||||||
|
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM)
|
||||||
|
// Give it a moment to exit gracefully
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
// Force kill if still running
|
||||||
|
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||||
|
// Wait for the process to exit to prevent zombies
|
||||||
|
_, _ = cmd.Process.Wait()
|
||||||
|
}
|
||||||
|
return "", ctx.Err()
|
||||||
}
|
}
|
||||||
if settingSnapshot.Debug {
|
|
||||||
config.DebugLog("SSH command output [%s]: %s", sc.server.Name, output)
|
|
||||||
}
|
|
||||||
return output, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sc *SSHConnector) buildSSHArgs(command []string) []string {
|
func (sc *SSHConnector) buildSSHArgs(command []string) []string {
|
||||||
|
|||||||
Reference in New Issue
Block a user