Skip to content

Files arg plus foamyguy #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 97 additions & 67 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

from adafruit_connection_manager import get_connection_manager

SEEK_END = 2

if not sys.implementation.name == "circuitpython":
from types import TracebackType
from typing import Any, Dict, Optional, Type
Expand Down Expand Up @@ -344,14 +346,6 @@ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> byt
self.close()


def _generate_boundary_str():
hex_characters = "0123456789abcdef"
_boundary = ""
for _ in range(32):
_boundary += random.choice(hex_characters)
return _boundary


class Session:
"""HTTP session that shares sockets and ssl context."""

Expand All @@ -366,10 +360,74 @@ def __init__(
self._session_id = session_id
self._last_response = None

def _build_boundary_data(self, files: dict): # pylint: disable=too-many-locals
boundary_string = self._build_boundary_string()
content_length = 0
boundary_objects = []

for field_name, field_values in files.items():
file_name = field_values[0]
file_handle = field_values[1]

boundary_data = f"--{boundary_string}\r\n"
boundary_data += f'Content-Disposition: form-data; name="{field_name}"'
if file_name is not None:
boundary_data += f'; filename="{file_name}"'
boundary_data += "\r\n"
if len(field_values) >= 3:
file_content_type = field_values[2]
boundary_data += f"Content-Type: {file_content_type}\r\n"
if len(field_values) >= 4:
file_headers = field_values[3]
for file_header_key, file_header_value in file_headers.items():
boundary_data += f"{file_header_key}: {file_header_value}\r\n"
boundary_data += "\r\n"

content_length += len(boundary_data)
boundary_objects.append(boundary_data)

if hasattr(file_handle, "read"):
is_binary = False
try:
content = file_handle.read(1)
is_binary = isinstance(content, bytes)
except UnicodeError:
is_binary = False

if not is_binary:
raise AttributeError("Files must be opened in binary mode")

file_handle.seek(0, SEEK_END)
content_length += file_handle.tell()
file_handle.seek(0)
boundary_objects.append(file_handle)
boundary_data = ""
else:
boundary_data = file_handle

boundary_data += "\r\n"
content_length += len(boundary_data)
boundary_objects.append(boundary_data)

boundary_data = f"--{boundary_string}--\r\n"

content_length += len(boundary_data)
boundary_objects.append(boundary_data)

return boundary_string, content_length, boundary_objects

@staticmethod
def _build_boundary_string():
hex_characters = "0123456789abcdef"
_boundary = ""
for _ in range(32):
_boundary += random.choice(hex_characters)
return _boundary

@staticmethod
def _check_headers(headers: Dict[str, str]):
if not isinstance(headers, dict):
raise AttributeError("headers must be in dict format")
raise AttributeError("Headers must be in dict format")

for key, value in headers.items():
if isinstance(value, (str, bytes)) or value is None:
Expand Down Expand Up @@ -403,6 +461,19 @@ def _send(socket: SocketType, data: bytes):
def _send_as_bytes(self, socket: SocketType, data: str):
return self._send(socket, bytes(data, "utf-8"))

def _send_boundary_objects(self, socket: SocketType, boundary_objects: Any):
for boundary_object in boundary_objects:
if isinstance(boundary_object, str):
self._send_as_bytes(socket, boundary_object)
else:
chunk_size = 32
b = bytearray(chunk_size)
while True:
size = boundary_object.readinto(b)
if size == 0:
break
self._send(socket, b[:size])

def _send_header(self, socket, header, value):
if value is None:
return
Expand Down Expand Up @@ -440,6 +511,7 @@ def _send_request( # pylint: disable=too-many-arguments

# If data is sent and it's a dict, set content type header and convert to string
if data and isinstance(data, dict):
assert files is None
content_type_header = "application/x-www-form-urlencoded"
_post_data = ""
for k in data:
Expand All @@ -451,8 +523,18 @@ def _send_request( # pylint: disable=too-many-arguments
if data and isinstance(data, str):
data = bytes(data, "utf-8")

if data is None:
data = b""
# If files are send, build data to send and calculate length
content_length = 0
boundary_objects = None
if files and isinstance(files, dict):
boundary_string, content_length, boundary_objects = (
self._build_boundary_data(files)
)
content_type_header = f"multipart/form-data; boundary={boundary_string}"
else:
if data is None:
data = b""
content_length = len(data)

self._send_as_bytes(socket, method)
self._send(socket, b" /")
Expand All @@ -461,60 +543,6 @@ def _send_request( # pylint: disable=too-many-arguments

# create lower-case supplied header list
supplied_headers = {header.lower() for header in headers}
boundary_str = None

# pylint: disable=too-many-nested-blocks
if files is not None and isinstance(files, dict):
boundary_str = _generate_boundary_str()
content_type_header = f"multipart/form-data; boundary={boundary_str}"

for fieldname in files.keys():
if not fieldname.endswith("-name"):
if files[fieldname][0] is not None:
file_content = files[fieldname][1].read()

data += b"--" + boundary_str.encode() + b"\r\n"
data += (
b'Content-Disposition: form-data; name="'
+ fieldname.encode()
+ b'"; filename="'
+ files[fieldname][0].encode()
+ b'"\r\n'
)
if len(files[fieldname]) >= 3:
data += (
b"Content-Type: "
+ files[fieldname][2].encode()
+ b"\r\n"
)
if len(files[fieldname]) >= 4:
for custom_header_key in files[fieldname][3].keys():
data += (
custom_header_key.encode()
+ b": "
+ files[fieldname][3][custom_header_key].encode()
+ b"\r\n"
)
data += b"\r\n"
data += file_content + b"\r\n"
else:
# filename is None
data += b"--" + boundary_str.encode() + b"\r\n"
data += (
b'Content-Disposition: form-data; name="'
+ fieldname.encode()
+ b'"; \r\n'
)
if len(files[fieldname]) >= 3:
data += (
b"Content-Type: "
+ files[fieldname][2].encode()
+ b"\r\n"
)
data += b"\r\n"
data += files[fieldname][1].encode() + b"\r\n"

data += b"--" + boundary_str.encode() + b"--"

# Send headers
if not "host" in supplied_headers:
Expand All @@ -523,8 +551,8 @@ def _send_request( # pylint: disable=too-many-arguments
self._send_header(socket, "User-Agent", "Adafruit CircuitPython")
if content_type_header and not "content-type" in supplied_headers:
self._send_header(socket, "Content-Type", content_type_header)
if data and not "content-length" in supplied_headers:
self._send_header(socket, "Content-Length", str(len(data)))
if (data or files) and not "content-length" in supplied_headers:
self._send_header(socket, "Content-Length", str(content_length))
# Iterate over keys to avoid tuple alloc
for header in headers:
self._send_header(socket, header, headers[header])
Expand All @@ -533,6 +561,8 @@ def _send_request( # pylint: disable=too-many-arguments
# Send data
if data:
self._send(socket, bytes(data))
elif boundary_objects:
self._send_boundary_objects(socket, boundary_objects)

# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
def request(
Expand Down
8 changes: 4 additions & 4 deletions examples/wifi/expanded/requests_wifi_file_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
ssl_context = adafruit_connection_manager.get_radio_ssl_context(wifi.radio)
requests = adafruit_requests.Session(pool, ssl_context)

with open("raspi_snip.png", "rb") as file_handle:
with open("requests_wifi_file_upload_image.png", "rb") as file_handle:
files = {
"file": (
"raspi_snip.png",
"requests_wifi_file_upload_image.png",
file_handle,
"image/png",
{"CustomHeader": "BlinkaRocks"},
),
"othervalue": (None, "HelloWorld"),
}

with requests.post(URL, files=files) as resp:
print(resp.content)
with requests.post(URL, files=files) as response:
print(response.content)
2 changes: 1 addition & 1 deletion tests/header_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def test_check_headers_not_dict(requests):
with pytest.raises(AttributeError) as context:
requests._check_headers("")
assert "headers must be in dict format" in str(context)
assert "Headers must be in dict format" in str(context)


def test_check_headers_not_valid(requests):
Expand Down
10 changes: 8 additions & 2 deletions tests/method_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def test_post_string(sock, requests):


def test_post_form(sock, requests):
data = {"Date": "July 25, 2019", "Time": "12:00"}
data = {
"Date": "July 25, 2019",
"Time": "12:00",
}
requests.post("http://" + mocket.MOCK_HOST_1 + "/post", data=data)
sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80))
sock.send.assert_has_calls(
Expand All @@ -67,7 +70,10 @@ def test_post_form(sock, requests):


def test_post_json(sock, requests):
json_data = {"Date": "July 25, 2019", "Time": "12:00"}
json_data = {
"Date": "July 25, 2019",
"Time": "12:00",
}
requests.post("http://" + mocket.MOCK_HOST_1 + "/post", json=json_data)
sock.connect.assert_called_once_with((mocket.MOCK_POOL_IP, 80))
sock.send.assert_has_calls(
Expand Down
Loading