Skip to content

ssh: export a transport interface #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ssh/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ func (ch *channel) writePacket(packet []byte) error {
return io.EOF
}
ch.sentClose = (packet[0] == msgChannelClose)
err := ch.mux.conn.writePacket(packet)
err := ch.mux.conn.WritePacket(packet)
ch.writeMu.Unlock()
return err
}
Expand Down
7 changes: 7 additions & 0 deletions ssh/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
}

func NewClientConnFromTransport(t Transport) (Conn, <-chan NewChannel, <-chan *Request, error) {
conn := &connection{
mux: newMux(t),
}
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil
}

// clientHandshake performs the client side key exchange. See RFC 4253 Section
// 7.
func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error {
Expand Down
8 changes: 8 additions & 0 deletions ssh/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ func (t *handshakeTransport) printPacket(p []byte, write bool) {
}
}

func (t *handshakeTransport) ReadPacket() ([]byte, error) {
return t.readPacket()
}

func (t *handshakeTransport) readPacket() ([]byte, error) {
p, ok := <-t.incoming
if !ok {
Expand Down Expand Up @@ -479,6 +483,10 @@ func (t *handshakeTransport) sendKexInit() error {
return nil
}

func (t *handshakeTransport) WritePacket(p []byte) error {
return t.writePacket(p)
}

func (t *handshakeTransport) writePacket(p []byte) error {
switch p[0] {
case msgKexInit:
Expand Down
10 changes: 9 additions & 1 deletion ssh/mempipe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ type memTransport struct {
*sync.Cond
}

func (t *memTransport) ReadPacket() ([]byte, error) {
return t.readPacket()
}

func (t *memTransport) readPacket() ([]byte, error) {
t.Lock()
defer t.Unlock()
Expand Down Expand Up @@ -53,6 +57,10 @@ func (t *memTransport) Close() error {
return err
}

func (t *memTransport) WritePacket(p []byte) error {
return t.writePacket(p)
}

func (t *memTransport) writePacket(p []byte) error {
t.write.Lock()
defer t.write.Unlock()
Expand All @@ -66,7 +74,7 @@ func (t *memTransport) writePacket(p []byte) error {
return nil
}

func memPipe() (a, b packetConn) {
func memPipe() (a, b *memTransport) {
t1 := memTransport{}
t2 := memTransport{}
t1.write = &t2
Expand Down
8 changes: 4 additions & 4 deletions ssh/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ func (c *chanList) dropAll() []*channel {
// mux represents the state for the SSH connection protocol, which
// multiplexes many channels onto a single packet transport.
type mux struct {
conn packetConn
conn Transport
chanList chanList

incomingChannels chan NewChannel
Expand All @@ -113,7 +113,7 @@ func (m *mux) Wait() error {
}

// newMux returns a mux that runs over the given connection.
func newMux(p packetConn) *mux {
func newMux(p Transport) *mux {
m := &mux{
conn: p,
incomingChannels: make(chan NewChannel, chanSize),
Expand All @@ -134,7 +134,7 @@ func (m *mux) sendMessage(msg interface{}) error {
if debugMux {
log.Printf("send global(%d): %#v", m.chanList.offset, msg)
}
return m.conn.writePacket(p)
return m.conn.WritePacket(p)
}

func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
Expand Down Expand Up @@ -212,7 +212,7 @@ func (m *mux) loop() {

// onePacket reads and processes one packet.
func (m *mux) onePacket() error {
packet, err := m.conn.readPacket()
packet, err := m.conn.ReadPacket()
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions ssh/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func TestMuxChannelOverflow(t *testing.T) {
marshalUint32(packet[5:], uint32(1))
packet[9] = 42

if err := writer.mux.conn.writePacket(packet); err != nil {
if err := writer.mux.conn.WritePacket(packet); err != nil {
t.Errorf("could not send packet")
}
if _, err := reader.SendRequest("hello", true, nil); err == nil {
Expand Down Expand Up @@ -432,7 +432,7 @@ func TestMuxInvalidRecord(t *testing.T) {
marshalUint32(packet[5:], 1)
packet[9] = 42

a.conn.writePacket(packet)
a.conn.WritePacket(packet)
go a.SendRequest("hello", false, nil)
// 'a' wrote an invalid packet, so 'b' has exited.
req, ok := <-b.incomingRequests
Expand Down Expand Up @@ -475,7 +475,7 @@ func TestMuxMaxPacketSize(t *testing.T) {
marshalUint32(packet[5:], uint32(len(large)))
packet[9] = 42

if err := a.mux.conn.writePacket(packet); err != nil {
if err := a.mux.conn.WritePacket(packet); err != nil {
t.Errorf("could not send packet")
}

Expand Down
15 changes: 15 additions & 0 deletions ssh/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ type packetConn interface {
Close() error
}

// Transport represents a connection that implements packet based operations as
// specified by SSH Transport Protocol (RFC 4253).
type Transport interface {
// WritePacket encrypts and sends a packet of data to the remote peer.
WritePacket([]byte) error

// ReadPacket reads and decrypts a packet of data from the remote peer. The
// read is blocking. If error is nil then the returned byte slice is always
// non-empty.
ReadPacket() ([]byte, error)

// Close closes the connection with the remote peer.
Close() error
}

// transport is the keyingTransport that implements the SSH packet
// protocol.
type transport struct {
Expand Down