Skip to content

Commit 1d7b10b

Browse files
committed
get_request_headers combomethod
1 parent 9e21be0 commit 1d7b10b

File tree

5 files changed

+46
-10
lines changed

5 files changed

+46
-10
lines changed

tests/core/providers/test_async_http_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ async def test_async_user_provided_session() -> None:
9494
assert cached_session == session
9595

9696

97-
def test_get_request_headers():
98-
provider = AsyncHTTPProvider()
97+
@pytest.mark.parametrize("provider", (AsyncHTTPProvider(), AsyncHTTPProvider))
98+
def test_get_request_headers(provider):
9999
headers = provider.get_request_headers()
100100
assert len(headers) == 2
101101
assert headers["Content-Type"] == "application/json"

tests/core/providers/test_http_provider.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def test_user_provided_session():
101101
assert adapter._pool_maxsize == 20
102102

103103

104-
def test_get_request_headers():
105-
provider = HTTPProvider()
104+
@pytest.mark.parametrize("provider", (HTTPProvider(), HTTPProvider))
105+
def test_get_request_headers(provider):
106106
headers = provider.get_request_headers()
107107
assert len(headers) == 2
108108
assert headers["Content-Type"] == "application/json"

web3/_utils/http.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,29 @@
1+
from typing import (
2+
TYPE_CHECKING,
3+
Type,
4+
Union,
5+
)
6+
7+
if TYPE_CHECKING:
8+
from web3 import (
9+
AsyncHTTPProvider,
10+
HTTPProvider,
11+
)
12+
113
DEFAULT_HTTP_TIMEOUT = 30.0
214

315

4-
def construct_user_agent(class_type: type) -> str:
16+
def construct_user_agent(
17+
class_type: Union[
18+
"AsyncHTTPProvider",
19+
"HTTPProvider",
20+
Type["HTTPProvider"],
21+
Type["AsyncHTTPProvider"],
22+
],
23+
class_name: str,
24+
) -> str:
525
from web3 import (
626
__version__ as web3_version,
727
)
828

9-
return f"web3.py/{web3_version}/{class_type.__module__}.{class_type.__qualname__}"
29+
return f"web3.py/{web3_version}/{class_type.__module__}.{class_name}"

web3/providers/rpc/async_rpc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
URI,
2020
)
2121
from eth_utils import (
22+
combomethod,
2223
to_dict,
2324
)
2425

@@ -67,6 +68,7 @@ def __init__(
6768
**kwargs: Any,
6869
) -> None:
6970
self._request_session_manager = HTTPSessionManager()
71+
self.initialized = True
7072

7173
if endpoint_uri is None:
7274
self.endpoint_uri = (
@@ -108,10 +110,16 @@ def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]:
108110
yield "headers", self.get_request_headers()
109111
yield from self._request_kwargs.items()
110112

111-
def get_request_headers(self) -> Dict[str, str]:
113+
@combomethod
114+
def get_request_headers(cls) -> Dict[str, str]:
115+
if isinstance(cls, AsyncHTTPProvider):
116+
cls_name = cls.__class__.__name__
117+
else:
118+
cls_name = cls.__name__
119+
112120
return {
113121
"Content-Type": "application/json",
114-
"User-Agent": construct_user_agent(type(self)),
122+
"User-Agent": construct_user_agent(cls, cls_name),
115123
}
116124

117125
async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes:

web3/providers/rpc/rpc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
URI,
1717
)
1818
from eth_utils import (
19+
combomethod,
1920
to_dict,
2021
)
2122
import requests
@@ -70,6 +71,7 @@ def __init__(
7071
] = empty,
7172
**kwargs: Any,
7273
) -> None:
74+
self.initialized = True
7375
self._request_session_manager = HTTPSessionManager()
7476

7577
if endpoint_uri is None:
@@ -116,10 +118,16 @@ def get_request_kwargs(self) -> Iterable[Tuple[str, Any]]:
116118
yield "headers", self.get_request_headers()
117119
yield from self._request_kwargs.items()
118120

119-
def get_request_headers(self) -> Dict[str, str]:
121+
@combomethod
122+
def get_request_headers(cls) -> Dict[str, str]:
123+
if isinstance(cls, HTTPProvider):
124+
cls_name = cls.__class__.__name__
125+
else:
126+
cls_name = cls.__name__
127+
120128
return {
121129
"Content-Type": "application/json",
122-
"User-Agent": construct_user_agent(type(self)),
130+
"User-Agent": construct_user_agent(cls, cls_name),
123131
}
124132

125133
def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes:

0 commit comments

Comments
 (0)