Skip to content

Commit c70a91b

Browse files
committed
Iptables mode selection
1 parent b9dff5a commit c70a91b

File tree

2 files changed

+134
-63
lines changed

2 files changed

+134
-63
lines changed

iptables/iptables.go

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ const (
6464
ProtocolIPv6
6565
)
6666

67+
// Mode to differentiate between legacy and nf_tables
68+
type ModeType string
69+
70+
const (
71+
// ModeTypeAuto is the default mode, which uses the system default
72+
ModeTypeAuto ModeType = "auto"
73+
// ModeTypeLegacy forces the use of the legacy iptables mode
74+
ModeTypeLegacy ModeType = "legacy"
75+
// ModeTypeNFTables forces the use of the nf_tables iptables mode
76+
ModeTypeNFTables ModeType = "nf_tables"
77+
)
78+
6779
type IPTables struct {
6880
path string
6981
proto Protocol
@@ -74,8 +86,8 @@ type IPTables struct {
7486
v1 int
7587
v2 int
7688
v3 int
77-
mode string // the underlying iptables operating mode, e.g. nf_tables
78-
timeout int // time to wait for the iptables lock, default waits forever
89+
mode ModeType // the underlying iptables operating mode, e.g. nf_tables
90+
timeout int // time to wait for the iptables lock, default waits forever
7991
}
8092

8193
// Stat represents a structured statistic entry.
@@ -106,6 +118,12 @@ func Timeout(timeout int) option {
106118
}
107119
}
108120

