Skip to content

Commit 2c19fd5

Browse files
authored
Merge pull request #8 from pires/pires/pr_2_rework
api: introduce net.Listener and net.Conn wrappers
2 parents 9c0bafa + 265cd46 commit 2c19fd5

File tree

7 files changed

+635
-14
lines changed

7 files changed

+635
-14
lines changed

README.md

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,46 @@ $ go get -u github.com/pires/go-proxyproto
2525

2626
### Client (TODO)
2727

28-
### Server (TODO)
28+
### Server
29+
30+
```go
31+
package main
32+
33+
import (
34+
"log"
35+
"net"
36+
37+
proxyproto "github.com/pires/go-proxyproto"
38+
)
39+
40+
func main() {
41+
// Create a listener
42+
addr := "localhost:9876"
43+
list, err := net.Listen("tcp", addr)
44+
if err != nil {
45+
log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error())
46+
}
47+
48+
// Wrap listener in a proxyproto listener
49+
proxyListener := &proxyproto.Listener{Listener: list}
50+
defer proxyListener.Close()
51+
52+
// Wait for a connection and accept it
53+
conn, err := proxyListener.Accept()
54+
defer conn.Close()
55+
56+
// Print connection details
57+
if conn.LocalAddr() == nil {
58+
log.Fatal("couldn't retrieve local address")
59+
}
60+
log.Printf("local address: %q", conn.LocalAddr().String())
61+
62+
if conn.RemoteAddr() == nil {
63+
log.Fatal("couldn't retrieve remote address")
64+
}
65+
log.Printf("remote address: %q", conn.RemoteAddr().String())
66+
}
67+
```
2968

3069
## Documentation
3170

header.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,22 @@ type Header struct {
4141
DestinationPort uint16
4242
}
4343

