Implement on context cancellation for ssh processes to prevent zombies

This commit is contained in:
2026-01-19 23:45:02 +01:00
parent d64eb3db95
commit de92a640e2

View File

@@ -2,6 +2,7 @@ package fail2ban
import (
"bufio"
"bytes"
"context"
"encoding/base64"
"errors"
@@ -13,6 +14,7 @@ import (
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/swissmakers/fail2ban-ui/internal/config"
@@ -308,6 +310,13 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error {
args := sc.buildSSHArgs([]string{"sh", "-s"})
cmd := exec.CommandContext(ctx, "ssh", args...)
// Set process group to ensure all child processes (including SSH control master) are killed
// when the context is cancelled. This prevents zombie processes.
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
Pgid: 0,
}
// Create a script that reads the base64 string from stdin and pipes it through base64 -d | bash
// We use a here-document to pass the base64 string
scriptContent := fmt.Sprintf("cat <<'ENDBASE64' | base64 -d | bash\n%s\nENDBASE64\n", scriptB64)
@@ -318,8 +327,43 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error {
config.DebugLog("SSH ensureAction command [%s]: ssh %s (with here-doc via stdin)", sc.server.Name, strings.Join(args, " "))
}
out, err := cmd.CombinedOutput()
output := strings.TrimSpace(string(out))
// Capture stdout and stderr
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// Start the command
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start ssh command: %w", err)
}
// Monitor context cancellation and command completion
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()
var err error
select {
case err = <-done:
// Command completed normally
case <-ctx.Done():
// Context cancelled - kill the entire process group to prevent zombies
if cmd.Process != nil && cmd.Process.Pid > 0 {
// Kill the entire process group (negative PID kills the process group)
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM)
// Give it a moment to exit gracefully
time.Sleep(100 * time.Millisecond)
// Force kill if still running
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
// Wait for the process to exit to prevent zombies
_, _ = cmd.Process.Wait()
}
return ctx.Err()
}
combinedOutput := append(stdout.Bytes(), stderr.Bytes()...)
output := strings.TrimSpace(string(combinedOutput))
if err != nil {
config.DebugLog("Failed to ensure action file for server %s: %v (output: %s)", sc.server.Name, err, output)
return fmt.Errorf("failed to ensure action file on remote server %s: %w (remote output: %s)", sc.server.Name, err, output)
@@ -395,13 +439,41 @@ func (sc *SSHConnector) buildFail2banArgs(args ...string) []string {
func (sc *SSHConnector) runRemoteCommand(ctx context.Context, command []string) (string, error) {
args := sc.buildSSHArgs(command)
cmd := exec.CommandContext(ctx, "ssh", args...)
cmd := exec.Command("ssh", args...)
// Set process group to ensure all child processes (including SSH control master) are killed
// when we need to terminate. This prevents zombie processes.
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,
Pgid: 0,
}
settingSnapshot := config.GetSettings()
if settingSnapshot.Debug {
config.DebugLog("SSH command [%s]: ssh %s", sc.server.Name, strings.Join(args, " "))
}
out, err := cmd.CombinedOutput()
output := strings.TrimSpace(string(out))
// Capture stdout and stderr
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
// Start the command
if err := cmd.Start(); err != nil {
return "", fmt.Errorf("failed to start ssh command: %w", err)
}
// Monitor context cancellation and command completion
done := make(chan error, 1)
go func() {
done <- cmd.Wait()
}()
select {
case err := <-done:
// Command completed
combinedOutput := append(stdout.Bytes(), stderr.Bytes()...)
output := strings.TrimSpace(string(combinedOutput))
if err != nil {
if settingSnapshot.Debug {
config.DebugLog("SSH command error [%s]: %v | output: %s", sc.server.Name, err, output)
@@ -412,6 +484,20 @@ func (sc *SSHConnector) runRemoteCommand(ctx context.Context, command []string)
config.DebugLog("SSH command output [%s]: %s", sc.server.Name, output)
}
return output, nil
case <-ctx.Done():
// Context cancelled - kill the entire process group to prevent zombies
if cmd.Process != nil && cmd.Process.Pid > 0 {
// Kill the entire process group (negative PID kills the process group)
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM)
// Give it a moment to exit gracefully
time.Sleep(100 * time.Millisecond)
// Force kill if still running
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
// Wait for the process to exit to prevent zombies
_, _ = cmd.Process.Wait()
}
return "", ctx.Err()
}
}
func (sc *SSHConnector) buildSSHArgs(command []string) []string {