diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 8e841871..43339346 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -14,6 +14,7 @@ import sys import time import traceback +import typing import warnings from . import compat @@ -25,6 +26,7 @@ from . import protocol from . import serverversion from . import transaction +from . import types from . import utils @@ -179,11 +181,11 @@ def remove_log_listener(self, callback): """ self._log_listeners.discard(callback) - def get_server_pid(self): + def get_server_pid(self) -> int: """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() - def get_server_version(self): + def get_server_version(self) -> types.ServerVersion: """Return the version of the connected PostgreSQL server. The returned value is a named tuple similar to that in @@ -199,7 +201,7 @@ def get_server_version(self): """ return self._server_version - def get_settings(self): + def get_settings(self) -> protocol.ConnectionSettings: """Return connection settings. :return: :class:`~asyncpg.ConnectionSettings`. @@ -207,7 +209,7 @@ def get_settings(self): return self._protocol.get_settings() def transaction(self, *, isolation='read_committed', readonly=False, - deferrable=False): + deferrable=False) -> transaction.Transaction: """Create a :class:`~transaction.Transaction` object. Refer to `PostgreSQL documentation`_ on the meaning of transaction @@ -230,7 +232,7 @@ def transaction(self, *, isolation='read_committed', readonly=False, self._check_open() return transaction.Transaction(self, isolation, readonly, deferrable) - def is_in_transaction(self): + def is_in_transaction(self) -> bool: """Return True if Connection is currently inside a transaction. :return bool: True if inside transaction, False otherwise. @@ -275,7 +277,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: _, status, _ = await self._execute(query, args, 0, timeout, True) return status.decode() - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany(self, command: str, args, *, timeout: float=None) \ + -> None: """Execute an SQL *command* for each sequence of arguments in *args*. Example: @@ -378,7 +381,8 @@ async def _introspect_types(self, typeoids, timeout): return await self.__execute( self._intro_query, (list(typeoids),), 0, timeout) - def cursor(self, query, *args, prefetch=None, timeout=None): + def cursor(self, query, *args, prefetch=None, timeout=None) \ + -> cursor.CursorFactory: """Return a *cursor factory* for the specified query. :param args: Query arguments. @@ -392,7 +396,8 @@ def cursor(self, query, *args, prefetch=None, timeout=None): return cursor.CursorFactory(self, query, None, args, prefetch, timeout) - async def prepare(self, query, *, timeout=None): + async def prepare(self, query, *, timeout=None) \ + -> prepared_stmt.PreparedStatement: """Create a *prepared statement* for the specified query. :param str query: Text of the query to create a prepared statement for. @@ -408,7 +413,8 @@ async def _prepare(self, query, *, timeout=None, use_cache: bool=False): use_cache=use_cache) return prepared_stmt.PreparedStatement(self, query, stmt) - async def fetch(self, query, *args, timeout=None) -> list: + async def fetch(self, query, *args, timeout=None) \ + -> typing.List[protocol.Record]: """Run a query and return the results as a list of :class:`Record`. :param str query: Query text. @@ -420,7 +426,8 @@ async def fetch(self, query, *args, timeout=None) -> list: self._check_open() return await self._execute(query, args, 0, timeout) - async def fetchval(self, query, *args, column=0, timeout=None): + async def fetchval(self, query, *args, column=0, timeout=None) \ + -> typing.Any: """Run a query and return a value in the first row. :param str query: Query text. @@ -441,7 +448,8 @@ async def fetchval(self, query, *args, column=0, timeout=None): return None return data[0][column] - async def fetchrow(self, query, *args, timeout=None): + async def fetchrow(self, query, *args, timeout=None) \ + -> typing.Optional[protocol.Record]: """Run a query and return the first row. :param str query: Query text @@ -461,7 +469,8 @@ async def copy_from_table(self, table_name, *, output, columns=None, schema_name=None, timeout=None, format=None, oids=None, delimiter=None, null=None, header=None, quote=None, - escape=None, force_quote=None, encoding=None): + escape=None, force_quote=None, encoding=None) \ + -> str: """Copy table contents to a file or file-like object. :param str table_name: @@ -533,7 +542,7 @@ async def copy_from_query(self, query, *args, output, timeout=None, format=None, oids=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, - encoding=None): + encoding=None) -> str: """Copy the results of a query to a file or file-like object. :param str query: @@ -597,7 +606,7 @@ async def copy_to_table(self, table_name, *, source, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, force_not_null=None, force_null=None, - encoding=None): + encoding=None) -> str: """Copy data to the specified table. :param str table_name: @@ -668,7 +677,7 @@ async def copy_to_table(self, table_name, *, source, async def copy_records_to_table(self, table_name, *, records, columns=None, schema_name=None, - timeout=None): + timeout=None) -> str: """Copy a list of records to the specified table using binary COPY. :param str table_name: @@ -1060,7 +1069,7 @@ async def set_builtin_type_codec(self, typename, *, # Statement cache is no longer valid due to codec changes. self._drop_local_statement_cache() - def is_closed(self): + def is_closed(self) -> bool: """Return ``True`` if the connection is closed, ``False`` otherwise. :return bool: ``True`` if the connection is closed, ``False`` @@ -1503,7 +1512,7 @@ async def connect(dsn=None, *, command_timeout=None, ssl=None, connection_class=Connection, - server_settings=None): + server_settings=None) -> Connection: r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py index 030def0e..56c289c0 100644 --- a/asyncpg/cursor.py +++ b/asyncpg/cursor.py @@ -6,10 +6,12 @@ import collections +import typing from . import compat from . import connresource from . import exceptions +from . import protocol class CursorFactory(connresource.ConnectionResource): @@ -33,7 +35,7 @@ def __init__(self, connection, query, state, args, prefetch, timeout): @compat.aiter_compat @connresource.guarded - def __aiter__(self): + def __aiter__(self) -> 'CursorIterator': prefetch = 50 if self._prefetch is None else self._prefetch return CursorIterator(self._connection, self._query, self._state, @@ -41,7 +43,7 @@ def __aiter__(self): self._timeout) @connresource.guarded - def __await__(self): + def __await__(self) -> 'Cursor': if self._prefetch is not None: raise exceptions.InterfaceError( 'prefetch argument can only be specified for iterable cursor') @@ -164,11 +166,11 @@ def __init__(self, connection, query, state, args, prefetch, timeout): @compat.aiter_compat @connresource.guarded - def __aiter__(self): + def __aiter__(self) -> 'CursorIterator': return self @connresource.guarded - async def __anext__(self): + async def __anext__(self) -> protocol.Record: if self._state is None: self._state = await self._connection._get_statement( self._query, self._timeout, named=True) @@ -203,7 +205,7 @@ async def _init(self, timeout): return self @connresource.guarded - async def fetch(self, n, *, timeout=None): + async def fetch(self, n, *, timeout=None) -> typing.List[protocol.Record]: r"""Return the next *n* rows as a list of :class:`Record` objects. :param float timeout: Optional timeout value in seconds. @@ -221,7 +223,7 @@ async def fetch(self, n, *, timeout=None): return recs @connresource.guarded - async def fetchrow(self, *, timeout=None): + async def fetchrow(self, *, timeout=None) -> protocol.Record: r"""Return the next row. :param float timeout: Optional timeout value in seconds. diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 64f4071e..d103b429 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -10,11 +10,13 @@ import inspect import logging import time +import typing import warnings from . import connection from . import connect_utils from . import exceptions +from . import protocol logger = logging.getLogger(__name__) @@ -508,7 +510,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: async with self.acquire() as con: return await con.execute(query, *args, timeout=timeout) - async def executemany(self, command: str, args, *, timeout: float=None): + async def executemany(self, command: str, args, *, timeout: float=None) \ + -> None: """Execute an SQL *command* for each sequence of arguments in *args*. Pool performs this operation using one of its connections. Other than @@ -520,7 +523,8 @@ async def executemany(self, command: str, args, *, timeout: float=None): async with self.acquire() as con: return await con.executemany(command, args, timeout=timeout) - async def fetch(self, query, *args, timeout=None) -> list: + async def fetch(self, query, *args, timeout=None) \ + -> typing.List[protocol.Record]: """Run a query and return the results as a list of :class:`Record`. Pool performs this operation using one of its connections. Other than @@ -532,7 +536,8 @@ async def fetch(self, query, *args, timeout=None) -> list: async with self.acquire() as con: return await con.fetch(query, *args, timeout=timeout) - async def fetchval(self, query, *args, column=0, timeout=None): + async def fetchval(self, query, *args, column=0, timeout=None) \ + -> typing.Any: """Run a query and return a value in the first row. Pool performs this operation using one of its connections. Other than @@ -545,7 +550,8 @@ async def fetchval(self, query, *args, column=0, timeout=None): return await con.fetchval( query, *args, column=column, timeout=timeout) - async def fetchrow(self, query, *args, timeout=None): + async def fetchrow(self, query, *args, timeout=None) \ + -> typing.Optional[protocol.Record]: """Run a query and return the first row. Pool performs this operation using one of its connections. Other than @@ -557,7 +563,7 @@ async def fetchrow(self, query, *args, timeout=None): async with self.acquire() as con: return await con.fetchrow(query, *args, timeout=timeout) - def acquire(self, *, timeout=None): + def acquire(self, *, timeout=None) -> connection.Connection: """Acquire a database connection from the pool. :param float timeout: A timeout for acquiring a Connection. @@ -784,7 +790,7 @@ def create_pool(dsn=None, *, init=None, loop=None, connection_class=connection.Connection, - **connect_kwargs): + **connect_kwargs) -> Pool: r"""Create a connection pool. Can be used either with an ``async with`` block: diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index 09a0a2ec..a459b200 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -6,10 +6,13 @@ import json +import typing from . import connresource from . import cursor from . import exceptions +from . import protocol +from . import types class PreparedStatement(connresource.ConnectionResource): @@ -50,7 +53,7 @@ def get_statusmsg(self) -> str: return self._last_status.decode() @connresource.guarded - def get_parameters(self): + def get_parameters(self) -> typing.Tuple[types.Type, ...]: """Return a description of statement parameters types. :return: A tuple of :class:`asyncpg.types.Type`. @@ -67,7 +70,7 @@ def get_parameters(self): return self._state._get_parameters() @connresource.guarded - def get_attributes(self): + def get_attributes(self) -> typing.Tuple[types.Attribute, ...]: """Return a description of relation attributes (columns). :return: A tuple of :class:`asyncpg.types.Attribute`. @@ -108,7 +111,7 @@ def cursor(self, *args, prefetch=None, timeout) @connresource.guarded - async def explain(self, *args, analyze=False): + async def explain(self, *args, analyze=False) -> typing.Dict: """Return the execution plan of the statement. :param args: Query arguments. @@ -150,7 +153,7 @@ async def explain(self, *args, analyze=False): return json.loads(data) @connresource.guarded - async def fetch(self, *args, timeout=None): + async def fetch(self, *args, timeout=None) -> typing.List[protocol.Record]: r"""Execute the statement and return a list of :class:`Record` objects. :param str query: Query text @@ -163,7 +166,7 @@ async def fetch(self, *args, timeout=None): return data @connresource.guarded - async def fetchval(self, *args, column=0, timeout=None): + async def fetchval(self, *args, column=0, timeout=None) -> typing.Any: """Execute the statement and return a value in the first row. :param args: Query arguments. @@ -182,7 +185,8 @@ async def fetchval(self, *args, column=0, timeout=None): return data[0][column] @connresource.guarded - async def fetchrow(self, *args, timeout=None): + async def fetchrow(self, *args, timeout=None) \ + -> typing.Optional[protocol.Record]: """Execute the statement and return the first row. :param str query: Query text diff --git a/asyncpg/protocol/__init__.py b/asyncpg/protocol/__init__.py index e872e2fa..16186260 100644 --- a/asyncpg/protocol/__init__.py +++ b/asyncpg/protocol/__init__.py @@ -5,4 +5,4 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 -from .protocol import Protocol, Record, NO_TIMEOUT # NOQA +from .protocol import Protocol, Record, ConnectionSettings, NO_TIMEOUT # NOQA