diff --git a/internal/fail2ban/connector_ssh.go b/internal/fail2ban/connector_ssh.go index e2abde2..f27c78c 100644 --- a/internal/fail2ban/connector_ssh.go +++ b/internal/fail2ban/connector_ssh.go @@ -2,6 +2,7 @@ package fail2ban import ( "bufio" + "bytes" "context" "encoding/base64" "errors" @@ -13,6 +14,7 @@ import ( "strconv" "strings" "sync" + "syscall" "time" "github.com/swissmakers/fail2ban-ui/internal/config" @@ -308,6 +310,13 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error { args := sc.buildSSHArgs([]string{"sh", "-s"}) cmd := exec.CommandContext(ctx, "ssh", args...) + // Set process group to ensure all child processes (including SSH control master) are killed + // when the context is cancelled. This prevents zombie processes. + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Pgid: 0, + } + // Create a script that reads the base64 string from stdin and pipes it through base64 -d | bash // We use a here-document to pass the base64 string scriptContent := fmt.Sprintf("cat <<'ENDBASE64' | base64 -d | bash\n%s\nENDBASE64\n", scriptB64) @@ -318,8 +327,43 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error { config.DebugLog("SSH ensureAction command [%s]: ssh %s (with here-doc via stdin)", sc.server.Name, strings.Join(args, " ")) } - out, err := cmd.CombinedOutput() - output := strings.TrimSpace(string(out)) + // Capture stdout and stderr + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Start the command + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start ssh command: %w", err) + } + + // Monitor context cancellation and command completion + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + var err error + select { + case err = <-done: + // Command completed normally + case <-ctx.Done(): + // Context cancelled - kill the entire process group to prevent zombies + if cmd.Process != nil && cmd.Process.Pid > 0 { + // Kill the entire process group (negative PID kills the process group) + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM) + // Give it a moment to exit gracefully + time.Sleep(100 * time.Millisecond) + // Force kill if still running + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + // Wait for the process to exit to prevent zombies + _, _ = cmd.Process.Wait() + } + return ctx.Err() + } + + combinedOutput := append(stdout.Bytes(), stderr.Bytes()...) + output := strings.TrimSpace(string(combinedOutput)) if err != nil { config.DebugLog("Failed to ensure action file for server %s: %v (output: %s)", sc.server.Name, err, output) return fmt.Errorf("failed to ensure action file on remote server %s: %w (remote output: %s)", sc.server.Name, err, output) @@ -395,23 +439,65 @@ func (sc *SSHConnector) buildFail2banArgs(args ...string) []string { func (sc *SSHConnector) runRemoteCommand(ctx context.Context, command []string) (string, error) { args := sc.buildSSHArgs(command) - cmd := exec.CommandContext(ctx, "ssh", args...) + cmd := exec.Command("ssh", args...) + + // Set process group to ensure all child processes (including SSH control master) are killed + // when we need to terminate. This prevents zombie processes. + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Pgid: 0, + } + settingSnapshot := config.GetSettings() if settingSnapshot.Debug { config.DebugLog("SSH command [%s]: ssh %s", sc.server.Name, strings.Join(args, " ")) } - out, err := cmd.CombinedOutput() - output := strings.TrimSpace(string(out)) - if err != nil { - if settingSnapshot.Debug { - config.DebugLog("SSH command error [%s]: %v | output: %s", sc.server.Name, err, output) + + // Capture stdout and stderr + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Start the command + if err := cmd.Start(); err != nil { + return "", fmt.Errorf("failed to start ssh command: %w", err) + } + + // Monitor context cancellation and command completion + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + // Command completed + combinedOutput := append(stdout.Bytes(), stderr.Bytes()...) + output := strings.TrimSpace(string(combinedOutput)) + if err != nil { + if settingSnapshot.Debug { + config.DebugLog("SSH command error [%s]: %v | output: %s", sc.server.Name, err, output) + } + return output, fmt.Errorf("ssh command failed: %w (output: %s)", err, output) } - return output, fmt.Errorf("ssh command failed: %w (output: %s)", err, output) + if settingSnapshot.Debug { + config.DebugLog("SSH command output [%s]: %s", sc.server.Name, output) + } + return output, nil + case <-ctx.Done(): + // Context cancelled - kill the entire process group to prevent zombies + if cmd.Process != nil && cmd.Process.Pid > 0 { + // Kill the entire process group (negative PID kills the process group) + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM) + // Give it a moment to exit gracefully + time.Sleep(100 * time.Millisecond) + // Force kill if still running + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + // Wait for the process to exit to prevent zombies + _, _ = cmd.Process.Wait() + } + return "", ctx.Err() } - if settingSnapshot.Debug { - config.DebugLog("SSH command output [%s]: %s", sc.server.Name, output) - } - return output, nil } func (sc *SSHConnector) buildSSHArgs(command []string) []string {