diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index d06549566..698f23f87 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -116,13 +116,16 @@ jobs: fail-fast: false matrix: sdk-version: - - '1.10.11-0-gf0b0e7ecf-r470' - - '2.8.3-21-g7d35cd2be-r470' + - 'bundle-1.10.11-0-gf0b0e7ecf-r470' coveralls: [false] fuzzing: [false] ssl: [false] include: - - sdk-version: '2.10.0-1-gfa775b383-r486-linux-x86_64' + - sdk-version: 'bundle-2.10.0-1-gfa775b383-r486-linux-x86_64' + coveralls: false + ssl: true + - sdk-path: 'dev/linux/x86_64/master/' + sdk-version: 'sdk-gc64-2.11.0-entrypoint-113-g803baaffe-r529.linux.x86_64' coveralls: true ssl: true @@ -141,8 +144,8 @@ jobs: - name: Setup Tarantool ${{ matrix.sdk-version }} run: | - ARCHIVE_NAME=tarantool-enterprise-bundle-${{ matrix.sdk-version }}.tar.gz - curl -O -L https://${{ secrets.SDK_DOWNLOAD_TOKEN }}@download.tarantool.io/enterprise/${ARCHIVE_NAME} + ARCHIVE_NAME=tarantool-enterprise-${{ matrix.sdk-version }}.tar.gz + curl -O -L https://${{ secrets.SDK_DOWNLOAD_TOKEN }}@download.tarantool.io/enterprise/${{ matrix.sdk-path }}${ARCHIVE_NAME} tar -xzf ${ARCHIVE_NAME} rm -f ${ARCHIVE_NAME} diff --git a/CHANGELOG.md b/CHANGELOG.md index 7930b34d1..ca81649e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ Versioning](http://semver.org/spec/v2.0.0.html) except to the first release. - Error type support in MessagePack (#209) - Event subscription support (#119) - Session settings support (#215) +- pap-sha256 authorization method support (Tarantool EE feature) (#243) ### Changed diff --git a/auth.go b/auth.go index 60c219d69..2e5ddc4c4 100644 --- a/auth.go +++ b/auth.go @@ -3,8 +3,46 @@ package tarantool import ( "crypto/sha1" "encoding/base64" + "fmt" ) +const ( + chapSha1 = "chap-sha1" + papSha256 = "pap-sha256" +) + +// Auth is used as a parameter to set up an authentication method. +type Auth int + +const ( + // AutoAuth does not force any authentication method. A method will be + // selected automatically (a value from IPROTO_ID response or + // ChapSha1Auth). + AutoAuth Auth = iota + // ChapSha1Auth forces chap-sha1 authentication method. The method is + // available both in the Tarantool Community Edition (CE) and the + // Tarantool Enterprise Edition (EE) + ChapSha1Auth + // PapSha256Auth forces pap-sha256 authentication method. The method is + // available only for the Tarantool Enterprise Edition (EE) with + // SSL transport. + PapSha256Auth +) + +// String returns a string representation of an authentication method. +func (a Auth) String() string { + switch a { + case AutoAuth: + return "auto" + case ChapSha1Auth: + return chapSha1 + case PapSha256Auth: + return papSha256 + default: + return fmt.Sprintf("unknown auth type (code %d)", a) + } +} + func scramble(encodedSalt, pass string) (scramble []byte, err error) { /* ================================================================== According to: http://tarantool.org/doc/dev_guide/box-protocol.html diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 000000000..6964f2552 --- /dev/null +++ b/auth_test.go @@ -0,0 +1,28 @@ +package tarantool_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + . "github.com/tarantool/go-tarantool" +) + +func TestAuth_String(t *testing.T) { + unknownId := int(PapSha256Auth) + 1 + tests := []struct { + auth Auth + expected string + }{ + {AutoAuth, "auto"}, + {ChapSha1Auth, "chap-sha1"}, + {PapSha256Auth, "pap-sha256"}, + {Auth(unknownId), fmt.Sprintf("unknown auth type (code %d)", unknownId)}, + } + + for _, tc := range tests { + t.Run(tc.expected, func(t *testing.T) { + assert.Equal(t, tc.auth.String(), tc.expected) + }) + } +} diff --git a/config.lua b/config.lua index c8a853ff4..eadfb3825 100644 --- a/config.lua +++ b/config.lua @@ -1,6 +1,12 @@ -- Do not set listen for now so connector won't be -- able to send requests until everything is configured. +local auth_type = os.getenv("TEST_TNT_AUTH_TYPE") +if auth_type == "auto" then + auth_type = nil +end + box.cfg{ + auth_type = auth_type, work_dir = os.getenv("TEST_TNT_WORK_DIR"), memtx_use_mvcc_engine = os.getenv("TEST_TNT_MEMTX_USE_MVCC_ENGINE") == 'true' or nil, } @@ -267,5 +273,6 @@ box.space.test:truncate() -- Set listen only when every other thing is configured. box.cfg{ + auth_type = auth_type, listen = os.getenv("TEST_TNT_LISTEN"), } diff --git a/connection.go b/connection.go index b657dc2ae..c52d2e0a7 100644 --- a/connection.go +++ b/connection.go @@ -226,6 +226,8 @@ type Greeting struct { // Opts is a way to configure Connection type Opts struct { + // Auth is an authentication method. + Auth Auth // Timeout for response to a particular request. The timeout is reset when // push messages are received. If Timeout is zero, any request can be // blocked infinitely. @@ -546,19 +548,40 @@ func (conn *Connection) dial() (err error) { // Auth. if opts.User != "" { - scr, err := scramble(conn.Greeting.auth, opts.Pass) - if err != nil { - err = errors.New("auth: scrambling failure " + err.Error()) + auth := opts.Auth + if opts.Auth == AutoAuth { + if conn.serverProtocolInfo.Auth != AutoAuth { + auth = conn.serverProtocolInfo.Auth + } else { + auth = ChapSha1Auth + } + } + + var req Request + if auth == ChapSha1Auth { + salt := conn.Greeting.auth + req, err = newChapSha1AuthRequest(conn.opts.User, salt, opts.Pass) + if err != nil { + return fmt.Errorf("auth: %w", err) + } + } else if auth == PapSha256Auth { + if opts.Transport != connTransportSsl { + return errors.New("auth: forbidden to use " + auth.String() + + " unless SSL is enabled for the connection") + } + req = newPapSha256AuthRequest(conn.opts.User, opts.Pass) + } else { connection.Close() - return err + return errors.New("auth: " + auth.String()) } - if err = conn.writeAuthRequest(w, scr); err != nil { + + if err = conn.writeRequest(w, req); err != nil { connection.Close() - return err + return fmt.Errorf("auth: %w", err) } - if err = conn.readAuthResponse(r); err != nil { + if _, err = conn.readResponse(r); err != nil { connection.Close() - return err + return fmt.Errorf("auth: %w", err) } } @@ -662,28 +685,6 @@ func (conn *Connection) writeRequest(w *bufio.Writer, req Request) error { return err } -func (conn *Connection) writeAuthRequest(w *bufio.Writer, scramble []byte) error { - req := newAuthRequest(conn.opts.User, string(scramble)) - - err := conn.writeRequest(w, req) - if err != nil { - return fmt.Errorf("auth: %w", err) - } - - return nil -} - -func (conn *Connection) writeIdRequest(w *bufio.Writer, protocolInfo ProtocolInfo) error { - req := NewIdRequest(protocolInfo) - - err := conn.writeRequest(w, req) - if err != nil { - return fmt.Errorf("identify: %w", err) - } - - return nil -} - func (conn *Connection) readResponse(r io.Reader) (Response, error) { respBytes, err := conn.read(r) if err != nil { @@ -707,24 +708,6 @@ func (conn *Connection) readResponse(r io.Reader) (Response, error) { return resp, nil } -func (conn *Connection) readAuthResponse(r io.Reader) error { - _, err := conn.readResponse(r) - if err != nil { - return fmt.Errorf("auth: %w", err) - } - - return nil -} - -func (conn *Connection) readIdResponse(r io.Reader) (Response, error) { - resp, err := conn.readResponse(r) - if err != nil { - return resp, fmt.Errorf("identify: %w", err) - } - - return resp, nil -} - func (conn *Connection) createConnection(reconnect bool) (err error) { var reconnects uint for conn.c == nil && conn.state == connDisconnected { @@ -1625,19 +1608,20 @@ func checkProtocolInfo(expected ProtocolInfo, actual ProtocolInfo) error { func (conn *Connection) identify(w *bufio.Writer, r *bufio.Reader) error { var ok bool - werr := conn.writeIdRequest(w, clientProtocolInfo) + req := NewIdRequest(clientProtocolInfo) + werr := conn.writeRequest(w, req) if werr != nil { - return werr + return fmt.Errorf("identify: %w", werr) } - resp, rerr := conn.readIdResponse(r) + resp, rerr := conn.readResponse(r) if rerr != nil { if resp.Code == ErrUnknownRequestType { // IPROTO_ID requests are not supported by server. return nil } - return rerr + return fmt.Errorf("identify: %w", rerr) } if len(resp.Data) == 0 { @@ -1664,5 +1648,7 @@ func (conn *Connection) ServerProtocolInfo() ProtocolInfo { // supported by Go connection client. // Since 1.10.0 func (conn *Connection) ClientProtocolInfo() ProtocolInfo { - return clientProtocolInfo.Clone() + info := clientProtocolInfo.Clone() + info.Auth = conn.opts.Auth + return info } diff --git a/connection_pool/watcher.go b/connection_pool/watcher.go index 6a1fde8f4..2876f90bc 100644 --- a/connection_pool/watcher.go +++ b/connection_pool/watcher.go @@ -26,22 +26,24 @@ func (c *watcherContainer) add(watcher *poolWatcher) { } // remove removes a watcher from the container. -func (c *watcherContainer) remove(watcher *poolWatcher) { +func (c *watcherContainer) remove(watcher *poolWatcher) bool { c.mutex.Lock() defer c.mutex.Unlock() if watcher == c.head { c.head = watcher.next + return true } else { cur := c.head for cur.next != nil { if cur.next == watcher { cur.next = watcher.next - break + return true } cur = cur.next } } + return false } // foreach iterates over the container to the end or until the call returns @@ -83,15 +85,13 @@ type poolWatcher struct { // Unregister unregisters the pool watcher. func (w *poolWatcher) Unregister() { - w.mutex.Lock() - defer w.mutex.Unlock() - - if !w.unregistered { - w.container.remove(w) + if !w.unregistered && w.container.remove(w) { + w.mutex.Lock() w.unregistered = true for _, watcher := range w.watchers { watcher.Unregister() } + w.mutex.Unlock() } } diff --git a/const.go b/const.go index acd6a4861..ead151878 100644 --- a/const.go +++ b/const.go @@ -51,6 +51,7 @@ const ( KeyEvent = 0x57 KeyEventData = 0x58 KeyTxnIsolation = 0x59 + KeyAuthType = 0x5b KeyFieldName = 0x00 KeyFieldType = 0x01 diff --git a/protocol.go b/protocol.go index ae6142b4d..06506ee5a 100644 --- a/protocol.go +++ b/protocol.go @@ -13,6 +13,8 @@ type ProtocolFeature uint64 // ProtocolInfo type aggregates Tarantool protocol version and features info. type ProtocolInfo struct { + // Auth is an authentication method. + Auth Auth // Version is the supported protocol version. Version ProtocolVersion // Features are supported protocol features. diff --git a/request.go b/request.go index 66eb4be41..1b135eaa6 100644 --- a/request.go +++ b/request.go @@ -3,6 +3,7 @@ package tarantool import ( "context" "errors" + "fmt" "reflect" "strings" "sync" @@ -591,14 +592,30 @@ func (req *spaceIndexRequest) setIndex(index interface{}) { type authRequest struct { baseRequest - user, scramble string + auth Auth + user, pass string } -func newAuthRequest(user, scramble string) *authRequest { +func newChapSha1AuthRequest(user, salt, password string) (*authRequest, error) { + scr, err := scramble(salt, password) + if err != nil { + return nil, fmt.Errorf("scrambling failure: %w", err) + } + + req := new(authRequest) + req.requestCode = AuthRequestCode + req.auth = ChapSha1Auth + req.user = user + req.pass = string(scr) + return req, nil +} + +func newPapSha256AuthRequest(user, password string) *authRequest { req := new(authRequest) req.requestCode = AuthRequestCode + req.auth = PapSha256Auth req.user = user - req.scramble = scramble + req.pass = password return req } @@ -606,7 +623,7 @@ func newAuthRequest(user, scramble string) *authRequest { func (req *authRequest) Body(res SchemaResolver, enc *encoder) error { return enc.Encode(map[uint32]interface{}{ KeyUserName: req.user, - KeyTuple: []interface{}{string("chap-sha1"), string(req.scramble)}, + KeyTuple: []interface{}{req.auth.String(), req.pass}, }) } diff --git a/response.go b/response.go index 6c3f69c99..dc747c852 100644 --- a/response.go +++ b/response.go @@ -213,6 +213,21 @@ func (resp *Response) decodeBody() (err error) { } serverProtocolInfo.Features[i] = feature } + case KeyAuthType: + var auth string + if auth, err = d.DecodeString(); err != nil { + return err + } + found := false + for _, a := range [...]Auth{ChapSha1Auth, PapSha256Auth} { + if auth == a.String() { + serverProtocolInfo.Auth = a + found = true + } + } + if !found { + return fmt.Errorf("unknown auth type %s", auth) + } default: if err = d.Skip(); err != nil { return err diff --git a/ssl_test.go b/ssl_test.go index 769508284..ca0773b58 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -94,7 +94,7 @@ func createClientServerSslOk(t testing.TB, serverOpts, return l, c, msgs, errs } -func serverTnt(serverOpts, clientOpts SslOpts) (test_helpers.TarantoolInstance, error) { +func serverTnt(serverOpts, clientOpts SslOpts, auth Auth) (test_helpers.TarantoolInstance, error) { listen := tntHost + "?transport=ssl&" key := serverOpts.KeyFile @@ -120,6 +120,7 @@ func serverTnt(serverOpts, clientOpts SslOpts) (test_helpers.TarantoolInstance, listen = listen[:len(listen)-1] return test_helpers.StartTarantool(test_helpers.StartOpts{ + Auth: auth, InitScript: "config.lua", Listen: listen, SslCertsDir: "testdata", @@ -170,7 +171,7 @@ func assertConnectionSslOk(t testing.TB, serverOpts, clientOpts SslOpts) { func assertConnectionTntFail(t testing.TB, serverOpts, clientOpts SslOpts) { t.Helper() - inst, err := serverTnt(serverOpts, clientOpts) + inst, err := serverTnt(serverOpts, clientOpts, AutoAuth) serverTntStop(inst) if err == nil { @@ -181,7 +182,7 @@ func assertConnectionTntFail(t testing.TB, serverOpts, clientOpts SslOpts) { func assertConnectionTntOk(t testing.TB, serverOpts, clientOpts SslOpts) { t.Helper() - inst, err := serverTnt(serverOpts, clientOpts) + inst, err := serverTnt(serverOpts, clientOpts, AutoAuth) serverTntStop(inst) if err != nil { @@ -432,12 +433,14 @@ var tests = []test{ }, } -func TestSslOpts(t *testing.T) { +func isTestTntSsl() bool { testTntSsl, exists := os.LookupEnv("TEST_TNT_SSL") - isTntSsl := false - if exists && (testTntSsl == "1" || strings.ToUpper(testTntSsl) == "TRUE") { - isTntSsl = true - } + return exists && + (testTntSsl == "1" || strings.ToUpper(testTntSsl) == "TRUE") +} + +func TestSslOpts(t *testing.T) { + isTntSsl := isTestTntSsl() for _, test := range tests { if test.ok { @@ -463,3 +466,39 @@ func TestSslOpts(t *testing.T) { } } } + +func TestOpts_PapSha256Auth(t *testing.T) { + isTntSsl := isTestTntSsl() + if !isTntSsl { + t.Skip("TEST_TNT_SSL is not set") + } + + isLess, err := test_helpers.IsTarantoolVersionLess(2, 11, 0) + if err != nil { + t.Fatalf("Could not check Tarantool version.") + } + if isLess { + t.Skip("Skipping test for Tarantoo without pap-sha256 support") + } + + sslOpts := SslOpts{ + KeyFile: "testdata/localhost.key", + CertFile: "testdata/localhost.crt", + } + inst, err := serverTnt(sslOpts, sslOpts, PapSha256Auth) + defer serverTntStop(inst) + if err != nil { + t.Errorf("An unexpected server error: %s", err) + } + + clientOpts := opts + clientOpts.Transport = "ssl" + clientOpts.Ssl = sslOpts + clientOpts.Auth = PapSha256Auth + conn := test_helpers.ConnectWithValidation(t, tntHost, clientOpts) + conn.Close() + + clientOpts.Auth = AutoAuth + conn = test_helpers.ConnectWithValidation(t, tntHost, clientOpts) + conn.Close() +} diff --git a/tarantool_test.go b/tarantool_test.go index 96bf02254..2c76071b3 100644 --- a/tarantool_test.go +++ b/tarantool_test.go @@ -726,6 +726,38 @@ func BenchmarkSQLSerial(b *testing.B) { } } +func TestOptsAuth_Default(t *testing.T) { + defaultOpts := opts + defaultOpts.Auth = AutoAuth + + conn := test_helpers.ConnectWithValidation(t, server, defaultOpts) + defer conn.Close() +} + +func TestOptsAuth_ChapSha1Auth(t *testing.T) { + chapSha1Opts := opts + chapSha1Opts.Auth = ChapSha1Auth + + conn := test_helpers.ConnectWithValidation(t, server, chapSha1Opts) + defer conn.Close() +} + +func TestOptsAuth_PapSha256AuthForbit(t *testing.T) { + papSha256Opts := opts + papSha256Opts.Auth = PapSha256Auth + + conn, err := Connect(server, papSha256Opts) + if err == nil { + t.Error("An error expected.") + conn.Close() + } + + if err.Error() != "auth: forbidden to use pap-sha256 unless "+ + "SSL is enabled for the connection" { + t.Errorf("An unexpected error: %s", err) + } +} + func TestFutureMultipleGetGetTyped(t *testing.T) { conn := test_helpers.ConnectWithValidation(t, server, opts) defer conn.Close() diff --git a/test_helpers/main.go b/test_helpers/main.go index f6e745617..f2a97bef9 100644 --- a/test_helpers/main.go +++ b/test_helpers/main.go @@ -27,6 +27,9 @@ import ( ) type StartOpts struct { + // Auth is an authentication method for a Tarantool instance. + Auth tarantool.Auth + // InitScript is a Lua script for tarantool to run on start. InitScript string @@ -223,6 +226,7 @@ func StartTarantool(startOpts StartOpts) (TarantoolInstance, error) { fmt.Sprintf("TEST_TNT_WORK_DIR=%s", startOpts.WorkDir), fmt.Sprintf("TEST_TNT_LISTEN=%s", startOpts.Listen), fmt.Sprintf("TEST_TNT_MEMTX_USE_MVCC_ENGINE=%t", startOpts.MemtxUseMvccEngine), + fmt.Sprintf("TEST_TNT_AUTH_TYPE=%s", startOpts.Auth), ) // Copy SSL certificates. @@ -248,6 +252,7 @@ func StartTarantool(startOpts StartOpts) (TarantoolInstance, error) { time.Sleep(startOpts.WaitStart) opts := tarantool.Opts{ + Auth: startOpts.Auth, Timeout: 500 * time.Millisecond, User: startOpts.User, Pass: startOpts.Pass,