44+
// RemoteAddr returns the address of the remote endpoint of the connection.
45+
func (header *Header) RemoteAddr() net.Addr {
46+
return &net.TCPAddr{
47+
IP: header.SourceAddress,
48+
Port: int(header.SourcePort),
49+
}
50+
}
51+
52+
// LocalAddr returns the address of the local endpoint of the connection.
53+
func (header *Header) LocalAddr() net.Addr {
54+
return &net.TCPAddr{
55+
IP: header.DestinationAddress,
56+
Port: int(header.DestinationPort),
57+
}
58+
}
59+
4460
// EqualTo returns true if headers are equivalent, false otherwise.
4561
// Deprecated: use EqualsTo instead. This method will eventually be removed.
4662
func (header *Header) EqualTo(otherHeader *Header) bool {
@@ -63,15 +79,25 @@ func (header *Header) EqualsTo(otherHeader *Header) bool {
6379
header.DestinationPort == otherHeader.DestinationPort
6480
}
6581

66-
// WriteTo renders a proxy protocol header in a format to write over the wire.
82+
// WriteTo renders a proxy protocol header in a format and writes it to an io.Writer.
6783
func (header *Header) WriteTo(w io.Writer) (int64, error) {
84+
buf, err := header.Format()
85+
if err != nil {
86+
return 0, err
87+
}
88+
89+
return bytes.NewBuffer(buf).WriteTo(w)
90+
}
91+
92+
// Format renders a proxy protocol header in a format to write over the wire.
93+
func (header *Header) Format() ([]byte, error) {
6894
switch header.Version {
6995
case 1:
70-
return header.writeVersion1(w)
96+
return header.formatVersion1()
7197
case 2:
72-
return header.writeVersion2(w)
98+
return header.formatVersion2()
7399
default:
74-
return 0, ErrUnknownProxyProtocolVersion
100+
return nil, ErrUnknownProxyProtocolVersion
75101
}
76102
}
77103

header_test.go

Lines changed: 169 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxyproto
22

33
import (
44
"bufio"
5+
"bytes"
56
"net"
67
"testing"
78
"time"
@@ -34,13 +35,13 @@ func TestReadTimeoutV1Invalid(t *testing.T) {
3435
reader := bufio.NewReader(&b)
3536
_, err := ReadTimeout(reader, 50*time.Millisecond)
3637
if err == nil {
37-
t.Fatalf("TestReadTimeoutV1Invalid: expected error %s", ErrNoProxyProtocol)
38+
t.Fatalf("expected error %s", ErrNoProxyProtocol)
3839
} else if err != ErrNoProxyProtocol {
39-
t.Fatalf("TestReadTimeoutV1Invalid: expected %s, actual %s", ErrNoProxyProtocol, err)
40+
t.Fatalf("expected %s, actual %s", ErrNoProxyProtocol, err)
4041
}
4142
}
4243

43-
func TestEqualTo(t *testing.T) {
44+
func TestEqualsTo(t *testing.T) {
4445
var headersEqual = []struct {
4546
this, that *Header
4647
expected bool
@@ -104,7 +105,171 @@ func TestEqualTo(t *testing.T) {
104105

105106
for _, tt := range headersEqual {
106107
if actual := tt.this.EqualsTo(tt.that); actual != tt.expected {
107-
t.Fatalf("TestEqualTo: expected %t, actual %t", tt.expected, actual)
108+
t.Fatalf("expected %t, actual %t", tt.expected, actual)
109+
}
110+
}
111+
}
112+
113+
// This is here just because of coveralls
114+
func TestEqualTo(t *testing.T) {
115+
TestEqualsTo(t)
116+
}
117+
118+
func TestLocalAddr(t *testing.T) {
119+
var headers = []struct {
120+
header *Header
121+
expectedAddr net.Addr
122+
expected bool
123+
}{
124+
{
125+
&Header{
126+
Version: 1,
127+
Command: PROXY,
128+
TransportProtocol: TCPv4,
129+
SourceAddress: net.ParseIP("10.1.1.1"),
130+
SourcePort: 1000,
131+
DestinationAddress: net.ParseIP("20.2.2.2"),
132+
DestinationPort: 2000,
133+
},
134+
&net.TCPAddr{
135+
IP: net.ParseIP("20.2.2.2"),
136+
Port: 2000,
137+
},
138+
true,
139+
},
140+
{
141+
&Header{
142+
Version: 1,
143+
Command: PROXY,
144+
TransportProtocol: TCPv4,
145+
SourceAddress: net.ParseIP("10.1.1.1"),
146+
SourcePort: 1000,
147+
DestinationAddress: net.ParseIP("20.2.2.2"),
148+
DestinationPort: 2000,
149+
},
150+
&net.TCPAddr{
151+
IP: net.ParseIP("10.1.1.1"),
152+
Port: 1000,
153+
},
154+
false,
155+
},
156+
}
157+
158+
for _, tt := range headers {
159+
actualAddr := tt.header.LocalAddr()
160+
if actual := actualAddr.String() == tt.expectedAddr.String(); actual != tt.expected {
161+
t.Fatalf("expected %t, actual %t for expectedAddr %+v and actualAddr %+v", tt.expected, actual, tt.expectedAddr, actualAddr)
162+
}
163+
}
164+
}
165+
166+
func TestRemoteAddr(t *testing.T) {
167+
var headers = []struct {
168+
header *Header
169+
expectedAddr net.Addr
170+
expected bool
171+
}{
172+
{
173+
&Header{
174+
Version: 1,
175+
Command: PROXY,
176+
TransportProtocol: TCPv4,
177+
SourceAddress: net.ParseIP("10.1.1.1"),
178+
SourcePort: 1000,
179+
DestinationAddress: net.ParseIP("20.2.2.2"),
180+
DestinationPort: 2000,
181+
},
182+
&net.TCPAddr{
183+
IP: net.ParseIP("20.2.2.2"),
184+
Port: 2000,
185+
},
186+
true,
187+
},
188+
{
189+
&Header{
190+
Version: 1,
191+
Command: PROXY,
192+
TransportProtocol: TCPv4,
193+
SourceAddress: net.ParseIP("10.1.1.1"),
194+
SourcePort: 1000,
195+
DestinationAddress: net.ParseIP("20.2.2.2"),
196+
DestinationPort: 2000,
197+
},
198+
&net.TCPAddr{
199+
IP: net.ParseIP("10.1.1.1"),
200+
Port: 1000,
201+
},
202+
false,
203+
},
204+
}
205+
206+
for _, tt := range headers {
207+
actualAddr := tt.header.LocalAddr()
208+
if actual := actualAddr.String() == tt.expectedAddr.String(); actual != tt.expected {
209+
t.Fatalf("expected %t, actual %t for expectedAddr %+v and actualAddr %+v", tt.expected, actual, tt.expectedAddr, actualAddr)
210+
}
211+
}
212+
}
213+
214+
func TestWriteTo(t *testing.T) {
215+
var buf bytes.Buffer
216+
217+
validHeader := &Header{
218+
Version: 1,
219+
Command: PROXY,
220+
TransportProtocol: TCPv4,
221+
SourceAddress: net.ParseIP("10.1.1.1"),
222+
SourcePort: 1000,
223+
DestinationAddress: net.ParseIP("20.2.2.2"),
224+
DestinationPort: 2000,
225+
}
226+
227+
if _, err := validHeader.WriteTo(&buf); err != nil {
228+
t.Fatalf("shouldn't have thrown error %q", err.Error())
229+
}
230+
231+
invalidHeader := &Header{
232+
SourceAddress: net.ParseIP("10.1.1.1"),
233+
SourcePort: 1000,
234+
DestinationAddress: net.ParseIP("20.2.2.2"),
235+
DestinationPort: 2000,
236+
}
237+
238+
if _, err := invalidHeader.WriteTo(&buf); err == nil {
239+
t.Fatalf("should have thrown error %q", err.Error())
240+
}
241+
}
242+
243+
func TestFormat(t *testing.T) {
244+
validHeader := &Header{
245+
Version: 1,
246+
Command: PROXY,
247+
TransportProtocol: TCPv4,
248+
SourceAddress: net.ParseIP("10.1.1.1"),
249+
SourcePort: 1000,
250+
DestinationAddress: net.ParseIP("20.2.2.2"),
251+
DestinationPort: 2000,
252+
}
253+
254+
if _, err := validHeader.Format(); err != nil {
255+
t.Fatalf("shouldn't have thrown error %q", err.Error())
256+
}
257+
258+
invalidHeader := &Header{
259+
Version: 3,
260+
Command: PROXY,
261+
TransportProtocol: TCPv4,
262+
SourceAddress: net.ParseIP("10.1.1.1"),
263+
SourcePort: 1000,
264+
DestinationAddress: net.ParseIP("20.2.2.2"),
265+
DestinationPort: 2000,
266+
}
267+
268+
if _, err := invalidHeader.Format(); err == nil {
269+
t.Fatalf("should have thrown error %q", err.Error())
270+
} else {
271+
if err != ErrUnknownProxyProtocolVersion {
272+
t.Fatalf("expected %q, actual %q", ErrUnknownProxyProtocolVersion.Error(), err.Error())
108273
}
109274
}
110275
}

0 commit comments

Comments
 (0)