Skip to content

Commit 68834f4

Browse files
committed
Merge branches 'quenting/rust-1.66.0' and 'quenting/rust-twisted-http' into quenting/merge
2 parents eedfd24 + e6d59a9 commit 68834f4

File tree

6 files changed

+318
-2
lines changed

6 files changed

+318
-2
lines changed

Cargo.lock

Lines changed: 90 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

changelog.d/17081.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add helpers to transform Twisted requests to Rust http Requests/Responses.

rust/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ name = "synapse.synapse_rust"
2323

2424
[dependencies]
2525
anyhow = "1.0.63"
26+
bytes = "1.6.0"
27+
headers = "0.4.0"
28+
http = "1.1.0"
2629
lazy_static = "1.4.0"
2730
log = "0.4.17"
2831
pyo3 = { version = "0.20.0", features = [

rust/src/errors.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* This file is licensed under the Affero General Public License (AGPL) version 3.
3+
*
4+
* Copyright (C) 2024 New Vector, Ltd
5+
*
6+
* This program is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU Affero General Public License as
8+
* published by the Free Software Foundation, either version 3 of the
9+
* License, or (at your option) any later version.
10+
*
11+
* See the GNU Affero General Public License for more details:
12+
* <https://www.gnu.org/licenses/agpl-3.0.html>.
13+
*/
14+
15+
#![allow(clippy::new_ret_no_self)]
16+
17+
use std::collections::HashMap;
18+
19+
use http::{HeaderMap, StatusCode};
20+
use pyo3::import_exception;
21+
22+
import_exception!(synapse.api.errors, SynapseError);
23+
24+
impl SynapseError {
25+
pub fn new(
26+
code: StatusCode,
27+
message: String,
28+
errcode: &'static str,
29+
additional_fields: Option<HashMap<String, String>>,
30+
headers: Option<HeaderMap>,
31+
) -> pyo3::PyErr {
32+
// Transform the HeaderMap into a HashMap<String, String>
33+
let headers = headers.map(|headers| {
34+
headers
35+
.iter()
36+
.map(|(key, value)| {
37+
(
38+
key.as_str().to_owned(),
39+
value
40+
.to_str()
41+
// XXX: will that ever panic?
42+
.expect("header value is valid ASCII")
43+
.to_owned(),
44+
)
45+
})
46+
.collect::<HashMap<String, String>>()
47+
});
48+
49+
SynapseError::new_err((code.as_u16(), message, errcode, additional_fields, headers))
50+
}
51+
}
52+
53+
import_exception!(synapse.api.errors, NotFoundError);
54+
55+
impl NotFoundError {
56+
pub fn new() -> pyo3::PyErr {
57+
NotFoundError::new_err(())
58+
}
59+
}

rust/src/http.rs

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/*
2+
* This file is licensed under the Affero General Public License (AGPL) version 3.
3+
*
4+
* Copyright (C) 2024 New Vector, Ltd
5+
*
6+
* This program is free software: you can redistribute it and/or modify
7+
* it under the terms of the GNU Affero General Public License as
8+
* published by the Free Software Foundation, either version 3 of the
9+
* License, or (at your option) any later version.
10+
*
11+
* See the GNU Affero General Public License for more details:
12+
* <https://www.gnu.org/licenses/agpl-3.0.html>.
13+
*/
14+
15+
use bytes::{Buf, BufMut, Bytes, BytesMut};
16+
use headers::{Header, HeaderMapExt};
17+
use http::{HeaderName, HeaderValue, Method, Request, Response, StatusCode, Uri};
18+
use pyo3::{
19+
exceptions::PyValueError,
20+
types::{PyBytes, PySequence, PyTuple},
21+
PyAny, PyResult,
22+
};
23+
24+
use crate::errors::SynapseError;
25+
26+
/// Read a file-like Python object by chunks
27+
///
28+
/// # Errors
29+
///
30+
/// Returns an error if calling the ``read`` on the Python object failed
31+
fn read_io_body(body: &PyAny, chunk_size: usize) -> PyResult<Bytes> {
32+
let mut buf = BytesMut::new();
33+
loop {
34+
let bytes: &PyBytes = body.call_method1("read", (chunk_size,))?.downcast()?;
35+
if bytes.as_bytes().is_empty() {
36+
return Ok(buf.into());
37+
}
38+
buf.put(bytes.as_bytes());
39+
}
40+
}
41+
42+
/// Transform a Twisted ``IRequest`` to an [`http::Request`]
43+
///
44+
/// It uses the following members of ``IRequest``:
45+
/// - ``content``, which is expected to be a file-like object with a ``read`` method
46+
/// - ``uri``, which is expected to be a valid URI as ``bytes``
47+
/// - ``method``, which is expected to be a valid HTTP method as ``bytes``
48+
/// - ``requestHeaders``, which is expected to have a ``getAllRawHeaders`` method
49+
///
50+
/// # Errors
51+
///
52+
/// Returns an error if the Python object doens't properly implement ``IRequest``
53+
pub fn http_request_from_twisted(request: &PyAny) -> PyResult<Request<Bytes>> {
54+
let content = request.getattr("content")?;
55+
let body = read_io_body(content, 4096)?;
56+
57+
let mut req = Request::new(body);
58+
59+
let uri: &PyBytes = request.getattr("uri")?.downcast()?;
60+
*req.uri_mut() =
61+
Uri::try_from(uri.as_bytes()).map_err(|_| PyValueError::new_err("invalid uri"))?;
62+
63+
let method: &PyBytes = request.getattr("method")?.downcast()?;
64+
*req.method_mut() = Method::from_bytes(method.as_bytes())
65+
.map_err(|_| PyValueError::new_err("invalid method"))?;
66+
67+
let headers_iter = request
68+
.getattr("requestHeaders")?
69+
.call_method0("getAllRawHeaders")?
70+
.iter()?;
71+
72+
for header in headers_iter {
73+
let header = header?;
74+
let header: &PyTuple = header.downcast()?;
75+
let name: &PyBytes = header.get_item(0)?.downcast()?;
76+
let name = HeaderName::from_bytes(name.as_bytes())
77+
.map_err(|_| PyValueError::new_err("invalid header name"))?;
78+
79+
let values: &PySequence = header.get_item(1)?.downcast()?;
80+
for index in 0..values.len()? {
81+
let value: &PyBytes = values.get_item(index)?.downcast()?;
82+
let value = HeaderValue::from_bytes(value.as_bytes())
83+
.map_err(|_| PyValueError::new_err("invalid header value"))?;
84+
req.headers_mut().append(name.clone(), value);
85+
}
86+
}
87+
88+
Ok(req)
89+
}
90+
91+
/// Send an [`http::Response`] through a Twisted ``IRequest``
92+
///
93+
/// It uses the following members of ``IRequest``:
94+
///
95+
/// - ``responseHeaders``, which is expected to have a `addRawHeader(bytes, bytes)` method
96+
/// - ``setResponseCode(int)`` method
97+
/// - ``write(bytes)`` method
98+
/// - ``finish()`` method
99+
///
100+
/// # Errors
101+
///
102+
/// Returns an error if the Python object doens't properly implement ``IRequest``
103+
pub fn http_response_to_twisted<B>(request: &PyAny, response: Response<B>) -> PyResult<()>
104+
where
105+
B: Buf,
106+
{
107+
let (parts, mut body) = response.into_parts();
108+
109+
request.call_method1("setResponseCode", (parts.status.as_u16(),))?;
110+
111+
let response_headers = request.getattr("responseHeaders")?;
112+
for (name, value) in parts.headers.iter() {
113+
response_headers.call_method1("addRawHeader", (name.as_str(), value.as_bytes()))?;
114+
}
115+
116+
while body.remaining() != 0 {
117+
let chunk = body.chunk();
118+
request.call_method1("write", (chunk,))?;
119+
body.advance(chunk.len());
120+
}
121+
122+
request.call_method0("finish")?;
123+
124+
Ok(())
125+
}
126+
127+
/// An extension trait for [`HeaderMap`] that provides typed access to headers, and throws the
128+
/// right python exceptions when the header is missing or fails to parse.
129+
pub trait HeaderMapPyExt: HeaderMapExt {
130+
/// Get a header from the map, returning an error if it is missing or invalid.
131+
fn typed_get_required<H>(&self) -> PyResult<H>
132+
where
133+
H: Header,
134+
{
135+
self.typed_get_optional::<H>()?.ok_or_else(|| {
136+
SynapseError::new(
137+
StatusCode::BAD_REQUEST,
138+
format!("Missing required header: {}", H::name()),
139+
"M_MISSING_HEADER",
140+
None,
141+
None,
142+
)
143+
})
144+
}
145+
146+
/// Get a header from the map, returning `None` if it is missing and an error if it is invalid.
147+
fn typed_get_optional<H>(&self) -> PyResult<Option<H>>
148+
where
149+
H: Header,
150+
{
151+
self.typed_try_get::<H>().map_err(|_| {
152+
SynapseError::new(
153+
StatusCode::BAD_REQUEST,
154+
format!("Invalid header: {}", H::name()),
155+
"M_INVALID_HEADER",
156+
None,
157+
None,
158+
)
159+
})
160+
}
161+
}
162+
163+
impl<T: HeaderMapExt> HeaderMapPyExt for T {}

0 commit comments

Comments
 (0)