diff --git a/adafruit_requests.py b/adafruit_requests.py index a756d5e..2edce6e 100644 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -46,6 +46,8 @@ from adafruit_connection_manager import get_connection_manager +SEEK_END = 2 + if not sys.implementation.name == "circuitpython": from types import TracebackType from typing import Any, Dict, Optional, Type @@ -344,14 +346,6 @@ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> byt self.close() -def _generate_boundary_str(): - hex_characters = "0123456789abcdef" - _boundary = "" - for _ in range(32): - _boundary += random.choice(hex_characters) - return _boundary - - class Session: """HTTP session that shares sockets and ssl context.""" @@ -366,10 +360,74 @@ def __init__( self._session_id = session_id self._last_response = None + def _build_boundary_data(self, files: dict): # pylint: disable=too-many-locals + boundary_string = self._build_boundary_string() + content_length = 0 + boundary_objects = [] + + for field_name, field_values in files.items(): + file_name = field_values[0] + file_handle = field_values[1] + + boundary_data = f"--{boundary_string}\r\n" + boundary_data += f'Content-Disposition: form-data; name="{field_name}"' + if file_name is not None: + boundary_data += f'; filename="{file_name}"' + boundary_data += "\r\n" + if len(field_values) >= 3: + file_content_type = field_values[2] + boundary_data += f"Content-Type: {file_content_type}\r\n" + if len(field_values) >= 4: + file_headers = field_values[3] + for file_header_key, file_header_value in file_headers.items(): + boundary_data += f"{file_header_key}: {file_header_value}\r\n" + boundary_data += "\r\n" + + content_length += len(boundary_data) + boundary_objects.append(boundary_data) + + if hasattr(file_handle, "read"): + is_binary = False + try: + content = file_handle.read(1) + is_binary = isinstance(content, bytes) + except UnicodeError: + is_binary = False + + if not is_binary: + raise AttributeError("Files must be opened in binary mode") + + file_handle.seek(0, SEEK_END) + content_length += file_handle.tell() + file_handle.seek(0) + boundary_objects.append(file_handle) + boundary_data = "" + else: + boundary_data = file_handle + + boundary_data += "\r\n" + content_length += len(boundary_data) + boundary_objects.append(boundary_data) + + boundary_data = f"--{boundary_string}--\r\n" + + content_length += len(boundary_data) + boundary_objects.append(boundary_data) + + return boundary_string, content_length, boundary_objects + + @staticmethod + def _build_boundary_string(): + hex_characters = "0123456789abcdef" + _boundary = "" + for _ in range(32): + _boundary += random.choice(hex_characters) + return _boundary + @staticmethod def _check_headers(headers: Dict[str, str]): if not isinstance(headers, dict): - raise AttributeError("headers must be in dict format") + raise AttributeError("Headers must be in dict format") for key, value in headers.items(): if isinstance(value, (str, bytes)) or value is None: @@ -403,6 +461,19 @@ def _send(socket: SocketType, data: bytes): def _send_as_bytes(self, socket: SocketType, data: str): return self._send(socket, bytes(data, "utf-8")) + def _send_boundary_objects(self, socket: SocketType, boundary_objects: Any): + for boundary_object in boundary_objects: + if isinstance(boundary_object, str): + self._send_as_bytes(socket, boundary_object) + else: + chunk_size = 32 + b = bytearray(chunk_size) + while True: + size = boundary_object.readinto(b) + if size == 0: + break + self._send(socket, b[:size]) + def _send_header(self, socket, header, value): if value is None: return @@ -440,6 +511,7 @@ def _send_request( # pylint: disable=too-many-arguments # If data is sent and it's a dict, set content type header and convert to string if data and isinstance(data, dict): + assert files is None content_type_header = "application/x-www-form-urlencoded" _post_data = "" for k in data: @@ -451,8 +523,18 @@ def _send_request( # pylint: disable=too-many-arguments if data and isinstance(data, str): data = bytes(data, "utf-8") - if data is None: - data = b"" + # If files are send, build data to send and calculate length + content_length = 0 + boundary_objects = None + if files and isinstance(files, dict): + boundary_string, content_length, boundary_objects = ( + self._build_boundary_data(files) + ) + content_type_header = f"multipart/form-data; boundary={boundary_string}" + else: + if data is None: + data = b"" + content_length = len(data) self._send_as_bytes(socket, method) self._send(socket, b" /") @@ -461,60 +543,6 @@ def _send_request( # pylint: disable=too-many-arguments # create lower-case supplied header list supplied_headers = {header.lower() for header in headers} - boundary_str = None - - # pylint: disable=too-many-nested-blocks - if files is not None and isinstance(files, dict): - boundary_str = _generate_boundary_str() - content_type_header = f"multipart/form-data; boundary={boundary_str}" - - for fieldname in files.keys(): - if not fieldname.endswith("-name"): - if files[fieldname][0] is not None: - file_content = files[fieldname][1].read() - - data += b"--" + boundary_str.encode() + b"\r\n" - data += ( - b'Content-Disposition: form-data; name="' - + fieldname.encode() - + b'"; filename="' - + files[fieldname][0].encode() - + b'"\r\n' - ) - if len(files[fieldname]) >= 3: - data += ( - b"Content-Type: " - + files[fieldname][2].encode() - + b"\r\n" - ) - if len(files[fieldname]) >= 4: - for custom_header_key in files[fieldname][3].keys(): - data += ( - custom_header_key.encode() - + b": " - + files[fieldname][3][custom_header_key].encode() - + b"\r\n" - ) - data += b"\r\n" - data += file_content + b"\r\n" - else: - # filename is None - data += b"--" + boundary_str.encode() + b"\r\n" - data += ( - b'Content-Disposition: form-data; name="' - + fieldname.encode() - + b'"; \r\n' - ) - if len(files[fieldname]) >= 3: - data += ( - b"Content-Type: " - + files[fieldname][2].encode() - + b"\r\n" - ) - data += b"\r\n" - data += files[fieldname][1].encode() + b"\r\n" - - data += b"--" + boundary_str.encode() + b"--" # Send headers if not "host" in supplied_headers: @@ -523,8 +551,8 @@ def _send_request( # pylint: disable=too-many-arguments self._send_header(socket, "User-Agent", "Adafruit CircuitPython") if content_type_header and not "content-type" in supplied_headers: self._send_header(socket, "Content-Type", content_type_header) - if data and not "content-length" in supplied_headers: - self._send_header(socket, "Content-Length", str(len(data))) + if (data or files) and not "content-length" in supplied_headers: + self._send_header(socket, "Content-Length", str(content_length)) # Iterate over keys to avoid tuple alloc for header in headers: self._send_header(socket, header, headers[header]) @@ -533,6 +561,8 @@ def _send_request( # pylint: disable=too-many-arguments # Send data if data: self._send(socket, bytes(data)) + elif boundary_objects: + self._send_boundary_objects(socket, boundary_objects) # pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals def request( diff --git a/examples/wifi/expanded/requests_wifi_file_upload.py b/examples/wifi/expanded/requests_wifi_file_upload.py index 3ceaef0..bd9ac2a 100644 --- a/examples/wifi/expanded/requests_wifi_file_upload.py +++ b/examples/wifi/expanded/requests_wifi_file_upload.py @@ -12,10 +12,10 @@ ssl_context = adafruit_connection_manager.get_radio_ssl_context(wifi.radio) requests = adafruit_requests.Session(pool, ssl_context) -with open("raspi_snip.png", "rb") as file_handle: +with open("requests_wifi_file_upload_image.png", "rb") as file_handle: files = { "file": ( - "raspi_snip.png", + "requests_wifi_file_upload_image.png", file_handle, "image/png", {"CustomHeader": "BlinkaRocks"}, @@ -23,5 +23,5 @@ "othervalue": (None, "HelloWorld"), } - with requests.post(URL, files=files) as resp: - print(resp.content) + with requests.post(URL, files=files) as response: + print(response.content) diff --git a/tests/header_test.py b/tests/header_test.py index 8bcb354..ddfd61a 100644 --- a/tests/header_test.py +++ b/tests/header_test.py @@ -11,7 +11,7 @@ def test_check_headers_not_dict(requests): with pytest.raises(AttributeError) as context: requests._check_headers("") - assert "headers must be in dict format" in str(context) + assert "Headers must be in dict format" in str(context) def test_check_headers_not_valid(requests): diff --git a/tests/method_test.py b/tests/method_test.py index d75e754..1cda6c2 100644 --- a/tests/method_test.py +++ b/tests/method_test.py @@ -52,7 +52,10 @@ def test_post_string(sock, requests): def test_post_form(sock, requests): - data = {"Date": "July 25, 2019", "Time": "12:00"} + data = { + "Date": "July 25, 2019", + "Time": "12:00", + } requests.post("http://" + mocket.MOCK_HOST_1 + "/post", data=data) sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) sock.send.assert_has_calls( @@ -67,7 +70,10 @@ def test_post_form(sock, requests): def test_post_json(sock, requests): - json_data = {"Date": "July 25, 2019", "Time": "12:00"} + json_data = { + "Date": "July 25, 2019", + "Time": "12:00", + } requests.post("http://" + mocket.MOCK_HOST_1 + "/post", json=json_data) sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80)) sock.send.assert_has_calls(