@@ -25,6 +25,23 @@ import (
2525 "testing"
2626)
2727
28+ var (
29+ protos = []Protocol {ProtocolIPv4 , ProtocolIPv6 }
30+ modes = []ModeType {ModeTypeAuto , ModeTypeLegacy , ModeTypeNFTables }
31+ )
32+
33+ // getProtoName returns the name of the protocol, for use in test names.
34+ func getProtoName (proto Protocol ) string {
35+ switch proto {
36+ case ProtocolIPv4 :
37+ return "IPv4"
38+ case ProtocolIPv6 :
39+ return "IPv6"
40+ default :
41+ panic ("unknown protocol" )
42+ }
43+ }
44+
2845func TestProto (t * testing.T ) {
2946 ipt , err := New ()
3047 if err != nil {
@@ -34,40 +51,72 @@ func TestProto(t *testing.T) {
3451 t .Fatalf ("Expected default protocol IPv4, got %v" , ipt .Proto ())
3552 }
3653
37- ip4t , err := NewWithProtocol (ProtocolIPv4 )
38- if err != nil {
39- t .Fatalf ("NewWithProtocol(ProtocolIPv4) failed: %v" , err )
40- }
41- if ip4t .Proto () != ProtocolIPv4 {
42- t .Fatalf ("Expected protocol IPv4, got %v" , ip4t .Proto ())
54+ for _ , proto := range protos {
55+ protoName := getProtoName (proto )
56+ ipt , err := New (IPFamily (proto ))
57+ if err != nil {
58+ t .Fatalf ("NewWithProtocol(%s) failed: %v" , protoName , err )
59+ }
60+ if ipt .Proto () != proto {
61+ t .Fatalf ("Expected protocol %s, got %v" , protoName , ipt .Proto ())
62+ }
63+ if ipt .mode != ModeTypeAuto {
64+ t .Fatalf ("Expected mode auto, got %v" , ipt .mode )
65+ }
4366 }
4467
45- ip6t , err := NewWithProtocol (ProtocolIPv6 )
46- if err != nil {
47- t .Fatalf ("NewWithProtocol(ProtocolIPv6) failed: %v" , err )
48- }
49- if ip6t .Proto () != ProtocolIPv6 {
50- t .Fatalf ("Expected protocol IPv6, got %v" , ip6t .Proto ())
68+ for _ , proto := range protos {
69+ for _ , mode := range modes {
70+ protoName := getProtoName (proto )
71+ ipt , err := New (Mode (mode ), IPFamily (proto ))
72+ if err != nil {
73+ t .Fatalf ("New(Mode(%v), IPFamily(%v)) failed: %v" , mode , protoName , err )
74+ }
75+ if ipt .Proto () != proto {
76+ t .Fatalf ("Expected protocol %v, got %v" , protoName , ipt .Proto ())
77+ }
78+ if ipt .mode != mode {
79+ t .Fatalf ("Expected mode %v, got %v" , mode , ipt .mode )
80+ }
81+ }
5182 }
5283}
5384
5485func TestTimeout (t * testing.T ) {
55- ipt , err := New ()
56- if err != nil {
57- t .Fatalf ("New failed: %v" , err )
58- }
59- if ipt .timeout != 0 {
60- t .Fatalf ("Expected timeout 0 (wait forever), got %v" , ipt .timeout )
61- }
86+ for _ , proto := range protos {
87+ for _ , mode := range modes {
88+ ipt , err := New (IPFamily (proto ), Mode (mode ))
89+ if err != nil {
90+ t .Fatalf ("New failed: %v" , err )
91+ }
92+ if ipt .timeout != 0 {
93+ t .Fatalf ("Expected timeout 0 (wait forever), got %v" , ipt .timeout )
94+ }
6295
63- ipt2 , err := New (Timeout (5 ))
64- if err != nil {
65- t .Fatalf ("New failed: %v" , err )
66- }
67- if ipt2 .timeout != 5 {
68- t .Fatalf ("Expected timeout 5, got %v" , ipt .timeout )
96+ ipt2 , err := New (Timeout (5 ))
97+ if err != nil {
98+ t .Fatalf ("New failed: %v" , err )
99+ }
100+ if ipt2 .timeout != 5 {
101+ t .Fatalf ("Expected timeout 5, got %v" , ipt .timeout )
102+ }
103+ }
69104 }
105+ }
70106
107+ func TestGetIptablesVersionMode (t * testing.T ) {
108+ for _ , proto := range protos {
109+ for _ , mode := range modes {
110+ ipt , err := New (IPFamily (proto ), Mode (mode ))
111+ if err != nil {
112+ t .Fatalf ("New failed: %v" , err )
113+ }
114+ _ , _ , _ , getmode := ipt .GetIptablesVersion ()
115+ if getmode != mode {
116+ t .Fatalf ("Expected mode %v, got %v" , mode , mode )
117+ }
118+ }
119+ }
71120}
72121
73122func randChain (t * testing.T ) string {
@@ -92,27 +141,20 @@ func contains(list []string, value string) bool {
92141// features enabled & disabled, to test compatibility.
93142// We used to test noWait as well, but that was removed as of iptables v1.6.0
94143func mustTestableIptables () []* IPTables {
95- ipt , err := New ()
96- if err != nil {
97- panic (fmt .Sprintf ("New failed: %v" , err ))
98- }
99- ip6t , err := NewWithProtocol (ProtocolIPv6 )
100- if err != nil {
101- panic (fmt .Sprintf ("NewWithProtocol(ProtocolIPv6) failed: %v" , err ))
102- }
103- ipts := []* IPTables {ipt , ip6t }
104-
105- // ensure we check one variant without built-in checking
106- if ipt .hasCheck {
107- i := * ipt
108- i .hasCheck = false
109- ipts = append (ipts , & i )
110-
111- i6 := * ip6t
112- i6 .hasCheck = false
113- ipts = append (ipts , & i6 )
114- } else {
115- panic ("iptables on this machine is too old -- missing -C" )
144+ ipts := []* IPTables {}
145+ for _ , proto := range protos {
146+ for _ , mode := range modes {
147+ ipt , err := New (IPFamily (proto ), Mode (mode ))
148+ if err != nil {
149+ panic (fmt .Sprintf ("New(IPFamily(%v), Mode(%v)) failed: %v" , proto , mode , err ))
150+ }
151+ if ipt .hasCheck {
152+ ipt .hasCheck = false
153+ ipts = append (ipts , ipt )
154+ } else {
155+ panic ("iptables on this machine is too old -- missing -C" )
156+ }
157+ }
116158 }
117159 return ipts
118160}
@@ -251,7 +293,7 @@ func TestRules(t *testing.T) {
251293}
252294
253295func runRulesTests (t * testing.T , ipt * IPTables ) {
254- t .Logf ("testing %s (hasWait=%t, hasCheck=%t)" , getIptablesCommand (ipt .Proto ()), ipt .hasWait , ipt .hasCheck )
296+ t .Logf ("testing %s (hasWait=%t, hasCheck=%t)" , getIptablesCommand (ipt .Proto (), ModeTypeAuto ), ipt .hasWait , ipt .hasCheck )
255297
256298 var address1 , address2 , subnet1 , subnet2 string
257299 if ipt .Proto () == ProtocolIPv6 {
@@ -689,7 +731,7 @@ func TestExtractIptablesVersion(t *testing.T) {
689731 t .Fatalf ("unexpected err %s" , err )
690732 }
691733
692- if v1 != tt .v1 || v2 != tt .v2 || v3 != tt .v3 || mode != tt .mode {
734+ if v1 != tt .v1 || v2 != tt .v2 || v3 != tt .v3 || mode != ModeType ( tt .mode ) {
693735 t .Fatalf ("expected %d %d %d %s, got %d %d %d %s" ,
694736 tt .v1 , tt .v2 , tt .v3 , tt .mode ,
695737 v1 , v2 , v3 , mode )
0 commit comments