Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ type IPTables struct {
v3 int
mode string // the underlying iptables operating mode, e.g. nf_tables
timeout int // time to wait for the iptables lock, default waits forever
execCommandFunc func(string, ...string) *exec.Cmd
}

// Stat represents a structured statistic entry.
Expand Down Expand Up @@ -118,6 +119,15 @@ func Path(path string) option {
}
}

// ExecCommandFunc allows for overriding the [exec.Command] used when spawning
// iptables sub-processes. Stdout and Stderr should be nil on the returned
// [exec.Cmd].
func ExecCommandFunc(fn func(name string, arg ...string) *exec.Cmd) option {
return func(ipt *IPTables) {
ipt.execCommandFunc = fn
}
}

// New creates a new IPTables configured with the options passed as parameters.
// Supported parameters are:
//
Expand All @@ -133,9 +143,10 @@ func Path(path string) option {
func New(opts ...option) (*IPTables, error) {

ipt := &IPTables{
proto: ProtocolIPv4,
timeout: 0,
path: "",
proto: ProtocolIPv4,
timeout: 0,
path: "",
execCommandFunc: exec.Command,
}

for _, opt := range opts {
Expand All @@ -155,7 +166,7 @@ func New(opts ...option) (*IPTables, error) {
}
ipt.path = path

vstring, err := getIptablesVersionString(path)
vstring, err := getIptablesVersionString(ipt.execCommandFunc, path)
if err != nil {
return nil, fmt.Errorf("could not get iptables version: %v", err)
}
Expand Down Expand Up @@ -563,7 +574,7 @@ func (ipt *IPTables) run(args ...string) error {
// runWithOutput runs an iptables command with the given arguments,
// writing any stdout output to the given writer
func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
args = append([]string{ipt.path}, args...)
args = append([]string(nil), args...) // copy input args
if ipt.hasWait {
args = append(args, "--wait")
if ipt.timeout != 0 && ipt.waitSupportSecond {
Expand All @@ -585,17 +596,14 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
}

var stderr bytes.Buffer
cmd := exec.Cmd{
Path: ipt.path,
Args: args,
Stdout: stdout,
Stderr: &stderr,
}
cmd := ipt.execCommandFunc(ipt.path, args...)
cmd.Stdout = stdout
cmd.Stderr = &stderr

if err := cmd.Run(); err != nil {
switch e := err.(type) {
case *exec.ExitError:
return &Error{*e, cmd, stderr.String(), nil}
return &Error{*e, *cmd, stderr.String(), nil}
default:
return err
}
Expand Down Expand Up @@ -651,8 +659,11 @@ func extractIptablesVersion(str string) (int, int, int, string, error) {
}

// Runs "iptables --version" to get the version string
func getIptablesVersionString(path string) (string, error) {
cmd := exec.Command(path, "--version")
func getIptablesVersionString(
execCommandFunc func(string, ...string) *exec.Cmd,
path string,
) (string, error) {
cmd := execCommandFunc(path, "--version")
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
Expand Down