Skip to content

Commit 000e303

Browse files
committed
Support SET SESSION AUTHORIZATION on trino-python-client
1 parent 73b0a58 commit 000e303

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

tests/unit/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def test_request_headers(mock_get_and_post):
8787
catalog = "test_catalog"
8888
schema = "test_schema"
8989
user = "test_user"
90+
authorization_user = "test_authorization_user"
9091
source = "test_source"
9192
timezone = "Europe/Brussels"
9293
accept_encoding_header = "accept-encoding"
@@ -100,6 +101,7 @@ def test_request_headers(mock_get_and_post):
100101
port=8080,
101102
client_session=ClientSession(
102103
user=user,
104+
authorization_user=authorization_user,
103105
source=source,
104106
catalog=catalog,
105107
schema=schema,
@@ -125,6 +127,7 @@ def assert_headers(headers):
125127
assert headers[constants.HEADER_SCHEMA] == schema
126128
assert headers[constants.HEADER_SOURCE] == source
127129
assert headers[constants.HEADER_USER] == user
130+
assert headers[constants.HEADER_AUTHORIZATION_USER] == authorization_user
128131
assert headers[constants.HEADER_SESSION] == ""
129132
assert headers[constants.HEADER_TIMEZONE] == timezone
130133
assert headers[constants.HEADER_CLIENT_CAPABILITIES] == "PARAMETRIC_DATETIME"
@@ -136,7 +139,7 @@ def assert_headers(headers):
136139
"catalog1=NONE,"
137140
"catalog2=" + urllib.parse.quote("ROLE{catalog2_role}")
138141
)
139-
assert len(headers.keys()) == 11
142+
assert len(headers.keys()) == 12
140143

141144
req.post("URL")
142145
_, post_kwargs = post.call_args

trino/client.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class ClientSession(object):
9696
9797
:param user: associated with the query. It is useful for access control
9898
and query scheduling.
99+
:param authorization_user: associated with the query. It is useful for access control
100+
and query scheduling.
99101
:param source: associated with the query. It is useful for access
100102
control and query scheduling.
101103
:param catalog: to query. The *catalog* is associated with a Trino
@@ -127,6 +129,7 @@ class ClientSession(object):
127129
def __init__(
128130
self,
129131
user: str,
132+
authorization_user: str = None,
130133
catalog: str = None,
131134
schema: str = None,
132135
source: str = None,
@@ -139,6 +142,7 @@ def __init__(
139142
timezone: str = None,
140143
):
141144
self._user = user
145+
self._authorization_user = authorization_user
142146
self._catalog = catalog
143147
self._schema = schema
144148
self._source = source
@@ -158,6 +162,16 @@ def __init__(
158162
def user(self):
159163
return self._user
160164

165+
@property
166+
def authorization_user(self):
167+
with self._object_lock:
168+
return self._authorization_user
169+
170+
@authorization_user.setter
171+
def authorization_user(self, authorization_user):
172+
with self._object_lock:
173+
self._authorization_user = authorization_user
174+
161175
@property
162176
def catalog(self):
163177
with self._object_lock:
@@ -455,6 +469,7 @@ def http_headers(self) -> Dict[str, str]:
455469
headers[constants.HEADER_SCHEMA] = self._client_session.schema
456470
headers[constants.HEADER_SOURCE] = self._client_session.source
457471
headers[constants.HEADER_USER] = self._client_session.user
472+
headers[constants.HEADER_AUTHORIZATION_USER] = self._client_session.authorization_user
458473
headers[constants.HEADER_TIMEZONE] = self._client_session.timezone
459474
headers[constants.HEADER_CLIENT_CAPABILITIES] = 'PARAMETRIC_DATETIME'
460475
if len(self._client_session.roles.values()):
@@ -658,6 +673,12 @@ def process(self, http_response) -> TrinoStatus:
658673
):
659674
self._client_session.prepared_statements.pop(name, None)
660675

676+
if constants.HEADER_SET_AUTHORIZATION_USER in http_response.headers:
677+
self._client_session.authorization_user = http_response.headers[constants.HEADER_SET_AUTHORIZATION_USER]
678+
679+
if constants.HEADER_RESET_AUTHORIZATION_USER in http_response.headers:
680+
self._client_session.authorization_user = None
681+
661682
self._next_uri = response.get("nextUri")
662683

663684
return TrinoStatus(

trino/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@
5454

5555
HEADER_CLIENT_CAPABILITIES = "X-Trino-Client-Capabilities"
5656

57+
HEADER_AUTHORIZATION_USER = "X-Trino-Authorization-User"
58+
HEADER_SET_AUTHORIZATION_USER = "X-Trino-Set-Authorization-User"
59+
HEADER_RESET_AUTHORIZATION_USER = "X-Trino-Reset-Authorization-User"
60+
5761
LENGTH_TYPES = ["char", "varchar"]
5862
PRECISION_TYPES = ["time", "time with time zone", "timestamp", "timestamp with time zone", "decimal"]
5963
SCALE_TYPES = ["decimal"]

0 commit comments

Comments
 (0)