diff --git a/iptables/iptables.go b/iptables/iptables.go index b058995..28b48f6 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -82,6 +82,7 @@ type IPTables struct { v3 int mode string // the underlying iptables operating mode, e.g. nf_tables timeout int // time to wait for the iptables lock, default waits forever + execCommandFunc func(string, ...string) *exec.Cmd } // Stat represents a structured statistic entry. @@ -118,6 +119,15 @@ func Path(path string) option { } } +// ExecCommandFunc allows for overriding the [exec.Command] used when spawning +// iptables sub-processes. Stdout and Stderr should be nil on the returned +// [exec.Cmd]. +func ExecCommandFunc(fn func(name string, arg ...string) *exec.Cmd) option { + return func(ipt *IPTables) { + ipt.execCommandFunc = fn + } +} + // New creates a new IPTables configured with the options passed as parameters. // Supported parameters are: // @@ -133,9 +143,10 @@ func Path(path string) option { func New(opts ...option) (*IPTables, error) { ipt := &IPTables{ - proto: ProtocolIPv4, - timeout: 0, - path: "", + proto: ProtocolIPv4, + timeout: 0, + path: "", + execCommandFunc: exec.Command, } for _, opt := range opts { @@ -155,7 +166,7 @@ func New(opts ...option) (*IPTables, error) { } ipt.path = path - vstring, err := getIptablesVersionString(path) + vstring, err := getIptablesVersionString(ipt.execCommandFunc, path) if err != nil { return nil, fmt.Errorf("could not get iptables version: %v", err) } @@ -563,7 +574,7 @@ func (ipt *IPTables) run(args ...string) error { // runWithOutput runs an iptables command with the given arguments, // writing any stdout output to the given writer func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { - args = append([]string{ipt.path}, args...) + args = append([]string(nil), args...) // copy input args if ipt.hasWait { args = append(args, "--wait") if ipt.timeout != 0 && ipt.waitSupportSecond { @@ -585,17 +596,14 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { } var stderr bytes.Buffer - cmd := exec.Cmd{ - Path: ipt.path, - Args: args, - Stdout: stdout, - Stderr: &stderr, - } + cmd := ipt.execCommandFunc(ipt.path, args...) + cmd.Stdout = stdout + cmd.Stderr = &stderr if err := cmd.Run(); err != nil { switch e := err.(type) { case *exec.ExitError: - return &Error{*e, cmd, stderr.String(), nil} + return &Error{*e, *cmd, stderr.String(), nil} default: return err } @@ -651,8 +659,11 @@ func extractIptablesVersion(str string) (int, int, int, string, error) { } // Runs "iptables --version" to get the version string -func getIptablesVersionString(path string) (string, error) { - cmd := exec.Command(path, "--version") +func getIptablesVersionString( + execCommandFunc func(string, ...string) *exec.Cmd, + path string, +) (string, error) { + cmd := execCommandFunc(path, "--version") var out bytes.Buffer cmd.Stdout = &out err := cmd.Run()