121+
func Mode(mode ModeType) option {
122+
return func(ipt *IPTables) {
123+
ipt.mode = mode
124+
}
125+
}
126+
109127
// New creates a new IPTables configured with the options passed as parameter.
110128
// For backwards compatibility, by default always uses IPv4 and timeout 0.
111129
// i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing
@@ -116,14 +134,15 @@ func New(opts ...option) (*IPTables, error) {
116134

117135
ipt := &IPTables{
118136
proto: ProtocolIPv4,
137+
mode: ModeTypeAuto,
119138
timeout: 0,
120139
}
121140

122141
for _, opt := range opts {
123142
opt(ipt)
124143
}
125144

126-
path, err := exec.LookPath(getIptablesCommand(ipt.proto))
145+
path, err := exec.LookPath(getIptablesCommand(ipt.proto, ipt.mode))
127146
if err != nil {
128147
return nil, err
129148
}
@@ -133,14 +152,13 @@ func New(opts ...option) (*IPTables, error) {
133152
if err != nil {
134153
return nil, fmt.Errorf("could not get iptables version: %v", err)
135154
}
136-
v1, v2, v3, mode, err := extractIptablesVersion(vstring)
155+
v1, v2, v3, _, err := extractIptablesVersion(vstring)
137156
if err != nil {
138157
return nil, fmt.Errorf("failed to extract iptables version from [%s]: %v", vstring, err)
139158
}
140159
ipt.v1 = v1
141160
ipt.v2 = v2
142161
ipt.v3 = v3
143-
ipt.mode = mode
144162

145163
checkPresent, waitPresent, waitSupportSecond, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
146164
ipt.hasCheck = checkPresent
@@ -518,8 +536,8 @@ func (ipt *IPTables) HasRandomFully() bool {
518536
}
519537

520538
// Return version components of the underlying iptables command
521-
func (ipt *IPTables) GetIptablesVersion() (int, int, int) {
522-
return ipt.v1, ipt.v2, ipt.v3
539+
func (ipt *IPTables) GetIptablesVersion() (int, int, int, ModeType) {
540+
return ipt.v1, ipt.v2, ipt.v3, ipt.mode
523541
}
524542

525543
// run runs an iptables command with the given arguments, ignoring
@@ -573,12 +591,23 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
573591
}
574592

575593
// getIptablesCommand returns the correct command for the given protocol, either "iptables" or "ip6tables".
576-
func getIptablesCommand(proto Protocol) string {
577-
if proto == ProtocolIPv6 {
578-
return "ip6tables"
579-
} else {
580-
return "iptables"
594+
func getIptablesCommand(proto Protocol, mode ModeType) string {
595+
var cmd string
596+
switch proto {
597+
case ProtocolIPv4:
598+
cmd = "iptables"
599+
case ProtocolIPv6:
600+
cmd = "ip6tables"
601+
}
602+
// Append a suffix to the command to get the correct binary,
603+
// If the mode is auto (default), the suffix is not applied and the system default is used.
604+
switch mode {
605+
case ModeTypeNFTables:
606+
cmd = fmt.Sprintf("%s-nft", cmd)
607+
case ModeTypeLegacy:
608+
cmd = fmt.Sprintf("%s-legacy", cmd)
581609
}
610+
return cmd
582611
}
583612

584613
// Checks if iptables has the "-C" and "--wait" flag
@@ -589,7 +618,7 @@ func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, bool)
589618
// getIptablesVersion returns the first three components of the iptables version
590619
// and the operating mode (e.g. nf_tables or legacy)
591620
// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
592-
func extractIptablesVersion(str string) (int, int, int, string, error) {
621+
func extractIptablesVersion(str string) (int, int, int, ModeType, error) {
593622
versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
594623
result := versionMatcher.FindStringSubmatch(str)
595624
if result == nil {
@@ -611,9 +640,9 @@ func extractIptablesVersion(str string) (int, int, int, string, error) {
611640
return 0, 0, 0, "", err
612641
}
613642

614-
mode := "legacy"
643+
mode := ModeTypeLegacy
615644
if result[4] != "" {
616-
mode = result[4]
645+
mode = ModeType(result[4])
617646
}
618647
return v1, v2, v3, mode, nil
619648
}

iptables/iptables_test.go

Lines changed: 90 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,23 @@ import (
2525
"testing"
2626
)
2727

28+
var (
29+
protos = []Protocol{ProtocolIPv4, ProtocolIPv6}
30+
modes = []ModeType{ModeTypeAuto, ModeTypeLegacy, ModeTypeNFTables}
31+
)
32+
33+
// getProtoName returns the name of the protocol, for use in test names.
34+
func getProtoName(proto Protocol) string {
35+
switch proto {
36+
case ProtocolIPv4:
37+
return "IPv4"
38+
case ProtocolIPv6:
39+
return "IPv6"
40+
default:
41+
panic("unknown protocol")
42+
}
43+
}
44+
2845
func TestProto(t *testing.T) {
2946
ipt, err := New()
3047
if err != nil {
@@ -34,40 +51,72 @@ func TestProto(t *testing.T) {
3451
t.Fatalf("Expected default protocol IPv4, got %v", ipt.Proto())
3552
}
3653

37-
ip4t, err := NewWithProtocol(ProtocolIPv4)
38-
if err != nil {
39-
t.Fatalf("NewWithProtocol(ProtocolIPv4) failed: %v", err)
40-
}
41-
if ip4t.Proto() != ProtocolIPv4 {
42-
t.Fatalf("Expected protocol IPv4, got %v", ip4t.Proto())
54+
for _, proto := range protos {
55+
protoName := getProtoName(proto)
56+
ipt, err := New(IPFamily(proto))
57+
if err != nil {
58+
t.Fatalf("NewWithProtocol(%s) failed: %v", protoName, err)
59+
}
60+
if ipt.Proto() != proto {
61+
t.Fatalf("Expected protocol %s, got %v", protoName, ipt.Proto())
62+
}
63+
if ipt.mode != ModeTypeAuto {
64+
t.Fatalf("Expected mode auto, got %v", ipt.mode)
65+
}
4366
}
4467

45-
ip6t, err := NewWithProtocol(ProtocolIPv6)
46-
if err != nil {
47-
t.Fatalf("NewWithProtocol(ProtocolIPv6) failed: %v", err)
48-
}
49-
if ip6t.Proto() != ProtocolIPv6 {
50-
t.Fatalf("Expected protocol IPv6, got %v", ip6t.Proto())
68+
for _, proto := range protos {
69+
for _, mode := range modes {
70+
protoName := getProtoName(proto)
71+
ipt, err := New(Mode(mode), IPFamily(proto))
72+
if err != nil {
73+
t.Fatalf("New(Mode(%v), IPFamily(%v)) failed: %v", mode, protoName, err)
74+
}
75+
if ipt.Proto() != proto {
76+
t.Fatalf("Expected protocol %v, got %v", protoName, ipt.Proto())
77+
}
78+
if ipt.mode != mode {
79+
t.Fatalf("Expected mode %v, got %v", mode, ipt.mode)
80+
}
81+
}
5182
}
5283
}
5384

5485
func TestTimeout(t *testing.T) {
55-
ipt, err := New()
56-
if err != nil {
57-
t.Fatalf("New failed: %v", err)
58-
}
59-
if ipt.timeout != 0 {
60-
t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
61-
}
86+
for _, proto := range protos {
87+
for _, mode := range modes {
88+
ipt, err := New(IPFamily(proto), Mode(mode))
89+
if err != nil {
90+
t.Fatalf("New failed: %v", err)
91+
}
92+
if ipt.timeout != 0 {
93+
t.Fatalf("Expected timeout 0 (wait forever), got %v", ipt.timeout)
94+
}
6295

63-
ipt2, err := New(Timeout(5))
64-
if err != nil {
65-
t.Fatalf("New failed: %v", err)
66-
}
67-
if ipt2.timeout != 5 {
68-
t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
96+
ipt2, err := New(Timeout(5))
97+
if err != nil {
98+
t.Fatalf("New failed: %v", err)
99+
}
100+
if ipt2.timeout != 5 {
101+
t.Fatalf("Expected timeout 5, got %v", ipt.timeout)
102+
}
103+
}
69104
}
105+
}
70106

107+
func TestGetIptablesVersionMode(t *testing.T) {
108+
for _, proto := range protos {
109+
for _, mode := range modes {
110+
ipt, err := New(IPFamily(proto), Mode(mode))
111+
if err != nil {
112+
t.Fatalf("New failed: %v", err)
113+
}
114+
_, _, _, getmode := ipt.GetIptablesVersion()
115+
if getmode != mode {
116+
t.Fatalf("Expected mode %v, got %v", mode, mode)
117+
}
118+
}
119+
}
71120
}
72121

73122
func randChain(t *testing.T) string {
@@ -92,27 +141,20 @@ func contains(list []string, value string) bool {
92141
// features enabled & disabled, to test compatibility.
93142
// We used to test noWait as well, but that was removed as of iptables v1.6.0
94143
func mustTestableIptables() []*IPTables {
95-
ipt, err := New()
96-
if err != nil {
97-
panic(fmt.Sprintf("New failed: %v", err))
98-
}
99-
ip6t, err := NewWithProtocol(ProtocolIPv6)
100-
if err != nil {
101-
panic(fmt.Sprintf("NewWithProtocol(ProtocolIPv6) failed: %v", err))
102-
}
103-
ipts := []*IPTables{ipt, ip6t}
104-
105-
// ensure we check one variant without built-in checking
106-
if ipt.hasCheck {
107-
i := *ipt
108-
i.hasCheck = false
109-
ipts = append(ipts, &i)
110-
111-
i6 := *ip6t
112-
i6.hasCheck = false
113-
ipts = append(ipts, &i6)
114-
} else {
115-
panic("iptables on this machine is too old -- missing -C")
144+
ipts := []*IPTables{}
145+
for _, proto := range protos {
146+
for _, mode := range modes {
147+
ipt, err := New(IPFamily(proto), Mode(mode))
148+
if err != nil {
149+
panic(fmt.Sprintf("New(IPFamily(%v), Mode(%v)) failed: %v", proto, mode, err))
150+
}
151+
if ipt.hasCheck {
152+
ipt.hasCheck = false
153+
ipts = append(ipts, ipt)
154+
} else {
155+
panic("iptables on this machine is too old -- missing -C")
156+
}
157+
}
116158
}
117159
return ipts
118160
}
@@ -251,7 +293,7 @@ func TestRules(t *testing.T) {
251293
}
252294

253295
func runRulesTests(t *testing.T, ipt *IPTables) {
254-
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto()), ipt.hasWait, ipt.hasCheck)
296+
t.Logf("testing %s (hasWait=%t, hasCheck=%t)", getIptablesCommand(ipt.Proto(), ModeTypeAuto), ipt.hasWait, ipt.hasCheck)
255297

256298
var address1, address2, subnet1, subnet2 string
257299
if ipt.Proto() == ProtocolIPv6 {
@@ -689,7 +731,7 @@ func TestExtractIptablesVersion(t *testing.T) {
689731
t.Fatalf("unexpected err %s", err)
690732
}
691733

692-
if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
734+
if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != ModeType(tt.mode) {
693735
t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
694736
tt.v1, tt.v2, tt.v3, tt.mode,
695737
v1, v2, v3, mode)

0 commit comments

Comments
 (0)