diff --git a/pymfdata/rdb/connection.py b/pymfdata/rdb/connection.py index 9c27a3c..1b919a1 100644 --- a/pymfdata/rdb/connection.py +++ b/pymfdata/rdb/connection.py @@ -1,6 +1,6 @@ from asyncio import current_task from contextlib import asynccontextmanager, contextmanager -from typing import AsyncIterable, Callable, Union, Optional +from typing import Callable, Union, Optional from sqlalchemy.engine import Engine, create_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_scoped_session, create_async_engine @@ -34,26 +34,26 @@ def init_session_factory(self, autocommit: bool = False, autoflush: bool = False async def session(self) -> Callable[..., AsyncSession]: assert self._session_factory is not None - session: AsyncSession = self._session_factory() + session: Union[AsyncSession, async_scoped_session] = self._session_factory() try: yield session except Exception: await session.rollback() raise finally: - await session.close() + await session.remove() async def get_db_session(self) -> Callable[..., AsyncSession]: assert self._session_factory is not None - session: AsyncSession = self._session_factory() + session: Union[AsyncSession, async_scoped_session] = self._session_factory() try: yield session except Exception: await session.rollback() raise finally: - await session.close() + await session.remove() @property def engine(self):