diff --git a/.github/workflows/dart.yml b/.github/workflows/dart.yml index 9d4f6b45..d4bf7909 100644 --- a/.github/workflows/dart.yml +++ b/.github/workflows/dart.yml @@ -14,6 +14,7 @@ on: jobs: build: runs-on: ubuntu-latest + timeout-minutes: 10 steps: - uses: actions/checkout@v3 diff --git a/lib/src/message_window.dart b/lib/src/message_window.dart index 9d266cee..d81c7865 100644 --- a/lib/src/message_window.dart +++ b/lib/src/message_window.dart @@ -29,6 +29,7 @@ Map _messageTypeMap = { 110: (d) => NoDataMessage(), 116: ParameterDescriptionMessage.new, $3: (d) => CloseCompleteMessage(), + $N: NoticeMessage.new, }; class MessageFramer { diff --git a/lib/src/server_messages.dart b/lib/src/server_messages.dart index c8d926c3..3c073ab5 100644 --- a/lib/src/server_messages.dart +++ b/lib/src/server_messages.dart @@ -11,10 +11,10 @@ import 'types.dart'; abstract class ServerMessage extends BaseMessage {} -class ErrorResponseMessage implements ServerMessage { +sealed class ErrorOrNoticeMessage implements ServerMessage { final fields = []; - ErrorResponseMessage(Uint8List bytes) { + ErrorOrNoticeMessage(Uint8List bytes) { final reader = ByteDataReader()..add(bytes); int? identificationToken; @@ -39,6 +39,10 @@ class ErrorResponseMessage implements ServerMessage { } } +class ErrorResponseMessage extends ErrorOrNoticeMessage { + ErrorResponseMessage(super.bytes); +} + class AuthenticationMessage implements ServerMessage { static const int KindOK = 0; static const int KindKerberosV5 = 2; @@ -244,6 +248,10 @@ class NoDataMessage extends ServerMessage { String toString() => 'No Data Message'; } +class NoticeMessage extends ErrorOrNoticeMessage { + NoticeMessage(super.bytes); +} + /// Identifies the message as a Start Copy Both response. /// This message is used only for Streaming Replication. class CopyBothResponseMessage implements ServerMessage { diff --git a/lib/src/v3/connection.dart b/lib/src/v3/connection.dart index d83f9ed4..119c21e3 100644 --- a/lib/src/v3/connection.dart +++ b/lib/src/v3/connection.dart @@ -12,7 +12,7 @@ import 'package:postgres/src/replication.dart'; import 'package:stream_channel/stream_channel.dart'; import '../auth/auth.dart'; -import '../connection.dart' show PostgreSQLException; +import '../connection.dart' show PostgreSQLException, PostgreSQLSeverity; import 'protocol.dart'; import 'query_description.dart'; @@ -73,40 +73,41 @@ abstract class _PgSessionBase implements PgSession { PgConnectionImplementation get _connection; + /// Runs [callback], guarded by [_operationLock] and cleans up the pending + /// resource afterwards. + Future _withResource(FutureOr Function() callback) { + if (_connection._isClosing) { + throw PostgreSQLException('Connection is closing down'); + } + + return _operationLock.withResource(() async { + assert(_connection._pending == null); + + try { + return await callback(); + } finally { + _connection._pending = null; + } + }); + } + /// Sends a message to the server and waits for a response [T], gracefully /// handling error messages that might come in instead. Future _sendAndWaitForQuery(ClientMessage send) { final trace = StackTrace.current; - return _operationLock.withResource(() { + return _withResource(() { _connection._channel.sink .add(AggregatedClientMessage([send, const SyncMessage()])); - final completer = Completer(); - final syncComplete = Completer.sync(); - - _connection._pending = _CallbackOperation(_connection, (message) async { - if (message is T) { - completer.complete(message); - } else if (message is ErrorResponseMessage) { - completer.completeError( - PostgreSQLException.fromFields(message.fields), trace); - } else if (message is ReadyForQueryMessage) { - if (!completer.isCompleted) { - completer.completeError( - StateError('Operation did not complete'), trace); - } + final wait = _connection._pending = _WaitForMessage(this, trace); - syncComplete.complete(); - } else { - completer.completeError( - StateError('Unexpected message $message'), trace); - } - }); + return wait.doneWithOperation.future.then((value) { + final effectiveResult = wait.result ?? + Result.error(StateError('Operation did not complete'), trace); - return syncComplete.future - .whenComplete(() => _connection._pending = null) - .then((value) => completer.future); + return effectiveResult.asFuture; + }); }); } @@ -278,6 +279,7 @@ class PgConnectionImplementation extends _PgSessionBase final StreamChannel _channel; late final StreamSubscription _serverMessages; + bool _isClosing = false; final _ResolvedSettings _settings; @@ -301,7 +303,7 @@ class PgConnectionImplementation extends _PgSessionBase } Future _startup() { - return _operationLock.withResource(() { + return _withResource(() { final result = _pending = _AuthenticationProcedure(this); _channel.sink.add(StartupMessage( @@ -322,10 +324,20 @@ class PgConnectionImplementation extends _PgSessionBase if (message is ParameterStatusMessage) { _parameters[message.name] = message.value; - } else if (message is BackendKeyMessage) { + } else if (message is BackendKeyMessage || message is NoticeMessage) { // ignore for now } else if (message is NotificationResponseMessage) { _channels.deliverNotification(message); + } else if (message is ErrorResponseMessage) { + final exception = PostgreSQLException.fromFields(message.fields); + + // Close the connection in response to fatal errors or if we get them + // out of nowhere. + if (exception.willAbortConnection || _pending == null) { + _closeAfterError(exception); + } else { + _pending!.handleError(exception); + } } else if (_pending != null) { await _pending!.handleMessage(message); } @@ -370,12 +382,30 @@ class PgConnectionImplementation extends _PgSessionBase @override Future close() async { - await _operationLock.withResource(() { - // Use lock to await earlier operations - _channel.sink.add(const TerminateMessage()); - }); + await _close(false, null); + } + + Future _close(bool interruptRunning, PostgreSQLException? cause) async { + if (!_isClosing) { + _isClosing = true; + + if (interruptRunning) { + _pending?.handleConnectionClosed(cause); + _channel.sink.add(const TerminateMessage()); + } else { + // Wait for the previous operation to complete by using the lock + await _operationLock.withResource(() { + // Use lock to await earlier operations + _channel.sink.add(const TerminateMessage()); + }); + } + + await Future.wait([_channel.sink.close(), _serverMessages.cancel()]); + } + } - await Future.wait([_channel.sink.close(), _serverMessages.cancel()]); + void _closeAfterError([PostgreSQLException? cause]) { + _close(true, cause); } } @@ -393,8 +423,11 @@ class _PreparedStatement extends PgStatement { @override Future dispose() async { - await _session._sendAndWaitForQuery( - CloseMessage.statement(_name)); + // Don't send a dispose message if the connection is already closed. + if (!_session._connection._isClosing) { + await _session._sendAndWaitForQuery( + CloseMessage.statement(_name)); + } } } @@ -438,7 +471,7 @@ class _PgResultStreamSubscription _BoundStatement statement, this._controller, this._source) : session = statement.statement._session, ignoreRows = false { - session._operationLock.withResource(() async { + session._withResource(() async { connection._pending = this; connection._channel.sink.add(AggregatedClientMessage([ @@ -462,7 +495,7 @@ class _PgResultStreamSubscription _PgResultStreamSubscription.simpleQueryAndIgnoreRows( String sql, this.session, this._controller, this._source) : ignoreRows = true { - session._operationLock.withResource(() async { + session._withResource(() async { connection._pending = this; connection._channel.sink.add(QueryMessage(sql)); @@ -476,59 +509,80 @@ class _PgResultStreamSubscription @override Future get schema => _schema.future; + Future _completeQuery() async { + _done.complete(); + + // Make sure the affectedRows and schema futures complete with something + // after the query is done, even if we didn't get a row description + // message. + if (!_affectedRows.isCompleted) { + _affectedRows.complete(0); + } + if (!_schema.isCompleted) { + _schema.complete(PgResultSchema(const [])); + } + await _controller.close(); + } + @override - Future handleMessage(ServerMessage message) async { - if (message is ErrorResponseMessage) { - _controller.addError( - PostgreSQLException.fromFields(message.fields), StackTrace.current); - } else if (message is BindCompleteMessage) { - // Nothing to do - } else if (message is RowDescriptionMessage) { - final schema = _resultSchema = PgResultSchema([ - for (final field in message.fieldDescriptions) - PgResultColumn( - type: PgDataType.byTypeOid[field.typeId] ?? PgDataType.byteArray, - tableName: field.tableName, - columnName: field.columnName, - columnOid: field.columnID, - tableOid: field.tableID, - binaryEncoding: field.formatCode != 0, - ), - ]); - _schema.complete(schema); - } else if (message is DataRowMessage) { - if (!ignoreRows) { - final schema = _resultSchema!; - - final columnValues = []; - for (var i = 0; i < message.values.length; i++) { - final field = schema.columns[i]; - - final type = field.type; - final codec = - field.binaryEncoding ? type.binaryCodec : type.textCodec; - - columnValues.add(codec.decode(message.values[i])); - } + void handleConnectionClosed(PostgreSQLException? dueToException) { + if (dueToException != null) { + _controller.addError(dueToException); + } + _completeQuery(); + } - final row = _ResultRow(schema, columnValues); - _controller.add(row); - } - } else if (message is CommandCompleteMessage) { - _affectedRows.complete(message.rowsAffected); - } else if (message is ReadyForQueryMessage) { - _done.complete(); + @override + void handleError(PostgreSQLException exception) { + _controller.addError(exception); + } - // Make sure the affectedRows and schema futures complete with something - // after the query is done, even if we didn't get a row description - // message. - if (!_affectedRows.isCompleted) { - _affectedRows.complete(0); - } - if (!_schema.isCompleted) { - _schema.complete(PgResultSchema(const [])); - } - await _controller.close(); + @override + Future handleMessage(ServerMessage message) async { + switch (message) { + case BindCompleteMessage(): + case NoDataMessage(): + // Nothing to do! + break; + case RowDescriptionMessage(): + final schema = _resultSchema = PgResultSchema([ + for (final field in message.fieldDescriptions) + PgResultColumn( + type: PgDataType.byTypeOid[field.typeId] ?? PgDataType.byteArray, + tableName: field.tableName, + columnName: field.columnName, + columnOid: field.columnID, + tableOid: field.tableID, + binaryEncoding: field.formatCode != 0, + ), + ]); + _schema.complete(schema); + case DataRowMessage(): + if (!ignoreRows) { + final schema = _resultSchema!; + + final columnValues = []; + for (var i = 0; i < message.values.length; i++) { + final field = schema.columns[i]; + + final type = field.type; + final codec = + field.binaryEncoding ? type.binaryCodec : type.textCodec; + + columnValues.add(codec.decode(message.values[i])); + } + + final row = _ResultRow(schema, columnValues); + _controller.add(row); + } + case CommandCompleteMessage(): + _affectedRows.complete(message.rowsAffected); + case ReadyForQueryMessage(): + await _completeQuery(); + default: + // Unexpected message - either a severe bug in this package or in the + // connection. We better close it. + session._connection._closeAfterError(); } } @@ -681,6 +735,18 @@ abstract class _PendingOperation { _PendingOperation(this.session); + /// Handle the connection being closed, either because it has been closed + /// explicitly or because a fatal exception is interrupting the connection. + void handleConnectionClosed(PostgreSQLException? dueToException); + + /// Handles an [ErrorResponseMessage] in an exception form. If the exception + /// is severe enough to close the connection, [handleConnectionClosed] will + /// be called instead. + void handleError(PostgreSQLException exception); + + /// Handles a message from the postgres server. The [message] will never be + /// a [ErrorResponseMessage] - these are delivered through [handleError] or + /// [handleConnectionClosed]. Future handleMessage(ServerMessage message); } @@ -691,13 +757,48 @@ class _ResultRow extends UnmodifiableListView implements PgResultRow { _ResultRow(this.schema, super.source); } -class _CallbackOperation extends _PendingOperation { - final Future Function(ServerMessage message) handle; +class _WaitForMessage extends _PendingOperation { + final StackTrace trace; + final doneWithOperation = Completer.sync(); + Result? result; - _CallbackOperation(super.connection, this.handle); + _WaitForMessage(super.session, this.trace); @override - Future handleMessage(ServerMessage message) => handle(message); + void handleConnectionClosed(PostgreSQLException? dueToException) { + result = Result.error( + dueToException ?? + PostgreSQLException('Connection closed while waiting for message'), + trace, + ); + doneWithOperation.complete(); + } + + @override + void handleError(PostgreSQLException exception) { + result = Result.error(exception, trace); + // We're not done yet! Exceptions delivered through handleError aren't + // fatal, so we'll continue waiting for a ReadyForQuery message. + } + + @override + Future handleMessage(ServerMessage message) async { + if (message is T) { + result = Result.value(message); + // Don't complete, we're still waiting for a ready for query message. + } else if (message is ReadyForQueryMessage) { + // This is the message we've been waiting for, the server is signalling + // that it's ready for another message - so we can release the lock. + doneWithOperation.complete(); + } else { + result = Result.error(StateError('Unexpected message $message'), trace); + + // If we get here, we clearly have a misunderstanding about the + // protocol or something is very seriously broken. Treat this as a + // critical flaw and close the connection as well. + session._connection._closeAfterError(); + } + } } class _AuthenticationProcedure extends _PendingOperation { @@ -719,6 +820,21 @@ class _AuthenticationProcedure extends _PendingOperation { ..onMessage(message); } + @override + void handleConnectionClosed(PostgreSQLException? dueToException) { + _done.completeError(dueToException ?? + PostgreSQLException('Connection closed during authentication')); + } + + @override + void handleError(PostgreSQLException exception) { + _done.completeError(exception); + + // If the authentication procedure fails, the connection is unusable - so we + // might as well close it right away. + session._connection._closeAfterError(); + } + @override Future handleMessage(ServerMessage message) async { if (message is ErrorResponseMessage) { @@ -758,3 +874,10 @@ class _AuthenticationProcedure extends _PendingOperation { } } } + +extension on PostgreSQLException { + bool get willAbortConnection { + return severity == PostgreSQLSeverity.fatal || + severity == PostgreSQLSeverity.panic; + } +} diff --git a/test/fixme/v3_close_test.dart b/test/fixme/v3_close_test.dart deleted file mode 100644 index 4e7cde07..00000000 --- a/test/fixme/v3_close_test.dart +++ /dev/null @@ -1,69 +0,0 @@ -import 'package:postgres/postgres_v3_experimental.dart'; -import 'package:test/test.dart'; - -import '../docker.dart'; - -void main() { - // NOTE: The Docker Container will not close after stopping this test so that needs to be done manually. - usePostgresDocker(); - - group('service-side connection close', - skip: 'the error is not caught or handled properly', () { - // ignore: unused_local_variable - late final PgConnection conn1; - late final PgConnection conn2; - - setUpAll(() async { - conn1 = await PgConnection.open( - PgEndpoint( - host: 'localhost', - database: 'dart_test', - username: 'dart', - password: 'dart', - ), - sessionSettings: PgSessionSettings( - onBadSslCertificate: (cert) => true, - ), - ); - - conn2 = await PgConnection.open( - PgEndpoint( - host: 'localhost', - database: 'dart_test', - username: 'postgres', - password: 'postgres', - ), - sessionSettings: PgSessionSettings( - onBadSslCertificate: (cert) => true, - ), - ); - }); - - test('produce error', () async { - // get conn1 PID - final res = await conn2 - .execute("SELECT pid FROM pg_stat_activity where usename = 'dart';"); - final conn1PID = res.first.first as int; - - // Simulate issue by terminating a connection during a query - // ignore: unawaited_futures - conn1.execute( - 'select * from pg_stat_activity;'); // comment this out and a different error will appear - - // Terminate the conn1 while the query is running - await conn2.execute( - 'select pg_terminate_backend($conn1PID) from pg_stat_activity;'); - // this will cause the following exception: - // PostgreSQLException (PostgreSQLSeverity.fatal 57P01: terminating connection due to administrator command ) - - expect(true, true); - }); - - tearDownAll(() async { - print('closing conn1'); - await conn1.close(); // this will never close & execution will hang here - print('closing conn2'); - await conn2.close(); - }); - }); -} diff --git a/test/v3_close_test.dart b/test/v3_close_test.dart new file mode 100644 index 00000000..6c6478c0 --- /dev/null +++ b/test/v3_close_test.dart @@ -0,0 +1,81 @@ +import 'package:postgres/postgres.dart' show PostgreSQLException; +import 'package:postgres/postgres_v3_experimental.dart'; +import 'package:test/test.dart'; + +import 'docker.dart'; + +final _endpoint = PgEndpoint( + host: 'localhost', + database: 'dart_test', + username: 'dart', + password: 'dart', +); + +void main() { + usePostgresDocker(); + + late PgConnection conn1; + late PgConnection conn2; + + setUp(() async { + conn1 = await PgConnection.open( + _endpoint, + sessionSettings: PgSessionSettings( + onBadSslCertificate: (cert) => true, + //transformer: _loggingTransformer('c1'), + ), + ); + + conn2 = await PgConnection.open( + _endpoint, + sessionSettings: PgSessionSettings( + onBadSslCertificate: (cert) => true, + ), + ); + }); + + tearDown(() async { + await conn1.close(); + await conn2.close(); + }); + + for (final concurrentQuery in [false, true]) { + test( + 'with concurrent query: $concurrentQuery', + () async { + final res = await conn2.execute( + "SELECT pid FROM pg_stat_activity where usename = 'dart';"); + final conn1PID = res.first.first as int; + + // Simulate issue by terminating a connection during a query + if (concurrentQuery) { + // We expect that terminating the connection will throw. + expect(conn1.execute('select pg_sleep(1) from pg_stat_activity;'), + _throwsPostgresException); + } + + // Terminate the conn1 while the query is running + await conn2.execute('select pg_terminate_backend($conn1PID);'); + }, + ); + } + + test('with simple query protocol', () async { + // Get the PID for conn1 + final res = await conn2 + .execute("SELECT pid FROM pg_stat_activity where usename = 'dart';"); + final conn1PID = res.first.first as int; + + // ignore: unawaited_futures + expect( + conn1.execute('select pg_sleep(1) from pg_stat_activity;', + ignoreRows: true), + _throwsPostgresException); + + await conn2.execute( + 'select pg_terminate_backend($conn1PID) from pg_stat_activity;'); + }); +} + +final _isPostgresException = isA(); +final _throwsPostgresException = throwsA(_isPostgresException); diff --git a/test/v3_test.dart b/test/v3_test.dart index b4db4cfd..a4a9669e 100644 --- a/test/v3_test.dart +++ b/test/v3_test.dart @@ -23,9 +23,10 @@ final _endpoint = PgEndpoint( // // Logger.root.level = Level.ALL; // Logger.root.onRecord.listen((r) => print('${r.loggerName}: ${r.message}')); -StreamChannelTransformer get _loggingTransformer { - final inLogger = Logger('postgres.connection.in'); - final outLogger = Logger('postgres.connection.out'); +StreamChannelTransformer _loggingTransformer( + String prefix) { + final inLogger = Logger('postgres.connection.$prefix.in'); + final outLogger = Logger('postgres.connection.$prefix.out'); return StreamChannelTransformer( StreamTransformer.fromHandlers( @@ -47,7 +48,7 @@ final _sessionSettings = PgSessionSettings( // To test SSL, we're running postgres with a self-signed certificate. onBadSslCertificate: (cert) => true, - transformer: _loggingTransformer, + transformer: _loggingTransformer('conn'), ); void main() { @@ -189,8 +190,12 @@ void main() { test('for duplicate with simple query', () async { await expectLater( - () => connection.execute('INSERT INTO foo VALUES (1);'), - _throwsPostgresException); + () => connection.execute('INSERT INTO foo VALUES (1);'), + _throwsPostgresException, + ); + + // Make sure the connection is still usable. + await connection.execute('SELECT 1'); }); test('for duplicate with extended query', () async { @@ -201,6 +206,9 @@ void main() { ), _throwsPostgresException, ); + + // Make sure the connection is still in a usable state. + await connection.execute('SELECT 1'); }); test('for duplicate in prepared statement', () async { @@ -409,6 +417,82 @@ void main() { expect(incoming, contains(isA())); expect(outgoing, contains(isA())); }); + + group('can close connection after error conditions', () { + late PgConnection conn1; + late PgConnection conn2; + + setUp(() async { + conn1 = await PgConnection.open( + PgEndpoint( + host: 'localhost', + database: 'dart_test', + username: 'dart', + password: 'dart', + ), + sessionSettings: PgSessionSettings( + transformer: _loggingTransformer('c1'), + onBadSslCertificate: (cert) => true, + ), + ); + + conn2 = await PgConnection.open( + PgEndpoint( + host: 'localhost', + database: 'dart_test', + username: 'postgres', + password: 'postgres', + ), + sessionSettings: PgSessionSettings( + transformer: _loggingTransformer('c2'), + onBadSslCertificate: (cert) => true, + ), + ); + }); + + tearDown(() async { + await conn1.close(); + await conn2.close(); + }); + + for (final concurrentQuery in [false, true]) { + test( + 'with concurrent query: $concurrentQuery', + () async { + final res = await conn2.execute( + "SELECT pid FROM pg_stat_activity where usename = 'dart';"); + final conn1PID = res.first.first as int; + + // Simulate issue by terminating a connection during a query + if (concurrentQuery) { + // We expect that terminating the connection will throw. Use + // pg_sleep to avoid flaky race conditions between the conditions. + expect(conn1.execute('select pg_sleep(1) from pg_stat_activity;'), + _throwsPostgresException); + } + + // Terminate the conn1 while the query is running + await conn2.execute('select pg_terminate_backend($conn1PID);'); + }, + ); + } + + test('with simple query protocol', () async { + // Get the PID for conn1 + final res = await conn2 + .execute("SELECT pid FROM pg_stat_activity where usename = 'dart';"); + final conn1PID = res.first.first as int; + + expect( + conn1.execute('select pg_sleep(1) from pg_stat_activity;', + ignoreRows: true), + _throwsPostgresException, + ); + + await conn2.execute( + 'select pg_terminate_backend($conn1PID) from pg_stat_activity;'); + }); + }); } final _isPostgresException = isA();