mirror of
https://github.com/swissmakers/fail2ban-ui.git
synced 2026-04-11 13:47:05 +02:00
Implement on context cancellation for ssh processes to prevent zombies
This commit is contained in:
@@ -2,6 +2,7 @@ package fail2ban
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/swissmakers/fail2ban-ui/internal/config"
|
||||
@@ -308,6 +310,13 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error {
|
||||
args := sc.buildSSHArgs([]string{"sh", "-s"})
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
|
||||
// Set process group to ensure all child processes (including SSH control master) are killed
|
||||
// when the context is cancelled. This prevents zombie processes.
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
Pgid: 0,
|
||||
}
|
||||
|
||||
// Create a script that reads the base64 string from stdin and pipes it through base64 -d | bash
|
||||
// We use a here-document to pass the base64 string
|
||||
scriptContent := fmt.Sprintf("cat <<'ENDBASE64' | base64 -d | bash\n%s\nENDBASE64\n", scriptB64)
|
||||
@@ -318,8 +327,43 @@ func (sc *SSHConnector) ensureAction(ctx context.Context) error {
|
||||
config.DebugLog("SSH ensureAction command [%s]: ssh %s (with here-doc via stdin)", sc.server.Name, strings.Join(args, " "))
|
||||
}
|
||||
|
||||
out, err := cmd.CombinedOutput()
|
||||
output := strings.TrimSpace(string(out))
|
||||
// Capture stdout and stderr
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// Start the command
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("failed to start ssh command: %w", err)
|
||||
}
|
||||
|
||||
// Monitor context cancellation and command completion
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cmd.Wait()
|
||||
}()
|
||||
|
||||
var err error
|
||||
select {
|
||||
case err = <-done:
|
||||
// Command completed normally
|
||||
case <-ctx.Done():
|
||||
// Context cancelled - kill the entire process group to prevent zombies
|
||||
if cmd.Process != nil && cmd.Process.Pid > 0 {
|
||||
// Kill the entire process group (negative PID kills the process group)
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM)
|
||||
// Give it a moment to exit gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Force kill if still running
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||
// Wait for the process to exit to prevent zombies
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
combinedOutput := append(stdout.Bytes(), stderr.Bytes()...)
|
||||
output := strings.TrimSpace(string(combinedOutput))
|
||||
if err != nil {
|
||||
config.DebugLog("Failed to ensure action file for server %s: %v (output: %s)", sc.server.Name, err, output)
|
||||
return fmt.Errorf("failed to ensure action file on remote server %s: %w (remote output: %s)", sc.server.Name, err, output)
|
||||
@@ -395,23 +439,65 @@ func (sc *SSHConnector) buildFail2banArgs(args ...string) []string {
|
||||
|
||||
func (sc *SSHConnector) runRemoteCommand(ctx context.Context, command []string) (string, error) {
|
||||
args := sc.buildSSHArgs(command)
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
// Set process group to ensure all child processes (including SSH control master) are killed
|
||||
// when we need to terminate. This prevents zombie processes.
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
Pgid: 0,
|
||||
}
|
||||
|
||||
settingSnapshot := config.GetSettings()
|
||||
if settingSnapshot.Debug {
|
||||
config.DebugLog("SSH command [%s]: ssh %s", sc.server.Name, strings.Join(args, " "))
|
||||
}
|
||||
out, err := cmd.CombinedOutput()
|
||||
output := strings.TrimSpace(string(out))
|
||||
if err != nil {
|
||||
if settingSnapshot.Debug {
|
||||
config.DebugLog("SSH command error [%s]: %v | output: %s", sc.server.Name, err, output)
|
||||
|
||||
// Capture stdout and stderr
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
// Start the command
|
||||
if err := cmd.Start(); err != nil {
|
||||
return "", fmt.Errorf("failed to start ssh command: %w", err)
|
||||
}
|
||||
|
||||
// Monitor context cancellation and command completion
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
// Command completed
|
||||
combinedOutput := append(stdout.Bytes(), stderr.Bytes()...)
|
||||
output := strings.TrimSpace(string(combinedOutput))
|
||||
if err != nil {
|
||||
if settingSnapshot.Debug {
|
||||
config.DebugLog("SSH command error [%s]: %v | output: %s", sc.server.Name, err, output)
|
||||
}
|
||||
return output, fmt.Errorf("ssh command failed: %w (output: %s)", err, output)
|
||||
}
|
||||
return output, fmt.Errorf("ssh command failed: %w (output: %s)", err, output)
|
||||
if settingSnapshot.Debug {
|
||||
config.DebugLog("SSH command output [%s]: %s", sc.server.Name, output)
|
||||
}
|
||||
return output, nil
|
||||
case <-ctx.Done():
|
||||
// Context cancelled - kill the entire process group to prevent zombies
|
||||
if cmd.Process != nil && cmd.Process.Pid > 0 {
|
||||
// Kill the entire process group (negative PID kills the process group)
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGTERM)
|
||||
// Give it a moment to exit gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Force kill if still running
|
||||
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
|
||||
// Wait for the process to exit to prevent zombies
|
||||
_, _ = cmd.Process.Wait()
|
||||
}
|
||||
return "", ctx.Err()
|
||||
}
|
||||
if settingSnapshot.Debug {
|
||||
config.DebugLog("SSH command output [%s]: %s", sc.server.Name, output)
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
func (sc *SSHConnector) buildSSHArgs(command []string) []string {
|
||||
|
||||
Reference in New Issue
Block a user