diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 0dab7b3dd..2cff15cff 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -13,21 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install Rust - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@stable with: - profile: minimal - toolchain: stable - override: true components: rustfmt - - name: cargo fmt --check - uses: actions-rs/cargo@v1 - with: - command: fmt - args: --all -- --check + - run: cargo fmt --all --check test: name: Test @@ -43,38 +36,26 @@ jobs: - stable steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v3 - name: Install Rust (${{ matrix.rust }}) - uses: actions-rs/toolchain@v1 + uses: dtolnay/rust-toolchain@master with: - profile: minimal toolchain: ${{ matrix.rust }} - override: true - name: Install libssl-dev run: sudo apt-get update && sudo apt-get install libssl-dev - name: Build without unstable flag - uses: actions-rs/cargo@v1 - with: - command: build + run: cargo build - name: Check with unstable flag - uses: actions-rs/cargo@v1 - with: - command: check - args: --features unstable + run: cargo check --features unstable - name: Run lib tests and doc tests - uses: actions-rs/cargo@v1 - with: - command: test + run: cargo test - name: Run integration tests - uses: actions-rs/cargo@v1 - with: - command: test - args: -p h2-tests + run: cargo test -p h2-tests - name: Run h2spec run: ./ci/h2spec.sh @@ -83,3 +64,26 @@ jobs: - name: Check minimal versions run: cargo clean; cargo update -Zminimal-versions; cargo check if: matrix.rust == 'nightly' + + msrv: + name: Check MSRV + needs: [style] + + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v3 + + - name: Get MSRV from package metadata + id: metadata + run: | + cargo metadata --no-deps --format-version 1 | + jq -r '"msrv=" + (.packages[] | select(.name == "h2")).rust_version' >> $GITHUB_OUTPUT + + - name: Install Rust (${{ steps.metadata.outputs.msrv }}) + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ steps.metadata.outputs.msrv }} + + - run: cargo check diff --git a/CHANGELOG.md b/CHANGELOG.md index 14fa00350..31852daff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,67 @@ +# 0.3.18 (April 17, 2023) + +* Fix panic because of opposite check in `is_remote_local()`. + +# 0.3.17 (April 13, 2023) + +* Add `Error::is_library()` method to check if the originated inside `h2`. +* Add `max_pending_accept_reset_streams(usize)` option to client and server + builders. +* Fix theoretical memory growth when receiving too many HEADERS and then + RST_STREAM frames faster than an application can accept them off the queue. + (CVE-2023-26964) + +# 0.3.16 (February 27, 2023) + +* Set `Protocol` extension on requests when received Extended CONNECT requests. +* Remove `B: Unpin + 'static` bound requiremented of bufs +* Fix releasing of frames when stream is finished, reducing memory usage. +* Fix panic when trying to send data and connection window is available, but stream window is not. +* Fix spurious wakeups when stream capacity is not available. + +# 0.3.15 (October 21, 2022) + +* Remove `B: Buf` bound on `SendStream`'s parameter +* add accessor for `StreamId` u32 + +# 0.3.14 (August 16, 2022) + +* Add `Error::is_reset` function. +* Bump MSRV to Rust 1.56. +* Return `RST_STREAM(NO_ERROR)` when the server early responds. + +# 0.3.13 (March 31, 2022) + +* Update private internal `tokio-util` dependency. + +# 0.3.12 (March 9, 2022) + +* Avoid time operations that can panic (#599) +* Bump MSRV to Rust 1.49 (#606) +* Fix header decoding error when a header name is contained at a continuation + header boundary (#589) +* Remove I/O type names from handshake `tracing` spans (#608) + +# 0.3.11 (January 26, 2022) + +* Make `SendStream::poll_capacity` never return `Ok(Some(0))` (#596) +* Fix panic when receiving already reset push promise (#597) + +# 0.3.10 (January 6, 2022) + +* Add `Error::is_go_away()` and `Error::is_remote()` methods. +* Fix panic if receiving malformed PUSH_PROMISE with stream ID of 0. + +# 0.3.9 (December 9, 2021) + +* Fix hang related to new `max_send_buffer_size`. + +# 0.3.8 (December 8, 2021) + +* Add "extended CONNECT support". Adds `h2::ext::Protocol`, which is used for request and response extensions to connect new protocols over an HTTP/2 stream. +* Add `max_send_buffer_size` options to client and server builders, and a default of ~400MB. This acts like a high-water mark for the `poll_capacity()` method. +* Fix panic if receiving malformed HEADERS with stream ID of 0. + # 0.3.7 (October 22, 2021) * Fix panic if server sends a malformed frame on a stream client was about to open. diff --git a/Cargo.toml b/Cargo.toml index 7bbf1647e..767961d0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,13 +5,13 @@ name = "h2" # - html_root_url. # - Update CHANGELOG.md. # - Create git tag -version = "0.3.7" +version = "0.3.18" license = "MIT" authors = [ "Carl Lerche ", "Sean McArthur ", ] -description = "An HTTP/2.0 client and server" +description = "An HTTP/2 client and server" documentation = "https://docs.rs/h2" repository = "https://github.com/hyperium/h2" readme = "README.md" @@ -19,6 +19,7 @@ keywords = ["http", "async", "non-blocking"] categories = ["asynchronous", "web-programming", "network-programming"] exclude = ["fixtures/**", "ci/**"] edition = "2018" +rust-version = "1.56" [features] # Enables `futures::Stream` implementations for various types. @@ -43,7 +44,7 @@ members = [ futures-core = { version = "0.3", default-features = false } futures-sink = { version = "0.3", default-features = false } futures-util = { version = "0.3", default-features = false } -tokio-util = { version = "0.6", features = ["codec"] } +tokio-util = { version = "0.7.1", features = ["codec"] } tokio = { version = "1", features = ["io-util"] } bytes = "1" http = "0.2" @@ -55,22 +56,20 @@ indexmap = { version = "1.5.2", features = ["std"] } [dev-dependencies] # Fuzzing -quickcheck = { version = "0.4.1", default-features = false } -rand = "0.3.15" +quickcheck = { version = "1.0.3", default-features = false } +rand = "0.8.4" # HPACK fixtures -hex = "0.2.0" -walkdir = "1.0.0" +hex = "0.4.3" +walkdir = "2.3.2" serde = "1.0.0" serde_json = "1.0.0" # Examples tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync", "net"] } -env_logger = { version = "0.5.3", default-features = false } -rustls = "0.19" -tokio-rustls = "0.22" -webpki = "0.21" -webpki-roots = "0.21" +env_logger = { version = "0.9", default-features = false } +tokio-rustls = "0.23.2" +webpki-roots = "0.22.2" [package.metadata.docs.rs] features = ["stream"] diff --git a/README.md b/README.md index 63627b706..2e1599914 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # H2 -A Tokio aware, HTTP/2.0 client & server implementation for Rust. +A Tokio aware, HTTP/2 client & server implementation for Rust. [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Crates.io](https://img.shields.io/crates/v/h2.svg)](https://crates.io/crates/h2) @@ -12,21 +12,21 @@ More information about this crate can be found in the [crate documentation][dox] ## Features -* Client and server HTTP/2.0 implementation. -* Implements the full HTTP/2.0 specification. +* Client and server HTTP/2 implementation. +* Implements the full HTTP/2 specification. * Passes [h2spec](https://github.com/summerwind/h2spec). * Focus on performance and correctness. * Built on [Tokio](https://tokio.rs). ## Non goals -This crate is intended to only be an implementation of the HTTP/2.0 +This crate is intended to only be an implementation of the HTTP/2 specification. It does not handle: * Managing TCP connections * HTTP 1.0 upgrade * TLS -* Any feature not described by the HTTP/2.0 specification. +* Any feature not described by the HTTP/2 specification. This crate is now used by [hyper](https://github.com/hyperium/hyper), which will provide all of these features. @@ -55,7 +55,7 @@ fn main() { **How does h2 compare to [solicit] or [rust-http2]?** -The h2 library has implemented more of the details of the HTTP/2.0 specification +The h2 library has implemented more of the details of the HTTP/2 specification than any other Rust library. It also passes the [h2spec] set of tests. The h2 library is rapidly approaching "production ready" quality. diff --git a/examples/akamai.rs b/examples/akamai.rs index 29d8a9347..1d0b17baf 100644 --- a/examples/akamai.rs +++ b/examples/akamai.rs @@ -3,9 +3,9 @@ use http::{Method, Request}; use tokio::net::TcpStream; use tokio_rustls::TlsConnector; -use rustls::Session; -use webpki::DNSNameRef; +use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore, ServerName}; +use std::convert::TryFrom; use std::error::Error; use std::net::ToSocketAddrs; @@ -16,9 +16,19 @@ pub async fn main() -> Result<(), Box> { let _ = env_logger::try_init(); let tls_client_config = std::sync::Arc::new({ - let mut c = rustls::ClientConfig::new(); - c.root_store - .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS); + let mut root_store = RootCertStore::empty(); + root_store.add_server_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| { + OwnedTrustAnchor::from_subject_spki_name_constraints( + ta.subject, + ta.spki, + ta.name_constraints, + ) + })); + + let mut c = tokio_rustls::rustls::ClientConfig::builder() + .with_safe_defaults() + .with_root_certificates(root_store) + .with_no_client_auth(); c.alpn_protocols.push(ALPN_H2.as_bytes().to_owned()); c }); @@ -33,17 +43,14 @@ pub async fn main() -> Result<(), Box> { println!("ADDR: {:?}", addr); let tcp = TcpStream::connect(&addr).await?; - let dns_name = DNSNameRef::try_from_ascii_str("http2.akamai.com").unwrap(); + let dns_name = ServerName::try_from("http2.akamai.com").unwrap(); let connector = TlsConnector::from(tls_client_config); let res = connector.connect(dns_name, tcp).await; let tls = res.unwrap(); { let (_, session) = tls.get_ref(); - let negotiated_protocol = session.get_alpn_protocol(); - assert_eq!( - Some(ALPN_H2.as_bytes()), - negotiated_protocol.as_ref().map(|x| &**x) - ); + let negotiated_protocol = session.alpn_protocol(); + assert_eq!(Some(ALPN_H2.as_bytes()), negotiated_protocol); } println!("Starting client handshake"); diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index ca32138e2..aafb60ae7 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -13,12 +13,10 @@ cargo-fuzz = true arbitrary = { version = "1", features = ["derive"] } libfuzzer-sys = { version = "0.4.0", features = ["arbitrary-derive"] } tokio = { version = "1", features = [ "full" ] } -bytes = "0.5.2" h2 = { path = "../", features = [ "unstable" ] } h2-support = { path = "../tests/h2-support" } futures = { version = "0.3", default-features = false, features = ["std"] } http = "0.2" -env_logger = { version = "0.5.3", default-features = false } # Prevent this from interfering with workspaces [workspace] diff --git a/src/client.rs b/src/client.rs index 9cd0b8f46..4147e8a46 100644 --- a/src/client.rs +++ b/src/client.rs @@ -136,6 +136,7 @@ //! [`Error`]: ../struct.Error.html use crate::codec::{Codec, SendError, UserError}; +use crate::ext::Protocol; use crate::frame::{Headers, Pseudo, Reason, Settings, StreamId}; use crate::proto::{self, Error}; use crate::{FlowControl, PingPong, RecvStream, SendStream}; @@ -319,9 +320,16 @@ pub struct Builder { /// Initial target window size for new connections. initial_target_connection_window_size: Option, + /// Maximum amount of bytes to "buffer" for writing per stream. + max_send_buffer_size: usize, + /// Maximum number of locally reset streams to keep at a time. reset_stream_max: usize, + /// Maximum number of remotely reset streams to allow in the pending + /// accept queue. + pending_accept_reset_stream_max: usize, + /// Initial `Settings` frame to send as part of the handshake. settings: Settings, @@ -337,7 +345,7 @@ pub(crate) struct Peer; impl SendRequest where - B: Buf + 'static, + B: Buf, { /// Returns `Ready` when the connection can initialize a new HTTP/2 /// stream. @@ -517,6 +525,19 @@ where (response, stream) }) } + + /// Returns whether the [extended CONNECT protocol][1] is enabled or not. + /// + /// This setting is configured by the server peer by sending the + /// [`SETTINGS_ENABLE_CONNECT_PROTOCOL` parameter][2] in a `SETTINGS` frame. + /// This method returns the currently acknowledged value received from the + /// remote. + /// + /// [1]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + /// [2]: https://datatracker.ietf.org/doc/html/rfc8441#section-3 + pub fn is_extended_connect_protocol_enabled(&self) -> bool { + self.inner.is_extended_connect_protocol_enabled() + } } impl fmt::Debug for SendRequest @@ -567,7 +588,7 @@ where impl Future for ReadySendRequest where - B: Buf + 'static, + B: Buf, { type Output = Result, crate::Error>; @@ -614,8 +635,10 @@ impl Builder { /// ``` pub fn new() -> Builder { Builder { + max_send_buffer_size: proto::DEFAULT_MAX_SEND_BUFFER_SIZE, reset_stream_duration: Duration::from_secs(proto::DEFAULT_RESET_STREAM_SECS), reset_stream_max: proto::DEFAULT_RESET_STREAM_MAX, + pending_accept_reset_stream_max: proto::DEFAULT_REMOTE_RESET_STREAM_MAX, initial_target_connection_window_size: None, initial_max_send_streams: usize::MAX, settings: Default::default(), @@ -948,10 +971,71 @@ impl Builder { self } + /// Sets the maximum number of pending-accept remotely-reset streams. + /// + /// Streams that have been received by the peer, but not accepted by the + /// user, can also receive a RST_STREAM. This is a legitimate pattern: one + /// could send a request and then shortly after, realize it is not needed, + /// sending a CANCEL. + /// + /// However, since those streams are now "closed", they don't count towards + /// the max concurrent streams. So, they will sit in the accept queue, + /// using memory. + /// + /// When the number of remotely-reset streams sitting in the pending-accept + /// queue reaches this maximum value, a connection error with the code of + /// `ENHANCE_YOUR_CALM` will be sent to the peer, and returned by the + /// `Future`. + /// + /// The default value is currently 20, but could change. + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use h2::client::*; + /// # use bytes::Bytes; + /// # + /// # async fn doc(my_io: T) + /// # -> Result<((SendRequest, Connection)), h2::Error> + /// # { + /// // `client_fut` is a future representing the completion of the HTTP/2 + /// // handshake. + /// let client_fut = Builder::new() + /// .max_pending_accept_reset_streams(100) + /// .handshake(my_io); + /// # client_fut.await + /// # } + /// # + /// # pub fn main() {} + /// ``` + pub fn max_pending_accept_reset_streams(&mut self, max: usize) -> &mut Self { + self.pending_accept_reset_stream_max = max; + self + } + + /// Sets the maximum send buffer size per stream. + /// + /// Once a stream has buffered up to (or over) the maximum, the stream's + /// flow control will not "poll" additional capacity. Once bytes for the + /// stream have been written to the connection, the send buffer capacity + /// will be freed up again. + /// + /// The default is currently ~400MB, but may change. + /// + /// # Panics + /// + /// This function panics if `max` is larger than `u32::MAX`. + pub fn max_send_buffer_size(&mut self, max: usize) -> &mut Self { + assert!(max <= std::u32::MAX as usize); + self.max_send_buffer_size = max; + self + } + /// Enables or disables server push promises. /// - /// This value is included in the initial SETTINGS handshake. When set, the - /// server MUST NOT send a push promise. Setting this value to value to + /// This value is included in the initial SETTINGS handshake. + /// Setting this value to value to /// false in the initial SETTINGS handshake guarantees that the remote server /// will never send a push promise. /// @@ -1064,7 +1148,7 @@ impl Builder { ) -> impl Future, Connection), crate::Error>> where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { Connection::handshake2(io, self.clone()) } @@ -1118,7 +1202,7 @@ where let builder = Builder::new(); builder .handshake(io) - .instrument(tracing::trace_span!("client_handshake", io = %std::any::type_name::())) + .instrument(tracing::trace_span!("client_handshake")) .await } @@ -1141,7 +1225,7 @@ where impl Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { async fn handshake2( mut io: T, @@ -1170,8 +1254,10 @@ where proto::Config { next_stream_id: builder.stream_id, initial_max_send_streams: builder.initial_max_send_streams, + max_send_buffer_size: builder.max_send_buffer_size, reset_stream_duration: builder.reset_stream_duration, reset_stream_max: builder.reset_stream_max, + remote_reset_stream_max: builder.pending_accept_reset_stream_max, settings: builder.settings.clone(), }, ); @@ -1243,14 +1329,13 @@ where /// /// This limit is configured by the server peer by sending the /// [`SETTINGS_MAX_CONCURRENT_STREAMS` parameter][1] in a `SETTINGS` frame. - /// This method returns the currently acknowledged value recieved from the + /// This method returns the currently acknowledged value received from the /// remote. /// - /// [settings]: https://tools.ietf.org/html/rfc7540#section-5.1.2 + /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 pub fn max_concurrent_send_streams(&self) -> usize { self.inner.max_send_streams() } - /// Returns the maximum number of concurrent streams that may be initiated /// by the server on this connection. /// @@ -1270,7 +1355,7 @@ where impl Future for Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { type Output = Result<(), crate::Error>; @@ -1416,6 +1501,7 @@ impl Peer { pub fn convert_send_message( id: StreamId, request: Request<()>, + protocol: Option, end_of_stream: bool, ) -> Result { use http::request::Parts; @@ -1435,7 +1521,7 @@ impl Peer { // Build the set pseudo header set. All requests will include `method` // and `path`. - let mut pseudo = Pseudo::request(method, uri); + let mut pseudo = Pseudo::request(method, uri, protocol); if pseudo.scheme.is_none() { // If the scheme is not set, then there are a two options. diff --git a/src/codec/framed_read.rs b/src/codec/framed_read.rs index 7c3bbb3ba..a874d7732 100644 --- a/src/codec/framed_read.rs +++ b/src/codec/framed_read.rs @@ -109,7 +109,7 @@ fn decode_frame( if partial_inout.is_some() && head.kind() != Kind::Continuation { proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); - return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } let kind = head.kind(); @@ -231,7 +231,7 @@ fn decode_frame( if head.stream_id() == 0 { // Invalid stream identifier proto_err!(conn: "invalid stream ID 0"); - return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { @@ -257,14 +257,14 @@ fn decode_frame( Some(partial) => partial, None => { proto_err!(conn: "received unexpected CONTINUATION frame"); - return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } }; // The stream identifiers must match if partial.frame.stream_id() != head.stream_id() { proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); - return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } // Extend the buf @@ -287,7 +287,7 @@ fn decode_frame( // the attacker to go away. if partial.buf.len() + bytes.len() > max_header_list_size { proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); - return Err(Error::library_go_away(Reason::COMPRESSION_ERROR).into()); + return Err(Error::library_go_away(Reason::COMPRESSION_ERROR)); } } partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); diff --git a/src/error.rs b/src/error.rs index fdbfc0d1b..1b1438e48 100644 --- a/src/error.rs +++ b/src/error.rs @@ -57,12 +57,9 @@ impl Error { } } - /// Returns the true if the error is an io::Error + /// Returns true if the error is an io::Error pub fn is_io(&self) -> bool { - match self.kind { - Kind::Io(_) => true, - _ => false, - } + matches!(self.kind, Kind::Io(..)) } /// Returns the error if the error is an io::Error @@ -86,6 +83,36 @@ impl Error { kind: Kind::Io(err), } } + + /// Returns true if the error is from a `GOAWAY`. + pub fn is_go_away(&self) -> bool { + matches!(self.kind, Kind::GoAway(..)) + } + + /// Returns true if the error is from a `RST_STREAM`. + pub fn is_reset(&self) -> bool { + matches!(self.kind, Kind::Reset(..)) + } + + /// Returns true if the error was received in a frame from the remote. + /// + /// Such as from a received `RST_STREAM` or `GOAWAY` frame. + pub fn is_remote(&self) -> bool { + matches!( + self.kind, + Kind::GoAway(_, _, Initiator::Remote) | Kind::Reset(_, _, Initiator::Remote) + ) + } + + /// Returns true if the error was created by `h2. + /// + /// Such as noticing some protocol error and sending a GOAWAY or RST_STREAM. + pub fn is_library(&self) -> bool { + matches!( + self.kind, + Kind::GoAway(_, _, Initiator::Library) | Kind::Reset(_, _, Initiator::Library) + ) + } } impl From for Error { diff --git a/src/ext.rs b/src/ext.rs new file mode 100644 index 000000000..cf383a495 --- /dev/null +++ b/src/ext.rs @@ -0,0 +1,55 @@ +//! Extensions specific to the HTTP/2 protocol. + +use crate::hpack::BytesStr; + +use bytes::Bytes; +use std::fmt; + +/// Represents the `:protocol` pseudo-header used by +/// the [Extended CONNECT Protocol]. +/// +/// [Extended CONNECT Protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 +#[derive(Clone, Eq, PartialEq)] +pub struct Protocol { + value: BytesStr, +} + +impl Protocol { + /// Converts a static string to a protocol name. + pub const fn from_static(value: &'static str) -> Self { + Self { + value: BytesStr::from_static(value), + } + } + + /// Returns a str representation of the header. + pub fn as_str(&self) -> &str { + self.value.as_str() + } + + pub(crate) fn try_from(bytes: Bytes) -> Result { + Ok(Self { + value: BytesStr::try_from(bytes)?, + }) + } +} + +impl<'a> From<&'a str> for Protocol { + fn from(value: &'a str) -> Self { + Self { + value: BytesStr::from(value), + } + } +} + +impl AsRef<[u8]> for Protocol { + fn as_ref(&self) -> &[u8] { + self.value.as_ref() + } +} + +impl fmt::Debug for Protocol { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.value.fmt(f) + } +} diff --git a/src/frame/data.rs b/src/frame/data.rs index e253d5e23..d0cdf5f69 100644 --- a/src/frame/data.rs +++ b/src/frame/data.rs @@ -16,7 +16,7 @@ pub struct Data { pad_len: Option, } -#[derive(Copy, Clone, Eq, PartialEq)] +#[derive(Copy, Clone, Default, Eq, PartialEq)] struct DataFlags(u8); const END_STREAM: u8 = 0x1; @@ -211,12 +211,6 @@ impl DataFlags { } } -impl Default for DataFlags { - fn default() -> Self { - DataFlags(0) - } -} - impl From for u8 { fn from(src: DataFlags) -> u8 { src.0 diff --git a/src/frame/headers.rs b/src/frame/headers.rs index 0851d7660..9d5c8cefe 100644 --- a/src/frame/headers.rs +++ b/src/frame/headers.rs @@ -1,4 +1,5 @@ use super::{util, StreamDependency, StreamId}; +use crate::ext::Protocol; use crate::frame::{Error, Frame, Head, Kind}; use crate::hpack::{self, BytesStr}; @@ -66,6 +67,7 @@ pub struct Pseudo { pub scheme: Option, pub authority: Option, pub path: Option, + pub protocol: Option, // Response pub status: Option, @@ -146,6 +148,10 @@ impl Headers { tracing::trace!("loading headers; flags={:?}", flags); + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + // Read the padding length if flags.is_padded() { if src.is_empty() { @@ -288,6 +294,10 @@ impl fmt::Debug for Headers { .field("stream_id", &self.stream_id) .field("flags", &self.flags); + if let Some(ref protocol) = self.header_block.pseudo.protocol { + builder.field("protocol", protocol); + } + if let Some(ref dep) = self.stream_dep { builder.field("stream_dep", dep); } @@ -299,17 +309,20 @@ impl fmt::Debug for Headers { // ===== util ===== -pub fn parse_u64(src: &[u8]) -> Result { +#[derive(Debug, PartialEq, Eq)] +pub struct ParseU64Error; + +pub fn parse_u64(src: &[u8]) -> Result { if src.len() > 19 { // At danger for overflow... - return Err(()); + return Err(ParseU64Error); } let mut ret = 0; for &d in src { if d < b'0' || d > b'9' { - return Err(()); + return Err(ParseU64Error); } ret *= 10; @@ -323,7 +336,7 @@ pub fn parse_u64(src: &[u8]) -> Result { #[derive(Debug)] pub enum PushPromiseHeaderError { - InvalidContentLength(Result), + InvalidContentLength(Result), NotSafeAndCacheable, } @@ -371,7 +384,7 @@ impl PushPromise { fn safe_and_cacheable(method: &Method) -> bool { // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods - return method == Method::GET || method == Method::HEAD; + method == Method::GET || method == Method::HEAD } pub fn fields(&self) -> &HeaderMap { @@ -390,6 +403,10 @@ impl PushPromise { let flags = PushPromiseFlag(head.flag()); let mut pad = 0; + if head.stream_id().is_zero() { + return Err(Error::InvalidStreamId); + } + // Read the padding length if flags.is_padded() { if src.is_empty() { @@ -525,7 +542,7 @@ impl Continuation { // ===== impl Pseudo ===== impl Pseudo { - pub fn request(method: Method, uri: Uri) -> Self { + pub fn request(method: Method, uri: Uri, protocol: Option) -> Self { let parts = uri::Parts::from(uri); let mut path = parts @@ -546,6 +563,7 @@ impl Pseudo { scheme: None, authority: None, path: Some(path).filter(|p| !p.is_empty()), + protocol, status: None, }; @@ -571,6 +589,7 @@ impl Pseudo { scheme: None, authority: None, path: None, + protocol: None, status: Some(status), } } @@ -589,6 +608,11 @@ impl Pseudo { self.scheme = Some(bytes_str); } + #[cfg(feature = "unstable")] + pub fn set_protocol(&mut self, protocol: Protocol) { + self.protocol = Some(protocol); + } + pub fn set_authority(&mut self, authority: BytesStr) { self.authority = Some(authority); } @@ -677,6 +701,10 @@ impl Iterator for Iter { return Some(Path(path)); } + if let Some(protocol) = pseudo.protocol.take() { + return Some(Protocol(protocol)); + } + if let Some(status) = pseudo.status.take() { return Some(Status(status)); } @@ -875,6 +903,7 @@ impl HeaderBlock { Method(v) => set_pseudo!(method, v), Scheme(v) => set_pseudo!(scheme, v), Path(v) => set_pseudo!(path, v), + Protocol(v) => set_pseudo!(protocol, v), Status(v) => set_pseudo!(status, v), } }); diff --git a/src/frame/mod.rs b/src/frame/mod.rs index 5a682b634..570a162a8 100644 --- a/src/frame/mod.rs +++ b/src/frame/mod.rs @@ -11,7 +11,8 @@ use std::fmt; /// /// # Examples /// -/// ```rust +/// ```ignore +/// # // We ignore this doctest because the macro is not exported. /// let buf: [u8; 4] = [0, 0, 0, 1]; /// assert_eq!(1u32, unpack_octets_4!(buf, 0, u32)); /// ``` @@ -25,6 +26,15 @@ macro_rules! unpack_octets_4 { }; } +#[cfg(test)] +mod tests { + #[test] + fn test_unpack_octets_4() { + let buf: [u8; 4] = [0, 0, 0, 1]; + assert_eq!(1u32, unpack_octets_4!(buf, 0, u32)); + } +} + mod data; mod go_away; mod head; diff --git a/src/frame/settings.rs b/src/frame/settings.rs index 523f20b06..0c913f059 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -13,6 +13,7 @@ pub struct Settings { initial_window_size: Option, max_frame_size: Option, max_header_list_size: Option, + enable_connect_protocol: Option, } /// An enum that lists all valid settings that can be sent in a SETTINGS @@ -27,6 +28,7 @@ pub enum Setting { InitialWindowSize(u32), MaxFrameSize(u32), MaxHeaderListSize(u32), + EnableConnectProtocol(u32), } #[derive(Copy, Clone, Eq, PartialEq, Default)] @@ -107,6 +109,14 @@ impl Settings { self.enable_push = Some(enable as u32); } + pub fn is_extended_connect_protocol_enabled(&self) -> Option { + self.enable_connect_protocol.map(|val| val != 0) + } + + pub fn set_enable_connect_protocol(&mut self, val: Option) { + self.enable_connect_protocol = val; + } + pub fn header_table_size(&self) -> Option { self.header_table_size } @@ -172,15 +182,23 @@ impl Settings { } } Some(MaxFrameSize(val)) => { - if val < DEFAULT_MAX_FRAME_SIZE || val > MAX_MAX_FRAME_SIZE { - return Err(Error::InvalidSettingValue); - } else { + if DEFAULT_MAX_FRAME_SIZE <= val && val <= MAX_MAX_FRAME_SIZE { settings.max_frame_size = Some(val); + } else { + return Err(Error::InvalidSettingValue); } } Some(MaxHeaderListSize(val)) => { settings.max_header_list_size = Some(val); } + Some(EnableConnectProtocol(val)) => match val { + 0 | 1 => { + settings.enable_connect_protocol = Some(val); + } + _ => { + return Err(Error::InvalidSettingValue); + } + }, None => {} } } @@ -236,6 +254,10 @@ impl Settings { if let Some(v) = self.max_header_list_size { f(MaxHeaderListSize(v)); } + + if let Some(v) = self.enable_connect_protocol { + f(EnableConnectProtocol(v)); + } } } @@ -269,6 +291,9 @@ impl fmt::Debug for Settings { Setting::MaxHeaderListSize(v) => { builder.field("max_header_list_size", &v); } + Setting::EnableConnectProtocol(v) => { + builder.field("enable_connect_protocol", &v); + } }); builder.finish() @@ -291,6 +316,7 @@ impl Setting { 4 => Some(InitialWindowSize(val)), 5 => Some(MaxFrameSize(val)), 6 => Some(MaxHeaderListSize(val)), + 8 => Some(EnableConnectProtocol(val)), _ => None, } } @@ -322,6 +348,7 @@ impl Setting { InitialWindowSize(v) => (4, v), MaxFrameSize(v) => (5, v), MaxHeaderListSize(v) => (6, v), + EnableConnectProtocol(v) => (8, v), }; dst.put_u16(kind); diff --git a/src/fuzz_bridge.rs b/src/fuzz_bridge.rs index 6132deeb4..3ea8b591c 100644 --- a/src/fuzz_bridge.rs +++ b/src/fuzz_bridge.rs @@ -1,7 +1,7 @@ #[cfg(fuzzing)] pub mod fuzz_logic { use crate::hpack; - use bytes::{BufMut, BytesMut}; + use bytes::BytesMut; use http::header::HeaderName; use std::io::Cursor; diff --git a/src/hpack/decoder.rs b/src/hpack/decoder.rs index e4b34d1fc..b45c37927 100644 --- a/src/hpack/decoder.rs +++ b/src/hpack/decoder.rs @@ -142,6 +142,12 @@ struct Table { max_size: usize, } +struct StringMarker { + offset: usize, + len: usize, + string: Option, +} + // ===== impl Decoder ===== impl Decoder { @@ -279,10 +285,13 @@ impl Decoder { // First, read the header name if table_idx == 0 { + let old_pos = buf.position(); + let name_marker = self.try_decode_string(buf)?; + let value_marker = self.try_decode_string(buf)?; + buf.set_position(old_pos); // Read the name as a literal - let name = self.decode_string(buf)?; - let value = self.decode_string(buf)?; - + let name = name_marker.consume(buf); + let value = value_marker.consume(buf); Header::new(name, value) } else { let e = self.table.get(table_idx)?; @@ -292,7 +301,11 @@ impl Decoder { } } - fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result { + fn try_decode_string( + &mut self, + buf: &mut Cursor<&mut BytesMut>, + ) -> Result { + let old_pos = buf.position(); const HUFF_FLAG: u8 = 0b1000_0000; // The first bit in the first byte contains the huffman encoded flag. @@ -309,17 +322,34 @@ impl Decoder { return Err(DecoderError::NeedMore(NeedMore::StringUnderflow)); } + let offset = (buf.position() - old_pos) as usize; if huff { let ret = { let raw = &buf.chunk()[..len]; - huffman::decode(raw, &mut self.buffer).map(BytesMut::freeze) + huffman::decode(raw, &mut self.buffer).map(|buf| StringMarker { + offset, + len, + string: Some(BytesMut::freeze(buf)), + }) }; buf.advance(len); - return ret; + ret + } else { + buf.advance(len); + Ok(StringMarker { + offset, + len, + string: None, + }) } + } - Ok(take(buf, len)) + fn decode_string(&mut self, buf: &mut Cursor<&mut BytesMut>) -> Result { + let old_pos = buf.position(); + let marker = self.try_decode_string(buf)?; + buf.set_position(old_pos); + Ok(marker.consume(buf)) } } @@ -433,6 +463,19 @@ fn take(buf: &mut Cursor<&mut BytesMut>, n: usize) -> Bytes { head.freeze() } +impl StringMarker { + fn consume(self, buf: &mut Cursor<&mut BytesMut>) -> Bytes { + buf.advance(self.offset); + match self.string { + Some(string) => { + buf.advance(self.len); + string + } + None => take(buf, self.len), + } + } +} + fn consume(buf: &mut Cursor<&mut BytesMut>) { // remove bytes from the internal BytesMut when they have been successfully // decoded. This is a more permanent cursor position, which will be @@ -809,8 +852,7 @@ mod test { fn test_decode_empty() { let mut de = Decoder::new(0); let mut buf = BytesMut::new(); - let empty = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap(); - assert_eq!(empty, ()); + let _: () = de.decode(&mut Cursor::new(&mut buf), |_| {}).unwrap(); } #[test] @@ -818,17 +860,16 @@ mod test { let mut de = Decoder::new(0); let mut buf = BytesMut::new(); - buf.extend(&[0b01000000, 0x80 | 2]); + buf.extend([0b01000000, 0x80 | 2]); buf.extend(huff_encode(b"foo")); - buf.extend(&[0x80 | 3]); + buf.extend([0x80 | 3]); buf.extend(huff_encode(b"bar")); let mut res = vec![]; - let _ = de - .decode(&mut Cursor::new(&mut buf), |h| { - res.push(h); - }) - .unwrap(); + de.decode(&mut Cursor::new(&mut buf), |h| { + res.push(h); + }) + .unwrap(); assert_eq!(res.len(), 1); assert_eq!(de.table.size(), 0); @@ -850,4 +891,47 @@ mod test { huffman::encode(src, &mut buf); buf } + + #[test] + fn test_decode_continuation_header_with_non_huff_encoded_name() { + let mut de = Decoder::new(0); + let value = huff_encode(b"bar"); + let mut buf = BytesMut::new(); + // header name is non_huff encoded + buf.extend([0b01000000, 3]); + buf.extend(b"foo"); + // header value is partial + buf.extend([0x80 | 3]); + buf.extend(&value[0..1]); + + let mut res = vec![]; + let e = de + .decode(&mut Cursor::new(&mut buf), |h| { + res.push(h); + }) + .unwrap_err(); + // decode error because the header value is partial + assert_eq!(e, DecoderError::NeedMore(NeedMore::StringUnderflow)); + + // extend buf with the remaining header value + buf.extend(&value[1..]); + de.decode(&mut Cursor::new(&mut buf), |h| { + res.push(h); + }) + .unwrap(); + + assert_eq!(res.len(), 1); + assert_eq!(de.table.size(), 0); + + match res[0] { + Header::Field { + ref name, + ref value, + } => { + assert_eq!(name, "foo"); + assert_eq!(value, "bar"); + } + _ => panic!(), + } + } } diff --git a/src/hpack/encoder.rs b/src/hpack/encoder.rs index 76b373830..d121a2aaf 100644 --- a/src/hpack/encoder.rs +++ b/src/hpack/encoder.rs @@ -118,12 +118,12 @@ impl Encoder { encode_int(idx, 7, 0x80, dst); } Index::Name(idx, _) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); encode_not_indexed(idx, header.value_slice(), header.is_sensitive(), dst); } Index::Inserted(_) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); assert!(!header.is_sensitive()); @@ -133,7 +133,7 @@ impl Encoder { encode_str(header.value_slice(), dst); } Index::InsertedValue(idx, _) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); assert!(!header.is_sensitive()); @@ -141,7 +141,7 @@ impl Encoder { encode_str(header.value_slice(), dst); } Index::NotIndexed(_) => { - let header = self.table.resolve(&index); + let header = self.table.resolve(index); encode_not_indexed2( header.name().as_slice(), diff --git a/src/hpack/header.rs b/src/hpack/header.rs index 8d6136e16..0b5d1fded 100644 --- a/src/hpack/header.rs +++ b/src/hpack/header.rs @@ -1,4 +1,5 @@ use super::{DecoderError, NeedMore}; +use crate::ext::Protocol; use bytes::Bytes; use http::header::{HeaderName, HeaderValue}; @@ -14,6 +15,7 @@ pub enum Header { Method(Method), Scheme(BytesStr), Path(BytesStr), + Protocol(Protocol), Status(StatusCode), } @@ -25,6 +27,7 @@ pub enum Name<'a> { Method, Scheme, Path, + Protocol, Status, } @@ -51,6 +54,7 @@ impl Header> { Method(v) => Method(v), Scheme(v) => Scheme(v), Path(v) => Path(v), + Protocol(v) => Protocol(v), Status(v) => Status(v), }) } @@ -79,6 +83,10 @@ impl Header { let value = BytesStr::try_from(value)?; Ok(Header::Path(value)) } + b"protocol" => { + let value = Protocol::try_from(value)?; + Ok(Header::Protocol(value)) + } b"status" => { let status = StatusCode::from_bytes(&value)?; Ok(Header::Status(status)) @@ -104,6 +112,7 @@ impl Header { Header::Method(ref v) => 32 + 7 + v.as_ref().len(), Header::Scheme(ref v) => 32 + 7 + v.len(), Header::Path(ref v) => 32 + 5 + v.len(), + Header::Protocol(ref v) => 32 + 9 + v.as_str().len(), Header::Status(_) => 32 + 7 + 3, } } @@ -116,6 +125,7 @@ impl Header { Header::Method(..) => Name::Method, Header::Scheme(..) => Name::Scheme, Header::Path(..) => Name::Path, + Header::Protocol(..) => Name::Protocol, Header::Status(..) => Name::Status, } } @@ -127,6 +137,7 @@ impl Header { Header::Method(ref v) => v.as_ref().as_ref(), Header::Scheme(ref v) => v.as_ref(), Header::Path(ref v) => v.as_ref(), + Header::Protocol(ref v) => v.as_ref(), Header::Status(ref v) => v.as_str().as_ref(), } } @@ -156,6 +167,10 @@ impl Header { Header::Path(ref b) => a == b, _ => false, }, + Header::Protocol(ref a) => match *other { + Header::Protocol(ref b) => a == b, + _ => false, + }, Header::Status(ref a) => match *other { Header::Status(ref b) => a == b, _ => false, @@ -175,18 +190,18 @@ impl Header { use http::header; match *self { - Header::Field { ref name, .. } => match *name { + Header::Field { ref name, .. } => matches!( + *name, header::AGE - | header::AUTHORIZATION - | header::CONTENT_LENGTH - | header::ETAG - | header::IF_MODIFIED_SINCE - | header::IF_NONE_MATCH - | header::LOCATION - | header::COOKIE - | header::SET_COOKIE => true, - _ => false, - }, + | header::AUTHORIZATION + | header::CONTENT_LENGTH + | header::ETAG + | header::IF_MODIFIED_SINCE + | header::IF_NONE_MATCH + | header::LOCATION + | header::COOKIE + | header::SET_COOKIE + ), Header::Path(..) => true, _ => false, } @@ -205,6 +220,7 @@ impl From
for Header> { Header::Method(v) => Header::Method(v), Header::Scheme(v) => Header::Scheme(v), Header::Path(v) => Header::Path(v), + Header::Protocol(v) => Header::Protocol(v), Header::Status(v) => Header::Status(v), } } @@ -215,12 +231,13 @@ impl<'a> Name<'a> { match self { Name::Field(name) => Ok(Header::Field { name: name.clone(), - value: HeaderValue::from_bytes(&*value)?, + value: HeaderValue::from_bytes(&value)?, }), Name::Authority => Ok(Header::Authority(BytesStr::try_from(value)?)), - Name::Method => Ok(Header::Method(Method::from_bytes(&*value)?)), + Name::Method => Ok(Header::Method(Method::from_bytes(&value)?)), Name::Scheme => Ok(Header::Scheme(BytesStr::try_from(value)?)), Name::Path => Ok(Header::Path(BytesStr::try_from(value)?)), + Name::Protocol => Ok(Header::Protocol(Protocol::try_from(value)?)), Name::Status => { match StatusCode::from_bytes(&value) { Ok(status) => Ok(Header::Status(status)), @@ -238,6 +255,7 @@ impl<'a> Name<'a> { Name::Method => b":method", Name::Scheme => b":scheme", Name::Path => b":path", + Name::Protocol => b":protocol", Name::Status => b":status", } } diff --git a/src/hpack/huffman/mod.rs b/src/hpack/huffman/mod.rs index 07b3fd925..86c97eb58 100644 --- a/src/hpack/huffman/mod.rs +++ b/src/hpack/huffman/mod.rs @@ -112,7 +112,7 @@ mod test { #[test] fn decode_single_byte() { assert_eq!("o", decode(&[0b00111111]).unwrap()); - assert_eq!("0", decode(&[0x0 + 7]).unwrap()); + assert_eq!("0", decode(&[7]).unwrap()); assert_eq!("A", decode(&[(0x21 << 2) + 3]).unwrap()); } @@ -138,7 +138,7 @@ mod test { dst.clear(); encode(b"0", &mut dst); - assert_eq!(&dst[..], &[0x0 + 7]); + assert_eq!(&dst[..], &[7]); dst.clear(); encode(b"A", &mut dst); @@ -147,7 +147,7 @@ mod test { #[test] fn encode_decode_str() { - const DATA: &'static [&'static str] = &[ + const DATA: &[&str] = &[ "hello world", ":method", ":scheme", @@ -184,8 +184,7 @@ mod test { #[test] fn encode_decode_u8() { - const DATA: &'static [&'static [u8]] = - &[b"\0", b"\0\0\0", b"\0\x01\x02\x03\x04\x05", b"\xFF\xF8"]; + const DATA: &[&[u8]] = &[b"\0", b"\0\0\0", b"\0\x01\x02\x03\x04\x05", b"\xFF\xF8"]; for s in DATA { let mut dst = BytesMut::with_capacity(s.len()); diff --git a/src/hpack/table.rs b/src/hpack/table.rs index 2328743a8..a1a780451 100644 --- a/src/hpack/table.rs +++ b/src/hpack/table.rs @@ -404,7 +404,7 @@ impl Table { // Find the associated position probe_loop!(probe < self.indices.len(), { - debug_assert!(!self.indices[probe].is_none()); + debug_assert!(self.indices[probe].is_some()); let mut pos = self.indices[probe].unwrap(); @@ -751,6 +751,7 @@ fn index_static(header: &Header) -> Option<(usize, bool)> { "/index.html" => Some((5, true)), _ => Some((4, false)), }, + Header::Protocol(..) => None, Header::Status(ref v) => match u16::from(*v) { 200 => Some((8, true)), 204 => Some((9, true)), diff --git a/src/hpack/test/fixture.rs b/src/hpack/test/fixture.rs index 6d0448425..0d33ca2de 100644 --- a/src/hpack/test/fixture.rs +++ b/src/hpack/test/fixture.rs @@ -52,8 +52,8 @@ fn test_story(story: Value) { Case { seqno: case.get("seqno").unwrap().as_u64().unwrap(), - wire: wire, - expect: expect, + wire, + expect, header_table_size: size, } }) @@ -134,6 +134,7 @@ fn key_str(e: &Header) -> &str { Header::Method(..) => ":method", Header::Scheme(..) => ":scheme", Header::Path(..) => ":path", + Header::Protocol(..) => ":protocol", Header::Status(..) => ":status", } } @@ -141,10 +142,11 @@ fn key_str(e: &Header) -> &str { fn value_str(e: &Header) -> &str { match *e { Header::Field { ref value, .. } => value.to_str().unwrap(), - Header::Authority(ref v) => &**v, + Header::Authority(ref v) => v, Header::Method(ref m) => m.as_str(), - Header::Scheme(ref v) => &**v, - Header::Path(ref v) => &**v, + Header::Scheme(ref v) => v, + Header::Path(ref v) => v, + Header::Protocol(ref v) => v.as_str(), Header::Status(ref v) => v.as_str(), } } diff --git a/src/hpack/test/fuzz.rs b/src/hpack/test/fuzz.rs index 1d05a97c5..af9e8ea23 100644 --- a/src/hpack/test/fuzz.rs +++ b/src/hpack/test/fuzz.rs @@ -4,7 +4,9 @@ use http::header::{HeaderName, HeaderValue}; use bytes::BytesMut; use quickcheck::{Arbitrary, Gen, QuickCheck, TestResult}; -use rand::{Rng, SeedableRng, StdRng}; +use rand::distributions::Slice; +use rand::rngs::StdRng; +use rand::{thread_rng, Rng, SeedableRng}; use std::io::Cursor; @@ -46,9 +48,9 @@ struct HeaderFrame { } impl FuzzHpack { - fn new(seed: [usize; 4]) -> FuzzHpack { + fn new(seed: [u8; 32]) -> FuzzHpack { // Seed the RNG - let mut rng = StdRng::from_seed(&seed); + let mut rng = StdRng::from_seed(seed); // Generates a bunch of source headers let mut source: Vec>> = vec![]; @@ -58,12 +60,12 @@ impl FuzzHpack { } // Actual test run headers - let num: usize = rng.gen_range(40, 500); + let num: usize = rng.gen_range(40..500); let mut frames: Vec = vec![]; let mut added = 0; - let skew: i32 = rng.gen_range(1, 5); + let skew: i32 = rng.gen_range(1..5); // Rough number of headers to add while added < num { @@ -72,24 +74,24 @@ impl FuzzHpack { headers: vec![], }; - match rng.gen_range(0, 20) { + match rng.gen_range(0..20) { 0 => { // Two resizes - let high = rng.gen_range(128, MAX_CHUNK * 2); - let low = rng.gen_range(0, high); + let high = rng.gen_range(128..MAX_CHUNK * 2); + let low = rng.gen_range(0..high); - frame.resizes.extend(&[low, high]); + frame.resizes.extend([low, high]); } 1..=3 => { - frame.resizes.push(rng.gen_range(128, MAX_CHUNK * 2)); + frame.resizes.push(rng.gen_range(128..MAX_CHUNK * 2)); } _ => {} } let mut is_name_required = true; - for _ in 0..rng.gen_range(1, (num - added) + 1) { - let x: f64 = rng.gen_range(0.0, 1.0); + for _ in 0..rng.gen_range(1..(num - added) + 1) { + let x: f64 = rng.gen_range(0.0..1.0); let x = x.powi(skew); let i = (x * source.len() as f64) as usize; @@ -177,31 +179,31 @@ impl FuzzHpack { } impl Arbitrary for FuzzHpack { - fn arbitrary(g: &mut G) -> Self { - FuzzHpack::new(quickcheck::Rng::gen(g)) + fn arbitrary(_: &mut Gen) -> Self { + FuzzHpack::new(thread_rng().gen()) } } fn gen_header(g: &mut StdRng) -> Header> { use http::{Method, StatusCode}; - if g.gen_weighted_bool(10) { - match g.next_u32() % 5 { + if g.gen_ratio(1, 10) { + match g.gen_range(0u32..5) { 0 => { let value = gen_string(g, 4, 20); Header::Authority(to_shared(value)) } 1 => { - let method = match g.next_u32() % 6 { + let method = match g.gen_range(0u32..6) { 0 => Method::GET, 1 => Method::POST, 2 => Method::PUT, 3 => Method::PATCH, 4 => Method::DELETE, 5 => { - let n: usize = g.gen_range(3, 7); + let n: usize = g.gen_range(3..7); let bytes: Vec = (0..n) - .map(|_| g.choose(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ").unwrap().clone()) + .map(|_| *g.sample(Slice::new(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ").unwrap())) .collect(); Method::from_bytes(&bytes).unwrap() @@ -212,7 +214,7 @@ fn gen_header(g: &mut StdRng) -> Header> { Header::Method(method) } 2 => { - let value = match g.next_u32() % 2 { + let value = match g.gen_range(0u32..2) { 0 => "http", 1 => "https", _ => unreachable!(), @@ -221,7 +223,7 @@ fn gen_header(g: &mut StdRng) -> Header> { Header::Scheme(to_shared(value.to_string())) } 3 => { - let value = match g.next_u32() % 100 { + let value = match g.gen_range(0u32..100) { 0 => "/".to_string(), 1 => "/index.html".to_string(), _ => gen_string(g, 2, 20), @@ -237,14 +239,14 @@ fn gen_header(g: &mut StdRng) -> Header> { _ => unreachable!(), } } else { - let name = if g.gen_weighted_bool(10) { + let name = if g.gen_ratio(1, 10) { None } else { Some(gen_header_name(g)) }; let mut value = gen_header_value(g); - if g.gen_weighted_bool(30) { + if g.gen_ratio(1, 30) { value.set_sensitive(true); } @@ -255,84 +257,86 @@ fn gen_header(g: &mut StdRng) -> Header> { fn gen_header_name(g: &mut StdRng) -> HeaderName { use http::header; - if g.gen_weighted_bool(2) { - g.choose(&[ - header::ACCEPT, - header::ACCEPT_CHARSET, - header::ACCEPT_ENCODING, - header::ACCEPT_LANGUAGE, - header::ACCEPT_RANGES, - header::ACCESS_CONTROL_ALLOW_CREDENTIALS, - header::ACCESS_CONTROL_ALLOW_HEADERS, - header::ACCESS_CONTROL_ALLOW_METHODS, - header::ACCESS_CONTROL_ALLOW_ORIGIN, - header::ACCESS_CONTROL_EXPOSE_HEADERS, - header::ACCESS_CONTROL_MAX_AGE, - header::ACCESS_CONTROL_REQUEST_HEADERS, - header::ACCESS_CONTROL_REQUEST_METHOD, - header::AGE, - header::ALLOW, - header::ALT_SVC, - header::AUTHORIZATION, - header::CACHE_CONTROL, - header::CONNECTION, - header::CONTENT_DISPOSITION, - header::CONTENT_ENCODING, - header::CONTENT_LANGUAGE, - header::CONTENT_LENGTH, - header::CONTENT_LOCATION, - header::CONTENT_RANGE, - header::CONTENT_SECURITY_POLICY, - header::CONTENT_SECURITY_POLICY_REPORT_ONLY, - header::CONTENT_TYPE, - header::COOKIE, - header::DNT, - header::DATE, - header::ETAG, - header::EXPECT, - header::EXPIRES, - header::FORWARDED, - header::FROM, - header::HOST, - header::IF_MATCH, - header::IF_MODIFIED_SINCE, - header::IF_NONE_MATCH, - header::IF_RANGE, - header::IF_UNMODIFIED_SINCE, - header::LAST_MODIFIED, - header::LINK, - header::LOCATION, - header::MAX_FORWARDS, - header::ORIGIN, - header::PRAGMA, - header::PROXY_AUTHENTICATE, - header::PROXY_AUTHORIZATION, - header::PUBLIC_KEY_PINS, - header::PUBLIC_KEY_PINS_REPORT_ONLY, - header::RANGE, - header::REFERER, - header::REFERRER_POLICY, - header::REFRESH, - header::RETRY_AFTER, - header::SERVER, - header::SET_COOKIE, - header::STRICT_TRANSPORT_SECURITY, - header::TE, - header::TRAILER, - header::TRANSFER_ENCODING, - header::USER_AGENT, - header::UPGRADE, - header::UPGRADE_INSECURE_REQUESTS, - header::VARY, - header::VIA, - header::WARNING, - header::WWW_AUTHENTICATE, - header::X_CONTENT_TYPE_OPTIONS, - header::X_DNS_PREFETCH_CONTROL, - header::X_FRAME_OPTIONS, - header::X_XSS_PROTECTION, - ]) - .unwrap() + if g.gen_ratio(1, 2) { + g.sample( + Slice::new(&[ + header::ACCEPT, + header::ACCEPT_CHARSET, + header::ACCEPT_ENCODING, + header::ACCEPT_LANGUAGE, + header::ACCEPT_RANGES, + header::ACCESS_CONTROL_ALLOW_CREDENTIALS, + header::ACCESS_CONTROL_ALLOW_HEADERS, + header::ACCESS_CONTROL_ALLOW_METHODS, + header::ACCESS_CONTROL_ALLOW_ORIGIN, + header::ACCESS_CONTROL_EXPOSE_HEADERS, + header::ACCESS_CONTROL_MAX_AGE, + header::ACCESS_CONTROL_REQUEST_HEADERS, + header::ACCESS_CONTROL_REQUEST_METHOD, + header::AGE, + header::ALLOW, + header::ALT_SVC, + header::AUTHORIZATION, + header::CACHE_CONTROL, + header::CONNECTION, + header::CONTENT_DISPOSITION, + header::CONTENT_ENCODING, + header::CONTENT_LANGUAGE, + header::CONTENT_LENGTH, + header::CONTENT_LOCATION, + header::CONTENT_RANGE, + header::CONTENT_SECURITY_POLICY, + header::CONTENT_SECURITY_POLICY_REPORT_ONLY, + header::CONTENT_TYPE, + header::COOKIE, + header::DNT, + header::DATE, + header::ETAG, + header::EXPECT, + header::EXPIRES, + header::FORWARDED, + header::FROM, + header::HOST, + header::IF_MATCH, + header::IF_MODIFIED_SINCE, + header::IF_NONE_MATCH, + header::IF_RANGE, + header::IF_UNMODIFIED_SINCE, + header::LAST_MODIFIED, + header::LINK, + header::LOCATION, + header::MAX_FORWARDS, + header::ORIGIN, + header::PRAGMA, + header::PROXY_AUTHENTICATE, + header::PROXY_AUTHORIZATION, + header::PUBLIC_KEY_PINS, + header::PUBLIC_KEY_PINS_REPORT_ONLY, + header::RANGE, + header::REFERER, + header::REFERRER_POLICY, + header::REFRESH, + header::RETRY_AFTER, + header::SERVER, + header::SET_COOKIE, + header::STRICT_TRANSPORT_SECURITY, + header::TE, + header::TRAILER, + header::TRANSFER_ENCODING, + header::USER_AGENT, + header::UPGRADE, + header::UPGRADE_INSECURE_REQUESTS, + header::VARY, + header::VIA, + header::WARNING, + header::WWW_AUTHENTICATE, + header::X_CONTENT_TYPE_OPTIONS, + header::X_DNS_PREFETCH_CONTROL, + header::X_FRAME_OPTIONS, + header::X_XSS_PROTECTION, + ]) + .unwrap(), + ) .clone() } else { let value = gen_string(g, 1, 25); @@ -349,9 +353,7 @@ fn gen_string(g: &mut StdRng, min: usize, max: usize) -> String { let bytes: Vec<_> = (min..max) .map(|_| { // Chars to pick from - g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----") - .unwrap() - .clone() + *g.sample(Slice::new(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----").unwrap()) }) .collect(); diff --git a/src/lib.rs b/src/lib.rs index cb02acaff..420e0fee1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -78,9 +78,10 @@ //! [`server::handshake`]: server/fn.handshake.html //! [`client::handshake`]: client/fn.handshake.html -#![doc(html_root_url = "https://docs.rs/h2/0.3.7")] +#![doc(html_root_url = "https://docs.rs/h2/0.3.18")] #![deny(missing_debug_implementations, missing_docs)] #![cfg_attr(test, deny(warnings))] +#![allow(clippy::type_complexity, clippy::manual_range_contains)] macro_rules! proto_err { (conn: $($msg:tt)+) => { @@ -120,6 +121,7 @@ mod frame; pub mod frame; pub mod client; +pub mod ext; pub mod server; mod share; @@ -132,44 +134,3 @@ pub use crate::share::{FlowControl, Ping, PingPong, Pong, RecvStream, SendStream #[cfg(feature = "unstable")] pub use codec::{Codec, SendError, UserError}; - -use std::task::Poll; - -// TODO: Get rid of this trait once https://github.com/rust-lang/rust/pull/63512 -// is stabilized. -trait PollExt { - /// Changes the success value of this `Poll` with the closure provided. - fn map_ok_(self, f: F) -> Poll>> - where - F: FnOnce(T) -> U; - /// Changes the error value of this `Poll` with the closure provided. - fn map_err_(self, f: F) -> Poll>> - where - F: FnOnce(E) -> U; -} - -impl PollExt for Poll>> { - fn map_ok_(self, f: F) -> Poll>> - where - F: FnOnce(T) -> U, - { - match self { - Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(f(t)))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } - - fn map_err_(self, f: F) -> Poll>> - where - F: FnOnce(E) -> U, - { - match self { - Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(Ok(t))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(f(e)))), - Poll::Ready(None) => Poll::Ready(None), - Poll::Pending => Poll::Pending, - } - } -} diff --git a/src/proto/connection.rs b/src/proto/connection.rs index a75df4369..619973df8 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -77,8 +77,10 @@ struct DynConnection<'a, B: Buf = Bytes> { pub(crate) struct Config { pub next_stream_id: StreamId, pub initial_max_send_streams: usize, + pub max_send_buffer_size: usize, pub reset_stream_duration: Duration, pub reset_stream_max: usize, + pub remote_reset_stream_max: usize, pub settings: frame::Settings, } @@ -108,10 +110,16 @@ where .initial_window_size() .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), initial_max_send_streams: config.initial_max_send_streams, + local_max_buffer_size: config.max_send_buffer_size, local_next_stream_id: config.next_stream_id, local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), + extended_connect_protocol_enabled: config + .settings + .is_extended_connect_protocol_enabled() + .unwrap_or(false), local_reset_duration: config.reset_stream_duration, local_reset_max: config.reset_stream_max, + remote_reset_max: config.remote_reset_stream_max, remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, remote_max_initiated: config .settings @@ -147,6 +155,13 @@ where self.inner.settings.send_settings(settings) } + /// Send a new SETTINGS frame with extended CONNECT protocol enabled. + pub(crate) fn set_enable_connect_protocol(&mut self) -> Result<(), UserError> { + let mut settings = frame::Settings::default(); + settings.set_enable_connect_protocol(Some(1)); + self.inner.settings.send_settings(settings) + } + /// Returns the maximum number of concurrent streams that may be initiated /// by this peer. pub(crate) fn max_send_streams(&self) -> usize { @@ -159,6 +174,11 @@ where self.inner.streams.max_recv_streams() } + #[cfg(feature = "unstable")] + pub fn num_wired_streams(&self) -> usize { + self.inner.streams.num_wired_streams() + } + /// Returns `Ready` when the connection is ready to receive a frame. /// /// Returns `Error` as this may raise errors that are caused by delayed @@ -202,7 +222,7 @@ where }); match (ours, theirs) { - (Reason::NO_ERROR, Reason::NO_ERROR) => return Ok(()), + (Reason::NO_ERROR, Reason::NO_ERROR) => Ok(()), (ours, Reason::NO_ERROR) => Err(Error::GoAway(Bytes::new(), ours, initiator)), // If both sides reported an error, give their // error back to th user. We assume our error diff --git a/src/proto/error.rs b/src/proto/error.rs index 197237263..2c00c7ea6 100644 --- a/src/proto/error.rs +++ b/src/proto/error.rs @@ -13,7 +13,7 @@ pub enum Error { Io(io::ErrorKind, Option), } -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Initiator { User, Library, @@ -70,7 +70,7 @@ impl fmt::Display for Error { impl From for Error { fn from(src: io::ErrorKind) -> Self { - Error::Io(src.into(), None) + Error::Io(src, None) } } diff --git a/src/proto/mod.rs b/src/proto/mod.rs index d505e77f3..d71ee9c42 100644 --- a/src/proto/mod.rs +++ b/src/proto/mod.rs @@ -31,5 +31,7 @@ pub type WindowSize = u32; // Constants pub const MAX_WINDOW_SIZE: WindowSize = (1 << 31) - 1; +pub const DEFAULT_REMOTE_RESET_STREAM_MAX: usize = 20; pub const DEFAULT_RESET_STREAM_MAX: usize = 10; pub const DEFAULT_RESET_STREAM_SECS: u64 = 30; +pub const DEFAULT_MAX_SEND_BUFFER_SIZE: usize = 1024 * 400; diff --git a/src/proto/ping_pong.rs b/src/proto/ping_pong.rs index 844c5fbb9..59023e26a 100644 --- a/src/proto/ping_pong.rs +++ b/src/proto/ping_pong.rs @@ -200,10 +200,7 @@ impl PingPong { impl ReceivedPing { pub(crate) fn is_shutdown(&self) -> bool { - match *self { - ReceivedPing::Shutdown => true, - _ => false, - } + matches!(*self, Self::Shutdown) } } diff --git a/src/proto/settings.rs b/src/proto/settings.rs index 44f4c2df4..6cc617209 100644 --- a/src/proto/settings.rs +++ b/src/proto/settings.rs @@ -117,6 +117,8 @@ impl Settings { tracing::trace!("ACK sent; applying settings"); + streams.apply_remote_settings(settings)?; + if let Some(val) = settings.header_table_size() { dst.set_send_header_table_size(val as usize); } @@ -124,8 +126,6 @@ impl Settings { if let Some(val) = settings.max_frame_size() { dst.set_max_send_frame_size(val as usize); } - - streams.apply_remote_settings(settings)?; } self.remote = None; diff --git a/src/proto/streams/counts.rs b/src/proto/streams/counts.rs index 70dfc7851..6a5aa9ccd 100644 --- a/src/proto/streams/counts.rs +++ b/src/proto/streams/counts.rs @@ -21,10 +21,16 @@ pub(super) struct Counts { num_recv_streams: usize, /// Maximum number of pending locally reset streams - max_reset_streams: usize, + max_local_reset_streams: usize, /// Current number of pending locally reset streams - num_reset_streams: usize, + num_local_reset_streams: usize, + + /// Max number of "pending accept" streams that were remotely reset + max_remote_reset_streams: usize, + + /// Current number of "pending accept" streams that were remotely reset + num_remote_reset_streams: usize, } impl Counts { @@ -36,8 +42,10 @@ impl Counts { num_send_streams: 0, max_recv_streams: config.remote_max_initiated.unwrap_or(usize::MAX), num_recv_streams: 0, - max_reset_streams: config.local_reset_max, - num_reset_streams: 0, + max_local_reset_streams: config.local_reset_max, + num_local_reset_streams: 0, + max_remote_reset_streams: config.remote_reset_max, + num_remote_reset_streams: 0, } } @@ -90,7 +98,7 @@ impl Counts { /// Returns true if the number of pending reset streams can be incremented. pub fn can_inc_num_reset_streams(&self) -> bool { - self.max_reset_streams > self.num_reset_streams + self.max_local_reset_streams > self.num_local_reset_streams } /// Increments the number of pending reset streams. @@ -101,7 +109,34 @@ impl Counts { pub fn inc_num_reset_streams(&mut self) { assert!(self.can_inc_num_reset_streams()); - self.num_reset_streams += 1; + self.num_local_reset_streams += 1; + } + + pub(crate) fn max_remote_reset_streams(&self) -> usize { + self.max_remote_reset_streams + } + + /// Returns true if the number of pending REMOTE reset streams can be + /// incremented. + pub(crate) fn can_inc_num_remote_reset_streams(&self) -> bool { + self.max_remote_reset_streams > self.num_remote_reset_streams + } + + /// Increments the number of pending REMOTE reset streams. + /// + /// # Panics + /// + /// Panics on failure as this should have been validated before hand. + pub(crate) fn inc_num_remote_reset_streams(&mut self) { + assert!(self.can_inc_num_remote_reset_streams()); + + self.num_remote_reset_streams += 1; + } + + pub(crate) fn dec_num_remote_reset_streams(&mut self) { + assert!(self.num_remote_reset_streams > 0); + + self.num_remote_reset_streams -= 1; } pub fn apply_remote_settings(&mut self, settings: &frame::Settings) { @@ -194,8 +229,8 @@ impl Counts { } fn dec_num_reset_streams(&mut self) { - assert!(self.num_reset_streams > 0); - self.num_reset_streams -= 1; + assert!(self.num_local_reset_streams > 0); + self.num_local_reset_streams -= 1; } } diff --git a/src/proto/streams/flow_control.rs b/src/proto/streams/flow_control.rs index 4a47f08dd..73a7754db 100644 --- a/src/proto/streams/flow_control.rs +++ b/src/proto/streams/flow_control.rs @@ -19,6 +19,7 @@ const UNCLAIMED_NUMERATOR: i32 = 1; const UNCLAIMED_DENOMINATOR: i32 = 2; #[test] +#[allow(clippy::assertions_on_constants)] fn sanity_unclaimed_ratio() { assert!(UNCLAIMED_NUMERATOR < UNCLAIMED_DENOMINATOR); assert!(UNCLAIMED_NUMERATOR >= 0); @@ -172,12 +173,15 @@ impl FlowControl { self.available ); - // Ensure that the argument is correct - assert!(self.window_size >= sz as usize); + // If send size is zero it's meaningless to update flow control window + if sz > 0 { + // Ensure that the argument is correct + assert!(self.window_size >= sz as usize); - // Update values - self.window_size -= sz; - self.available -= sz; + // Update values + self.window_size -= sz; + self.available -= sz; + } } } @@ -188,7 +192,7 @@ impl FlowControl { /// /// This type tries to centralize the knowledge of addition and subtraction /// to this capacity, instead of having integer casts throughout the source. -#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd)] pub struct Window(i32); impl Window { diff --git a/src/proto/streams/mod.rs b/src/proto/streams/mod.rs index 608395c0f..fbe32c7b0 100644 --- a/src/proto/streams/mod.rs +++ b/src/proto/streams/mod.rs @@ -7,6 +7,7 @@ mod send; mod state; mod store; mod stream; +#[allow(clippy::module_inception)] mod streams; pub(crate) use self::prioritize::Prioritized; @@ -41,18 +42,28 @@ pub struct Config { /// MAX_CONCURRENT_STREAMS specified in the frame. pub initial_max_send_streams: usize, + /// Max amount of DATA bytes to buffer per stream. + pub local_max_buffer_size: usize, + /// The stream ID to start the next local stream with pub local_next_stream_id: StreamId, /// If the local peer is willing to receive push promises pub local_push_enabled: bool, + /// If extended connect protocol is enabled. + pub extended_connect_protocol_enabled: bool, + /// How long a locally reset stream should ignore frames pub local_reset_duration: Duration, /// Maximum number of locally reset streams to keep at a time pub local_reset_max: usize, + /// Maximum number of remotely reset "pending accept" streams to keep at a + /// time. Going over this number results in a connection error. + pub remote_reset_max: usize, + /// Initial window size of remote initiated streams pub remote_init_window_sz: WindowSize, diff --git a/src/proto/streams/prioritize.rs b/src/proto/streams/prioritize.rs index eaaee162b..88204ddcc 100644 --- a/src/proto/streams/prioritize.rs +++ b/src/proto/streams/prioritize.rs @@ -7,9 +7,11 @@ use crate::codec::UserError; use crate::codec::UserError::*; use bytes::buf::{Buf, Take}; -use std::io; -use std::task::{Context, Poll, Waker}; -use std::{cmp, fmt, mem}; +use std::{ + cmp::{self, Ordering}, + fmt, io, mem, + task::{Context, Poll, Waker}, +}; /// # Warning /// @@ -51,6 +53,9 @@ pub(super) struct Prioritize { /// What `DATA` frame is currently being sent in the codec. in_flight_data_frame: InFlightData, + + /// The maximum amount of bytes a stream should buffer. + max_buffer_size: usize, } #[derive(Debug, Eq, PartialEq)] @@ -93,9 +98,14 @@ impl Prioritize { flow, last_opened_id: StreamId::ZERO, in_flight_data_frame: InFlightData::Nothing, + max_buffer_size: config.local_max_buffer_size, } } + pub(crate) fn max_buffer_size(&self) -> usize { + self.max_buffer_size + } + /// Queue a frame to be sent to the remote pub fn queue_frame( &mut self, @@ -227,39 +237,43 @@ impl Prioritize { // If it were less, then we could never send out the buffered data. let capacity = (capacity as usize) + stream.buffered_send_data; - if capacity == stream.requested_send_capacity as usize { - // Nothing to do - } else if capacity < stream.requested_send_capacity as usize { - // Update the target requested capacity - stream.requested_send_capacity = capacity as WindowSize; + match capacity.cmp(&(stream.requested_send_capacity as usize)) { + Ordering::Equal => { + // Nothing to do + } + Ordering::Less => { + // Update the target requested capacity + stream.requested_send_capacity = capacity as WindowSize; - // Currently available capacity assigned to the stream - let available = stream.send_flow.available().as_size(); + // Currently available capacity assigned to the stream + let available = stream.send_flow.available().as_size(); - // If the stream has more assigned capacity than requested, reclaim - // some for the connection - if available as usize > capacity { - let diff = available - capacity as WindowSize; + // If the stream has more assigned capacity than requested, reclaim + // some for the connection + if available as usize > capacity { + let diff = available - capacity as WindowSize; - stream.send_flow.claim_capacity(diff); + stream.send_flow.claim_capacity(diff); - self.assign_connection_capacity(diff, stream, counts); - } - } else { - // If trying to *add* capacity, but the stream send side is closed, - // there's nothing to be done. - if stream.state.is_send_closed() { - return; + self.assign_connection_capacity(diff, stream, counts); + } } + Ordering::Greater => { + // If trying to *add* capacity, but the stream send side is closed, + // there's nothing to be done. + if stream.state.is_send_closed() { + return; + } - // Update the target requested capacity - stream.requested_send_capacity = - cmp::min(capacity, WindowSize::MAX as usize) as WindowSize; + // Update the target requested capacity + stream.requested_send_capacity = + cmp::min(capacity, WindowSize::MAX as usize) as WindowSize; - // Try to assign additional capacity to the stream. If none is - // currently available, the stream will be queued to receive some - // when more becomes available. - self.try_assign_capacity(stream); + // Try to assign additional capacity to the stream. If none is + // currently available, the stream will be queued to receive some + // when more becomes available. + self.try_assign_capacity(stream); + } } } @@ -309,9 +323,11 @@ impl Prioritize { /// connection pub fn reclaim_all_capacity(&mut self, stream: &mut store::Ptr, counts: &mut Counts) { let available = stream.send_flow.available().as_size(); - stream.send_flow.claim_capacity(available); - // Re-assign all capacity to the connection - self.assign_connection_capacity(available, stream, counts); + if available > 0 { + stream.send_flow.claim_capacity(available); + // Re-assign all capacity to the connection + self.assign_connection_capacity(available, stream, counts); + } } /// Reclaim just reserved capacity, not buffered capacity, and re-assign @@ -364,11 +380,11 @@ impl Prioritize { continue; } - counts.transition(stream, |_, mut stream| { + counts.transition(stream, |_, stream| { // Try to assign capacity to the stream. This will also re-queue the // stream if there isn't enough connection level capacity to fulfill // the capacity request. - self.try_assign_capacity(&mut stream); + self.try_assign_capacity(stream); }) } } @@ -424,7 +440,7 @@ impl Prioritize { tracing::trace!(capacity = assign, "assigning"); // Assign the capacity to the stream - stream.assign_capacity(assign); + stream.assign_capacity(assign, self.max_buffer_size); // Claim the capacity from the connection self.flow.claim_capacity(assign); @@ -730,16 +746,19 @@ impl Prioritize { // capacity at this point. debug_assert!(len <= self.flow.window_size()); + // Check if the stream level window the peer knows is available. In some + // scenarios, maybe the window we know is available but the window which + // peer knows is not. + if len > 0 && len > stream.send_flow.window_size() { + stream.pending_send.push_front(buffer, frame.into()); + continue; + } + tracing::trace!(len, "sending data frame"); // Update the flow control tracing::trace_span!("updating stream flow").in_scope(|| { - stream.send_flow.send_data(len); - - // Decrement the stream's buffered data counter - debug_assert!(stream.buffered_send_data >= len as usize); - stream.buffered_send_data -= len as usize; - stream.requested_send_capacity -= len; + stream.send_data(len, self.max_buffer_size); // Assign the capacity back to the connection that // was just consumed from the stream in the previous diff --git a/src/proto/streams/recv.rs b/src/proto/streams/recv.rs index be996b963..0fe2bdd57 100644 --- a/src/proto/streams/recv.rs +++ b/src/proto/streams/recv.rs @@ -2,12 +2,12 @@ use super::*; use crate::codec::UserError; use crate::frame::{self, PushPromiseHeaderError, Reason, DEFAULT_INITIAL_WINDOW_SIZE}; use crate::proto::{self, Error}; -use std::task::Context; use http::{HeaderMap, Request, Response}; +use std::cmp::Ordering; use std::io; -use std::task::{Poll, Waker}; +use std::task::{Context, Poll, Waker}; use std::time::{Duration, Instant}; #[derive(Debug)] @@ -56,6 +56,9 @@ pub(super) struct Recv { /// If push promises are allowed to be received. is_push_enabled: bool, + + /// If extended connect protocol is enabled. + is_extended_connect_protocol_enabled: bool, } #[derive(Debug)] @@ -103,6 +106,7 @@ impl Recv { buffer: Buffer::new(), refused: None, is_push_enabled: config.local_push_enabled, + is_extended_connect_protocol_enabled: config.extended_connect_protocol_enabled, } } @@ -174,7 +178,7 @@ impl Recv { if let Some(content_length) = frame.fields().get(header::CONTENT_LENGTH) { let content_length = match frame::parse_u64(content_length.as_bytes()) { Ok(v) => v, - Err(()) => { + Err(_) => { proto_err!(stream: "could not parse content-length; stream={:?}", stream.id); return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); } @@ -216,6 +220,15 @@ impl Recv { let stream_id = frame.stream_id(); let (pseudo, fields) = frame.into_parts(); + + if pseudo.protocol.is_some() + && counts.peer().is_server() + && !self.is_extended_connect_protocol_enabled + { + proto_err!(stream: "cannot use :protocol if extended connect protocol is disabled; stream={:?}", stream.id); + return Err(Error::library_reset(stream.id, Reason::PROTOCOL_ERROR).into()); + } + if !pseudo.is_informational() { let message = counts .peer() @@ -449,60 +462,62 @@ impl Recv { settings: &frame::Settings, store: &mut Store, ) -> Result<(), proto::Error> { - let target = if let Some(val) = settings.initial_window_size() { - val - } else { - return Ok(()); - }; + if let Some(val) = settings.is_extended_connect_protocol_enabled() { + self.is_extended_connect_protocol_enabled = val; + } - let old_sz = self.init_window_sz; - self.init_window_sz = target; + if let Some(target) = settings.initial_window_size() { + let old_sz = self.init_window_sz; + self.init_window_sz = target; - tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,); + tracing::trace!("update_initial_window_size; new={}; old={}", target, old_sz,); - // Per RFC 7540 §6.9.2: - // - // In addition to changing the flow-control window for streams that are - // not yet active, a SETTINGS frame can alter the initial flow-control - // window size for streams with active flow-control windows (that is, - // streams in the "open" or "half-closed (remote)" state). When the - // value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust - // the size of all stream flow-control windows that it maintains by the - // difference between the new value and the old value. - // - // A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available - // space in a flow-control window to become negative. A sender MUST - // track the negative flow-control window and MUST NOT send new - // flow-controlled frames until it receives WINDOW_UPDATE frames that - // cause the flow-control window to become positive. - - if target < old_sz { - // We must decrease the (local) window on every open stream. - let dec = old_sz - target; - tracing::trace!("decrementing all windows; dec={}", dec); - - store.for_each(|mut stream| { - stream.recv_flow.dec_recv_window(dec); - Ok(()) - }) - } else if target > old_sz { - // We must increase the (local) window on every open stream. - let inc = target - old_sz; - tracing::trace!("incrementing all windows; inc={}", inc); - store.for_each(|mut stream| { - // XXX: Shouldn't the peer have already noticed our - // overflow and sent us a GOAWAY? - stream - .recv_flow - .inc_window(inc) - .map_err(proto::Error::library_go_away)?; - stream.recv_flow.assign_capacity(inc); - Ok(()) - }) - } else { - // size is the same... so do nothing - Ok(()) + // Per RFC 7540 §6.9.2: + // + // In addition to changing the flow-control window for streams that are + // not yet active, a SETTINGS frame can alter the initial flow-control + // window size for streams with active flow-control windows (that is, + // streams in the "open" or "half-closed (remote)" state). When the + // value of SETTINGS_INITIAL_WINDOW_SIZE changes, a receiver MUST adjust + // the size of all stream flow-control windows that it maintains by the + // difference between the new value and the old value. + // + // A change to `SETTINGS_INITIAL_WINDOW_SIZE` can cause the available + // space in a flow-control window to become negative. A sender MUST + // track the negative flow-control window and MUST NOT send new + // flow-controlled frames until it receives WINDOW_UPDATE frames that + // cause the flow-control window to become positive. + + match target.cmp(&old_sz) { + Ordering::Less => { + // We must decrease the (local) window on every open stream. + let dec = old_sz - target; + tracing::trace!("decrementing all windows; dec={}", dec); + + store.for_each(|mut stream| { + stream.recv_flow.dec_recv_window(dec); + }) + } + Ordering::Greater => { + // We must increase the (local) window on every open stream. + let inc = target - old_sz; + tracing::trace!("incrementing all windows; inc={}", inc); + store.try_for_each(|mut stream| { + // XXX: Shouldn't the peer have already noticed our + // overflow and sent us a GOAWAY? + stream + .recv_flow + .inc_window(inc) + .map_err(proto::Error::library_go_away)?; + stream.recv_flow.assign_capacity(inc); + Ok::<_, proto::Error>(()) + })?; + } + Ordering::Equal => (), + } } + + Ok(()) } pub fn is_end_stream(&self, stream: &store::Ptr) -> bool { @@ -546,7 +561,7 @@ impl Recv { "recv_data; frame ignored on locally reset {:?} for some time", stream.id, ); - return Ok(self.ignore_data(sz)?); + return self.ignore_data(sz); } // Ensure that there is enough capacity on the connection before acting @@ -586,10 +601,20 @@ impl Recv { if stream.state.recv_close().is_err() { proto_err!(conn: "recv_data: failed to transition to closed state; stream={:?}", stream.id); - return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } } + // Received a frame, but no one cared about it. fix issue#648 + if !stream.is_recv { + tracing::trace!( + "recv_data; frame ignored on stream release {:?} for some time", + stream.id, + ); + self.release_connection_capacity(sz, &mut None); + return Ok(()); + } + // Update stream level flow control stream.recv_flow.send_data(sz); @@ -716,12 +741,39 @@ impl Recv { } /// Handle remote sending an explicit RST_STREAM. - pub fn recv_reset(&mut self, frame: frame::Reset, stream: &mut Stream) { + pub fn recv_reset( + &mut self, + frame: frame::Reset, + stream: &mut Stream, + counts: &mut Counts, + ) -> Result<(), Error> { + // Reseting a stream that the user hasn't accepted is possible, + // but should be done with care. These streams will continue + // to take up memory in the accept queue, but will no longer be + // counted as "concurrent" streams. + // + // So, we have a separate limit for these. + // + // See https://github.com/hyperium/hyper/issues/2877 + if stream.is_pending_accept { + if counts.can_inc_num_remote_reset_streams() { + counts.inc_num_remote_reset_streams(); + } else { + tracing::warn!( + "recv_reset; remotely-reset pending-accept streams reached limit ({:?})", + counts.max_remote_reset_streams(), + ); + return Err(Error::library_go_away(Reason::ENHANCE_YOUR_CALM)); + } + } + // Notify the stream stream.state.recv_reset(frame, stream.is_pending_send); stream.notify_send(); stream.notify_recv(); + + Ok(()) } /// Handle a connection-level error @@ -746,7 +798,7 @@ impl Recv { } pub(super) fn clear_recv_buffer(&mut self, stream: &mut Stream) { - while let Some(_) = stream.pending_recv.pop_front(&mut self.buffer) { + while stream.pending_recv.pop_front(&mut self.buffer).is_some() { // drop it } } @@ -776,6 +828,16 @@ impl Recv { } } + pub(super) fn maybe_reset_next_stream_id(&mut self, id: StreamId) { + if let Ok(next_id) = self.next_stream_id { + // !Peer::is_local_init should have been called beforehand + debug_assert_eq!(id.is_server_initiated(), next_id.is_server_initiated()); + if id >= next_id { + self.next_stream_id = id.next_id(); + } + } + } + /// Returns true if the remote peer can reserve a stream with the given ID. pub fn ensure_can_reserve(&self) -> Result<(), Error> { if !self.is_push_enabled { @@ -840,7 +902,10 @@ impl Recv { let reset_duration = self.reset_duration; while let Some(stream) = self.pending_reset_expired.pop_if(store, |stream| { let reset_at = stream.reset_at.expect("reset_at must be set if in queue"); - now - reset_at > reset_duration + // rust-lang/rust#86470 tracks a bug in the standard library where `Instant` + // subtraction can panic (because, on some platforms, `Instant` isn't actually + // monotonic). We use a saturating operation to avoid this panic here. + now.saturating_duration_since(reset_at) > reset_duration }) { counts.transition_after(stream, true); } @@ -995,7 +1060,6 @@ impl Recv { cx: &Context, stream: &mut Stream, ) -> Poll>> { - // TODO: Return error when the stream is reset match stream.pending_recv.pop_front(&mut self.buffer) { Some(Event::Data(payload)) => Poll::Ready(Some(Ok(payload))), Some(event) => { @@ -1056,12 +1120,7 @@ impl Recv { impl Open { pub fn is_push_promise(&self) -> bool { - use self::Open::*; - - match *self { - PushPromise => true, - _ => false, - } + matches!(*self, Self::PushPromise) } } diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 3735d13dd..20aba38d4 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -7,11 +7,11 @@ use crate::frame::{self, Reason}; use crate::proto::{Error, Initiator}; use bytes::Buf; -use http; -use std::task::{Context, Poll, Waker}; use tokio::io::AsyncWrite; +use std::cmp::Ordering; use std::io; +use std::task::{Context, Poll, Waker}; /// Manages state transitions related to outbound frames. #[derive(Debug)] @@ -35,6 +35,9 @@ pub(super) struct Send { prioritize: Prioritize, is_push_enabled: bool, + + /// If extended connect protocol is enabled. + is_extended_connect_protocol_enabled: bool, } /// A value to detect which public API has called `poll_reset`. @@ -53,6 +56,7 @@ impl Send { next_stream_id: Ok(config.local_next_stream_id), prioritize: Prioritize::new(config), is_push_enabled: true, + is_extended_connect_protocol_enabled: false, } } @@ -329,14 +333,7 @@ impl Send { /// Current available stream send capacity pub fn capacity(&self, stream: &mut store::Ptr) -> WindowSize { - let available = stream.send_flow.available().as_size(); - let buffered = stream.buffered_send_data; - - if available as usize <= buffered { - 0 - } else { - available - buffered as WindowSize - } + stream.capacity(self.prioritize.max_buffer_size()) } pub fn poll_reset( @@ -429,6 +426,10 @@ impl Send { counts: &mut Counts, task: &mut Option, ) -> Result<(), Error> { + if let Some(val) = settings.is_extended_connect_protocol_enabled() { + self.is_extended_connect_protocol_enabled = val; + } + // Applies an update to the remote endpoint's initial window size. // // Per RFC 7540 §6.9.2: @@ -450,59 +451,61 @@ impl Send { let old_val = self.init_window_sz; self.init_window_sz = val; - if val < old_val { - // We must decrease the (remote) window on every open stream. - let dec = old_val - val; - tracing::trace!("decrementing all windows; dec={}", dec); - - let mut total_reclaimed = 0; - store.for_each(|mut stream| { - let stream = &mut *stream; - - stream.send_flow.dec_send_window(dec); - - // It's possible that decreasing the window causes - // `window_size` (the stream-specific window) to fall below - // `available` (the portion of the connection-level window - // that we have allocated to the stream). - // In this case, we should take that excess allocation away - // and reassign it to other streams. - let window_size = stream.send_flow.window_size(); - let available = stream.send_flow.available().as_size(); - let reclaimed = if available > window_size { - // Drop down to `window_size`. - let reclaim = available - window_size; - stream.send_flow.claim_capacity(reclaim); - total_reclaimed += reclaim; - reclaim - } else { - 0 - }; - - tracing::trace!( - "decremented stream window; id={:?}; decr={}; reclaimed={}; flow={:?}", - stream.id, - dec, - reclaimed, - stream.send_flow - ); - - // TODO: Should this notify the producer when the capacity - // of a stream is reduced? Maybe it should if the capacity - // is reduced to zero, allowing the producer to stop work. - - Ok::<_, Error>(()) - })?; - - self.prioritize - .assign_connection_capacity(total_reclaimed, store, counts); - } else if val > old_val { - let inc = val - old_val; - - store.for_each(|mut stream| { - self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) - .map_err(Error::library_go_away) - })?; + match val.cmp(&old_val) { + Ordering::Less => { + // We must decrease the (remote) window on every open stream. + let dec = old_val - val; + tracing::trace!("decrementing all windows; dec={}", dec); + + let mut total_reclaimed = 0; + store.for_each(|mut stream| { + let stream = &mut *stream; + + stream.send_flow.dec_send_window(dec); + + // It's possible that decreasing the window causes + // `window_size` (the stream-specific window) to fall below + // `available` (the portion of the connection-level window + // that we have allocated to the stream). + // In this case, we should take that excess allocation away + // and reassign it to other streams. + let window_size = stream.send_flow.window_size(); + let available = stream.send_flow.available().as_size(); + let reclaimed = if available > window_size { + // Drop down to `window_size`. + let reclaim = available - window_size; + stream.send_flow.claim_capacity(reclaim); + total_reclaimed += reclaim; + reclaim + } else { + 0 + }; + + tracing::trace!( + "decremented stream window; id={:?}; decr={}; reclaimed={}; flow={:?}", + stream.id, + dec, + reclaimed, + stream.send_flow + ); + + // TODO: Should this notify the producer when the capacity + // of a stream is reduced? Maybe it should if the capacity + // is reduced to zero, allowing the producer to stop work. + }); + + self.prioritize + .assign_connection_capacity(total_reclaimed, store, counts); + } + Ordering::Greater => { + let inc = val - old_val; + + store.try_for_each(|mut stream| { + self.recv_stream_window_update(inc, buffer, &mut stream, counts, task) + .map_err(Error::library_go_away) + })?; + } + Ordering::Equal => (), } } @@ -554,4 +557,8 @@ impl Send { } } } + + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.is_extended_connect_protocol_enabled + } } diff --git a/src/proto/streams/state.rs b/src/proto/streams/state.rs index 9931d41b1..76638fc87 100644 --- a/src/proto/streams/state.rs +++ b/src/proto/streams/state.rs @@ -303,7 +303,13 @@ impl State { Closed(..) => {} ref state => { tracing::trace!("recv_eof; state={:?}", state); - self.inner = Closed(Cause::Error(io::ErrorKind::BrokenPipe.into())); + self.inner = Closed(Cause::Error( + io::Error::new( + io::ErrorKind::BrokenPipe, + "stream closed because of a broken pipe", + ) + .into(), + )); } } } @@ -343,16 +349,20 @@ impl State { } pub fn is_scheduled_reset(&self) -> bool { + matches!(self.inner, Closed(Cause::ScheduledLibraryReset(..))) + } + + pub fn is_local_reset(&self) -> bool { match self.inner { + Closed(Cause::Error(ref e)) => e.is_local(), Closed(Cause::ScheduledLibraryReset(..)) => true, _ => false, } } - pub fn is_local_reset(&self) -> bool { + pub fn is_remote_reset(&self) -> bool { match self.inner { - Closed(Cause::Error(ref e)) => e.is_local(), - Closed(Cause::ScheduledLibraryReset(..)) => true, + Closed(Cause::Error(ref e)) => !e.is_local(), _ => false, } } @@ -367,65 +377,57 @@ impl State { } pub fn is_send_streaming(&self) -> bool { - match self.inner { + matches!( + self.inner, Open { - local: Streaming, .. - } => true, - HalfClosedRemote(Streaming) => true, - _ => false, - } + local: Streaming, + .. + } | HalfClosedRemote(Streaming) + ) } /// Returns true when the stream is in a state to receive headers pub fn is_recv_headers(&self) -> bool { - match self.inner { - Idle => true, - Open { + matches!( + self.inner, + Idle | Open { remote: AwaitingHeaders, .. - } => true, - HalfClosedLocal(AwaitingHeaders) => true, - ReservedRemote => true, - _ => false, - } + } | HalfClosedLocal(AwaitingHeaders) + | ReservedRemote + ) } pub fn is_recv_streaming(&self) -> bool { - match self.inner { + matches!( + self.inner, Open { - remote: Streaming, .. - } => true, - HalfClosedLocal(Streaming) => true, - _ => false, - } + remote: Streaming, + .. + } | HalfClosedLocal(Streaming) + ) } pub fn is_closed(&self) -> bool { - match self.inner { - Closed(_) => true, - _ => false, - } + matches!(self.inner, Closed(_)) } pub fn is_recv_closed(&self) -> bool { - match self.inner { - Closed(..) | HalfClosedRemote(..) | ReservedLocal => true, - _ => false, - } + matches!( + self.inner, + Closed(..) | HalfClosedRemote(..) | ReservedLocal + ) } pub fn is_send_closed(&self) -> bool { - match self.inner { - Closed(..) | HalfClosedLocal(..) | ReservedRemote => true, - _ => false, - } + matches!( + self.inner, + Closed(..) | HalfClosedLocal(..) | ReservedRemote + ) } pub fn is_idle(&self) -> bool { - match self.inner { - Idle => true, - _ => false, - } + matches!(self.inner, Idle) } pub fn ensure_recv_open(&self) -> Result { diff --git a/src/proto/streams/store.rs b/src/proto/streams/store.rs index ac58f43ac..d33a01cce 100644 --- a/src/proto/streams/store.rs +++ b/src/proto/streams/store.rs @@ -1,9 +1,8 @@ use super::*; -use slab; - use indexmap::{self, IndexMap}; +use std::convert::Infallible; use std::fmt; use std::marker::PhantomData; use std::ops; @@ -128,7 +127,20 @@ impl Store { } } - pub fn for_each(&mut self, mut f: F) -> Result<(), E> + pub(crate) fn for_each(&mut self, mut f: F) + where + F: FnMut(Ptr), + { + match self.try_for_each(|ptr| { + f(ptr); + Ok::<_, Infallible>(()) + }) { + Ok(()) => (), + Err(infallible) => match infallible {}, + } + } + + pub fn try_for_each(&mut self, mut f: F) -> Result<(), E> where F: FnMut(Ptr) -> Result<(), E>, { @@ -288,15 +300,15 @@ where let mut stream = store.resolve(idxs.head); if idxs.head == idxs.tail { - assert!(N::next(&*stream).is_none()); + assert!(N::next(&stream).is_none()); self.indices = None; } else { - idxs.head = N::take_next(&mut *stream).unwrap(); + idxs.head = N::take_next(&mut stream).unwrap(); self.indices = Some(idxs); } - debug_assert!(N::is_queued(&*stream)); - N::set_queued(&mut *stream, false); + debug_assert!(N::is_queued(&stream)); + N::set_queued(&mut stream, false); return Some(stream); } @@ -333,7 +345,7 @@ impl<'a> Ptr<'a> { } pub fn store_mut(&mut self) -> &mut Store { - &mut self.store + self.store } /// Remove the stream from the store diff --git a/src/proto/streams/stream.rs b/src/proto/streams/stream.rs index 79de47a9a..2888d744b 100644 --- a/src/proto/streams/stream.rs +++ b/src/proto/streams/stream.rs @@ -99,6 +99,9 @@ pub(super) struct Stream { /// Frames pending for this stream to read pub pending_recv: buffer::Deque, + /// When the RecvStream drop occurs, no data should be received. + pub is_recv: bool, + /// Task tracking receiving frames pub recv_task: Option, @@ -180,6 +183,7 @@ impl Stream { reset_at: None, next_reset_expire: None, pending_recv: buffer::Deque::new(), + is_recv: true, recv_task: None, pending_push_promises: store::Queue::new(), content_length: ContentLength::Omitted, @@ -248,7 +252,7 @@ impl Stream { // The stream is not in any queue !self.is_pending_send && !self.is_pending_send_capacity && !self.is_pending_accept && !self.is_pending_window_update && - !self.is_pending_open && !self.reset_at.is_some() + !self.is_pending_open && self.reset_at.is_none() } /// Returns true when the consumer of the stream has dropped all handles @@ -260,25 +264,65 @@ impl Stream { self.ref_count == 0 && !self.state.is_closed() } - pub fn assign_capacity(&mut self, capacity: WindowSize) { + /// Current available stream send capacity + pub fn capacity(&self, max_buffer_size: usize) -> WindowSize { + let available = self.send_flow.available().as_size() as usize; + let buffered = self.buffered_send_data; + + available.min(max_buffer_size).saturating_sub(buffered) as WindowSize + } + + pub fn assign_capacity(&mut self, capacity: WindowSize, max_buffer_size: usize) { + let prev_capacity = self.capacity(max_buffer_size); debug_assert!(capacity > 0); - self.send_capacity_inc = true; self.send_flow.assign_capacity(capacity); tracing::trace!( - " assigned capacity to stream; available={}; buffered={}; id={:?}", + " assigned capacity to stream; available={}; buffered={}; id={:?}; max_buffer_size={} prev={}", + self.send_flow.available(), + self.buffered_send_data, + self.id, + max_buffer_size, + prev_capacity, + ); + + if prev_capacity < self.capacity(max_buffer_size) { + self.notify_capacity(); + } + } + + pub fn send_data(&mut self, len: WindowSize, max_buffer_size: usize) { + let prev_capacity = self.capacity(max_buffer_size); + + self.send_flow.send_data(len); + + // Decrement the stream's buffered data counter + debug_assert!(self.buffered_send_data >= len as usize); + self.buffered_send_data -= len as usize; + self.requested_send_capacity -= len; + + tracing::trace!( + " sent stream data; available={}; buffered={}; id={:?}; max_buffer_size={} prev={}", self.send_flow.available(), self.buffered_send_data, - self.id + self.id, + max_buffer_size, + prev_capacity, ); - // Only notify if the capacity exceeds the amount of buffered data - if self.send_flow.available() > self.buffered_send_data { - tracing::trace!(" notifying task"); - self.notify_send(); + if prev_capacity < self.capacity(max_buffer_size) { + self.notify_capacity(); } } + /// If the capacity was limited because of the max_send_buffer_size, + /// then consider waking the send task again... + pub fn notify_capacity(&mut self) { + self.send_capacity_inc = true; + tracing::trace!(" notifying task"); + self.notify_send(); + } + /// Returns `Err` when the decrement cannot be completed due to overflow. pub fn dec_content_length(&mut self, len: usize) -> Result<(), ()> { match self.content_length { @@ -365,7 +409,7 @@ impl store::Next for NextSend { if val { // ensure that stream is not queued for being opened // if it's being put into queue for sending data - debug_assert_eq!(stream.is_pending_open, false); + debug_assert!(!stream.is_pending_open); } stream.is_pending_send = val; } @@ -436,7 +480,7 @@ impl store::Next for NextOpen { if val { // ensure that stream is not queued for being sent // if it's being put into queue for opening the stream - debug_assert_eq!(stream.is_pending_send, false); + debug_assert!(!stream.is_pending_send); } stream.is_pending_open = val; } @@ -472,9 +516,6 @@ impl store::Next for NextResetExpire { impl ContentLength { pub fn is_head(&self) -> bool { - match *self { - ContentLength::Head => true, - _ => false, - } + matches!(*self, Self::Head) } } diff --git a/src/proto/streams/streams.rs b/src/proto/streams/streams.rs index 4962db8d2..dbaebfa7a 100644 --- a/src/proto/streams/streams.rs +++ b/src/proto/streams/streams.rs @@ -2,6 +2,7 @@ use super::recv::RecvHeaderBlockError; use super::store::{self, Entry, Resolve, Store}; use super::{Buffer, Config, Counts, Prioritized, Recv, Send, Stream, StreamId}; use crate::codec::{Codec, SendError, UserError}; +use crate::ext::Protocol; use crate::frame::{self, Frame, Reason}; use crate::proto::{peer, Error, Initiator, Open, Peer, WindowSize}; use crate::{client, proto, server}; @@ -11,7 +12,6 @@ use http::{HeaderMap, Request, Response}; use std::task::{Context, Poll, Waker}; use tokio::io::AsyncWrite; -use crate::PollExt; use std::sync::{Arc, Mutex}; use std::{fmt, io}; @@ -140,6 +140,12 @@ where // TODO: ideally, OpaqueStreamRefs::new would do this, but we're holding // the lock, so it can't. me.refs += 1; + + // Pending-accepted remotely-reset streams are counted. + if stream.state.is_remote_reset() { + me.counts.dec_num_remote_reset_streams(); + } + StreamRef { opaque: OpaqueStreamRef::new(self.inner.clone(), stream), send_buffer: self.send_buffer.clone(), @@ -214,6 +220,8 @@ where use super::stream::ContentLength; use http::Method; + let protocol = request.extensions_mut().remove::(); + // Clear before taking lock, incase extensions contain a StreamRef. request.extensions_mut().clear(); @@ -261,7 +269,8 @@ where } // Convert the message - let headers = client::Peer::convert_send_message(stream_id, request, end_of_stream)?; + let headers = + client::Peer::convert_send_message(stream_id, request, protocol, end_of_stream)?; let mut stream = me.store.insert(stream.id, stream); @@ -294,35 +303,44 @@ where send_buffer: self.send_buffer.clone(), }) } + + pub(crate) fn is_extended_connect_protocol_enabled(&self) -> bool { + self.inner + .lock() + .unwrap() + .actions + .send + .is_extended_connect_protocol_enabled() + } } impl DynStreams<'_, B> { pub fn recv_headers(&mut self, frame: frame::Headers) -> Result<(), Error> { let mut me = self.inner.lock().unwrap(); - me.recv_headers(self.peer, &self.send_buffer, frame) + me.recv_headers(self.peer, self.send_buffer, frame) } pub fn recv_data(&mut self, frame: frame::Data) -> Result<(), Error> { let mut me = self.inner.lock().unwrap(); - me.recv_data(self.peer, &self.send_buffer, frame) + me.recv_data(self.peer, self.send_buffer, frame) } pub fn recv_reset(&mut self, frame: frame::Reset) -> Result<(), Error> { let mut me = self.inner.lock().unwrap(); - me.recv_reset(&self.send_buffer, frame) + me.recv_reset(self.send_buffer, frame) } /// Notify all streams that a connection-level error happened. pub fn handle_error(&mut self, err: proto::Error) -> StreamId { let mut me = self.inner.lock().unwrap(); - me.handle_error(&self.send_buffer, err) + me.handle_error(self.send_buffer, err) } pub fn recv_go_away(&mut self, frame: &frame::GoAway) -> Result<(), Error> { let mut me = self.inner.lock().unwrap(); - me.recv_go_away(&self.send_buffer, frame) + me.recv_go_away(self.send_buffer, frame) } pub fn last_processed_id(&self) -> StreamId { @@ -331,22 +349,22 @@ impl DynStreams<'_, B> { pub fn recv_window_update(&mut self, frame: frame::WindowUpdate) -> Result<(), Error> { let mut me = self.inner.lock().unwrap(); - me.recv_window_update(&self.send_buffer, frame) + me.recv_window_update(self.send_buffer, frame) } pub fn recv_push_promise(&mut self, frame: frame::PushPromise) -> Result<(), Error> { let mut me = self.inner.lock().unwrap(); - me.recv_push_promise(&self.send_buffer, frame) + me.recv_push_promise(self.send_buffer, frame) } pub fn recv_eof(&mut self, clear_pending_accept: bool) -> Result<(), ()> { let mut me = self.inner.lock().map_err(|_| ())?; - me.recv_eof(&self.send_buffer, clear_pending_accept) + me.recv_eof(self.send_buffer, clear_pending_accept) } pub fn send_reset(&mut self, id: StreamId, reason: Reason) { let mut me = self.inner.lock().unwrap(); - me.send_reset(&self.send_buffer, id, reason) + me.send_reset(self.send_buffer, id, reason) } pub fn send_go_away(&mut self, last_processed_id: StreamId) { @@ -589,7 +607,7 @@ impl Inner { let actions = &mut self.actions; self.counts.transition(stream, |counts, stream| { - actions.recv.recv_reset(frame, stream); + actions.recv.recv_reset(frame, stream, counts)?; actions.send.handle_error(send_buffer, stream, counts); assert!(stream.state.is_closed()); Ok(()) @@ -643,15 +661,12 @@ impl Inner { let last_processed_id = actions.recv.last_processed_id(); - self.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.handle_error(&err, &mut *stream); - actions.send.handle_error(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) + self.store.for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.handle_error(&err, &mut *stream); + actions.send.handle_error(send_buffer, stream, counts); }) - .unwrap(); + }); actions.conn_error = Some(err); @@ -674,19 +689,14 @@ impl Inner { let err = Error::remote_go_away(frame.debug_data().clone(), frame.reason()); - self.store - .for_each(|stream| { - if stream.id > last_stream_id { - counts.transition(stream, |counts, stream| { - actions.recv.handle_error(&err, &mut *stream); - actions.send.handle_error(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) - } else { - Ok::<_, ()>(()) - } - }) - .unwrap(); + self.store.for_each(|stream| { + if stream.id > last_stream_id { + counts.transition(stream, |counts, stream| { + actions.recv.handle_error(&err, &mut *stream); + actions.send.handle_error(send_buffer, stream, counts); + }) + } + }); actions.conn_error = Some(err); @@ -721,7 +731,7 @@ impl Inner { } None => { proto_err!(conn: "recv_push_promise: initiating stream is in an invalid state"); - return Err(Error::library_go_away(Reason::PROTOCOL_ERROR).into()); + return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); } }; @@ -802,23 +812,26 @@ impl Inner { let send_buffer = &mut *send_buffer; if actions.conn_error.is_none() { - actions.conn_error = Some(io::Error::from(io::ErrorKind::BrokenPipe).into()); + actions.conn_error = Some( + io::Error::new( + io::ErrorKind::BrokenPipe, + "connection closed because of a broken pipe", + ) + .into(), + ); } tracing::trace!("Streams::recv_eof"); - self.store - .for_each(|stream| { - counts.transition(stream, |counts, stream| { - actions.recv.recv_eof(stream); + self.store.for_each(|stream| { + counts.transition(stream, |counts, stream| { + actions.recv.recv_eof(stream); - // This handles resetting send state associated with the - // stream - actions.send.handle_error(send_buffer, stream, counts); - Ok::<_, ()>(()) - }) + // This handles resetting send state associated with the + // stream + actions.send.handle_error(send_buffer, stream, counts); }) - .expect("recv_eof"); + }); actions.clear_queues(clear_pending_accept, &mut self.store, counts); Ok(()) @@ -881,6 +894,10 @@ impl Inner { // We normally would open this stream, so update our // next-send-id record. self.actions.send.maybe_reset_next_stream_id(id); + } else { + // We normally would recv this stream, so update our + // next-recv-id record. + self.actions.recv.maybe_reset_next_stream_id(id); } let stream = Stream::new(id, 0, 0); @@ -1141,7 +1158,7 @@ impl StreamRef { let mut child_stream = me.store.resolve(child_key); child_stream.unlink(); child_stream.remove(); - return Err(err.into()); + return Err(err); } me.refs += 1; @@ -1224,10 +1241,7 @@ impl StreamRef { .map_err(From::from) } - pub fn clone_to_opaque(&self) -> OpaqueStreamRef - where - B: 'static, - { + pub fn clone_to_opaque(&self) -> OpaqueStreamRef { self.opaque.clone() } @@ -1276,7 +1290,7 @@ impl OpaqueStreamRef { me.actions .recv .poll_pushed(cx, &mut stream) - .map_ok_(|(h, key)| { + .map_ok(|(h, key)| { me.refs += 1; let opaque_ref = OpaqueStreamRef::new(self.inner.clone(), &mut me.store.resolve(key)); @@ -1340,12 +1354,13 @@ impl OpaqueStreamRef { .release_capacity(capacity, &mut stream, &mut me.actions.task) } + /// Clear the receive queue and set the status to no longer receive data frames. pub(crate) fn clear_recv_buffer(&mut self) { let mut me = self.inner.lock().unwrap(); let me = &mut *me; let mut stream = me.store.resolve(self.key); - + stream.is_recv = false; me.actions.recv.clear_recv_buffer(&mut stream); } @@ -1387,7 +1402,7 @@ impl Clone for OpaqueStreamRef { OpaqueStreamRef { inner: self.inner.clone(), - key: self.key.clone(), + key: self.key, } } } @@ -1456,9 +1471,21 @@ fn drop_stream_ref(inner: &Mutex, key: store::Key) { fn maybe_cancel(stream: &mut store::Ptr, actions: &mut Actions, counts: &mut Counts) { if stream.is_canceled_interest() { + // Server is allowed to early respond without fully consuming the client input stream + // But per the RFC, must send a RST_STREAM(NO_ERROR) in such cases. https://www.rfc-editor.org/rfc/rfc7540#section-8.1 + // Some other http2 implementation may interpret other error code as fatal if not respected (i.e: nginx https://trac.nginx.org/nginx/ticket/2376) + let reason = if counts.peer().is_server() + && stream.state.is_send_closed() + && stream.state.is_recv_streaming() + { + Reason::NO_ERROR + } else { + Reason::CANCEL + }; + actions .send - .schedule_implicit_reset(stream, Reason::CANCEL, counts, &mut actions.task); + .schedule_implicit_reset(stream, reason, counts, &mut actions.task); actions.recv.enqueue_reset_expiration(stream, counts); } } diff --git a/src/server.rs b/src/server.rs index 5e4bba97c..9ab64ad93 100644 --- a/src/server.rs +++ b/src/server.rs @@ -126,7 +126,7 @@ use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use std::time::Duration; -use std::{convert, fmt, io, mem}; +use std::{fmt, io}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tracing::instrument::{Instrument, Instrumented}; @@ -240,11 +240,18 @@ pub struct Builder { /// Maximum number of locally reset streams to keep at a time. reset_stream_max: usize, + /// Maximum number of remotely reset streams to allow in the pending + /// accept queue. + pending_accept_reset_stream_max: usize, + /// Initial `Settings` frame to send as part of the handshake. settings: Settings, /// Initial target window size for new connections. initial_target_connection_window_size: Option, + + /// Maximum amount of bytes to "buffer" for writing per stream. + max_send_buffer_size: usize, } /// Send a response back to the client @@ -298,8 +305,8 @@ enum Handshaking { Flushing(Instrumented>>), /// State 2. Connection is waiting for the client preface. ReadingPreface(Instrumented>>), - /// Dummy state for `mem::replace`. - Empty, + /// State 3. Handshake is done, polling again would panic. + Done, } /// Flush a Sink @@ -361,10 +368,10 @@ where impl Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { fn handshake2(io: T, builder: Builder) -> Handshake { - let span = tracing::trace_span!("server_handshake", io = %std::any::type_name::()); + let span = tracing::trace_span!("server_handshake"); let entered = span.enter(); // Create the codec. @@ -384,7 +391,8 @@ where .expect("invalid SETTINGS frame"); // Create the handshake future. - let state = Handshaking::from(codec); + let state = + Handshaking::Flushing(Flush::new(codec).instrument(tracing::trace_span!("flush"))); drop(entered); @@ -409,7 +417,7 @@ where ) -> Poll, SendResponse), crate::Error>>> { // Always try to advance the internal state. Getting Pending also is // needed to allow this function to return Pending. - if let Poll::Ready(_) = self.poll_closed(cx)? { + if self.poll_closed(cx)?.is_ready() { // If the socket is closed, don't return anything // TODO: drop any pending streams return Poll::Ready(None); @@ -470,6 +478,19 @@ where Ok(()) } + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + /// + /// # Errors + /// + /// Returns an error if a previous call is still pending acknowledgement + /// from the remote endpoint. + pub fn enable_connect_protocol(&mut self) -> Result<(), crate::Error> { + self.connection.set_enable_connect_protocol()?; + Ok(()) + } + /// Returns `Ready` when the underlying connection has closed. /// /// If any new inbound streams are received during a call to `poll_closed`, @@ -537,7 +558,7 @@ where /// /// This limit is configured by the client peer by sending the /// [`SETTINGS_MAX_CONCURRENT_STREAMS` parameter][1] in a `SETTINGS` frame. - /// This method returns the currently acknowledged value recieved from the + /// This method returns the currently acknowledged value received from the /// remote. /// /// [1]: https://tools.ietf.org/html/rfc7540#section-5.1.2 @@ -559,13 +580,20 @@ where pub fn max_concurrent_recv_streams(&self) -> usize { self.connection.max_recv_streams() } + + // Could disappear at anytime. + #[doc(hidden)] + #[cfg(feature = "unstable")] + pub fn num_wired_streams(&self) -> usize { + self.connection.num_wired_streams() + } } #[cfg(feature = "stream")] impl futures_core::Stream for Connection where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { type Item = Result<(Request, SendResponse), crate::Error>; @@ -618,8 +646,10 @@ impl Builder { Builder { reset_stream_duration: Duration::from_secs(proto::DEFAULT_RESET_STREAM_SECS), reset_stream_max: proto::DEFAULT_RESET_STREAM_MAX, + pending_accept_reset_stream_max: proto::DEFAULT_REMOTE_RESET_STREAM_MAX, settings: Settings::default(), initial_target_connection_window_size: None, + max_send_buffer_size: proto::DEFAULT_MAX_SEND_BUFFER_SIZE, } } @@ -857,6 +887,67 @@ impl Builder { self } + /// Sets the maximum number of pending-accept remotely-reset streams. + /// + /// Streams that have been received by the peer, but not accepted by the + /// user, can also receive a RST_STREAM. This is a legitimate pattern: one + /// could send a request and then shortly after, realize it is not needed, + /// sending a CANCEL. + /// + /// However, since those streams are now "closed", they don't count towards + /// the max concurrent streams. So, they will sit in the accept queue, + /// using memory. + /// + /// When the number of remotely-reset streams sitting in the pending-accept + /// queue reaches this maximum value, a connection error with the code of + /// `ENHANCE_YOUR_CALM` will be sent to the peer, and returned by the + /// `Future`. + /// + /// The default value is currently 20, but could change. + /// + /// # Examples + /// + /// + /// ``` + /// # use tokio::io::{AsyncRead, AsyncWrite}; + /// # use h2::server::*; + /// # + /// # fn doc(my_io: T) + /// # -> Handshake + /// # { + /// // `server_fut` is a future representing the completion of the HTTP/2 + /// // handshake. + /// let server_fut = Builder::new() + /// .max_pending_accept_reset_streams(100) + /// .handshake(my_io); + /// # server_fut + /// # } + /// # + /// # pub fn main() {} + /// ``` + pub fn max_pending_accept_reset_streams(&mut self, max: usize) -> &mut Self { + self.pending_accept_reset_stream_max = max; + self + } + + /// Sets the maximum send buffer size per stream. + /// + /// Once a stream has buffered up to (or over) the maximum, the stream's + /// flow control will not "poll" additional capacity. Once bytes for the + /// stream have been written to the connection, the send buffer capacity + /// will be freed up again. + /// + /// The default is currently ~400MB, but may change. + /// + /// # Panics + /// + /// This function panics if `max` is larger than `u32::MAX`. + pub fn max_send_buffer_size(&mut self, max: usize) -> &mut Self { + assert!(max <= std::u32::MAX as usize); + self.max_send_buffer_size = max; + self + } + /// Sets the maximum number of concurrent locally reset streams. /// /// When a stream is explicitly reset by either calling @@ -904,6 +995,14 @@ impl Builder { self } + /// Enables the [extended CONNECT protocol]. + /// + /// [extended CONNECT protocol]: https://datatracker.ietf.org/doc/html/rfc8441#section-4 + pub fn enable_connect_protocol(&mut self) -> &mut Self { + self.settings.set_enable_connect_protocol(Some(1)); + self + } + /// Creates a new configured HTTP/2 server backed by `io`. /// /// It is expected that `io` already be in an appropriate state to commence @@ -963,7 +1062,7 @@ impl Builder { pub fn handshake(&self, io: T) -> Handshake where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { Connection::handshake2(io, self.clone()) } @@ -1218,7 +1317,7 @@ where impl Future for Handshake where T: AsyncRead + AsyncWrite + Unpin, - B: Buf + 'static, + B: Buf, { type Output = Result, crate::Error>; @@ -1226,62 +1325,59 @@ where let span = self.span.clone(); // XXX(eliza): T_T let _e = span.enter(); tracing::trace!(state = ?self.state); - use crate::server::Handshaking::*; - - self.state = if let Flushing(ref mut flush) = self.state { - // We're currently flushing a pending SETTINGS frame. Poll the - // flush future, and, if it's completed, advance our state to wait - // for the client preface. - let codec = match Pin::new(flush).poll(cx)? { - Poll::Pending => { - tracing::trace!(flush.poll = %"Pending"); - return Poll::Pending; + + loop { + match &mut self.state { + Handshaking::Flushing(flush) => { + // We're currently flushing a pending SETTINGS frame. Poll the + // flush future, and, if it's completed, advance our state to wait + // for the client preface. + let codec = match Pin::new(flush).poll(cx)? { + Poll::Pending => { + tracing::trace!(flush.poll = %"Pending"); + return Poll::Pending; + } + Poll::Ready(flushed) => { + tracing::trace!(flush.poll = %"Ready"); + flushed + } + }; + self.state = Handshaking::ReadingPreface( + ReadPreface::new(codec).instrument(tracing::trace_span!("read_preface")), + ); } - Poll::Ready(flushed) => { - tracing::trace!(flush.poll = %"Ready"); - flushed + Handshaking::ReadingPreface(read) => { + let codec = ready!(Pin::new(read).poll(cx)?); + + self.state = Handshaking::Done; + + let connection = proto::Connection::new( + codec, + Config { + next_stream_id: 2.into(), + // Server does not need to locally initiate any streams + initial_max_send_streams: 0, + max_send_buffer_size: self.builder.max_send_buffer_size, + reset_stream_duration: self.builder.reset_stream_duration, + reset_stream_max: self.builder.reset_stream_max, + remote_reset_stream_max: self.builder.pending_accept_reset_stream_max, + settings: self.builder.settings.clone(), + }, + ); + + tracing::trace!("connection established!"); + let mut c = Connection { connection }; + if let Some(sz) = self.builder.initial_target_connection_window_size { + c.set_target_window_size(sz); + } + + return Poll::Ready(Ok(c)); + } + Handshaking::Done => { + panic!("Handshaking::poll() called again after handshaking was complete") } - }; - Handshaking::from(ReadPreface::new(codec)) - } else { - // Otherwise, we haven't actually advanced the state, but we have - // to replace it with itself, because we have to return a value. - // (note that the assignment to `self.state` has to be outside of - // the `if let` block above in order to placate the borrow checker). - mem::replace(&mut self.state, Handshaking::Empty) - }; - let poll = if let ReadingPreface(ref mut read) = self.state { - // We're now waiting for the client preface. Poll the `ReadPreface` - // future. If it has completed, we will create a `Connection` handle - // for the connection. - Pin::new(read).poll(cx) - // Actually creating the `Connection` has to occur outside of this - // `if let` block, because we've borrowed `self` mutably in order - // to poll the state and won't be able to borrow the SETTINGS frame - // as well until we release the borrow for `poll()`. - } else { - unreachable!("Handshake::poll() state was not advanced completely!") - }; - poll?.map(|codec| { - let connection = proto::Connection::new( - codec, - Config { - next_stream_id: 2.into(), - // Server does not need to locally initiate any streams - initial_max_send_streams: 0, - reset_stream_duration: self.builder.reset_stream_duration, - reset_stream_max: self.builder.reset_stream_max, - settings: self.builder.settings.clone(), - }, - ); - - tracing::trace!("connection established!"); - let mut c = Connection { connection }; - if let Some(sz) = self.builder.initial_target_connection_window_size { - c.set_target_window_size(sz); } - Ok(c) - }) + } } } @@ -1360,7 +1456,7 @@ impl Peer { _, ) = request.into_parts(); - let pseudo = Pseudo::request(method, uri); + let pseudo = Pseudo::request(method, uri, None); Ok(frame::PushPromise::new( stream_id, @@ -1410,6 +1506,16 @@ impl proto::Peer for Peer { malformed!("malformed headers: missing method"); } + let has_protocol = pseudo.protocol.is_some(); + if has_protocol { + if is_connect { + // Assert that we have the right type. + b = b.extension::(pseudo.protocol.unwrap()); + } else { + malformed!("malformed headers: :protocol on non-CONNECT request"); + } + } + if pseudo.status.is_some() { malformed!("malformed headers: :status field on request"); } @@ -1432,8 +1538,8 @@ impl proto::Peer for Peer { // A :scheme is required, except CONNECT. if let Some(scheme) = pseudo.scheme { - if is_connect { - malformed!(":scheme in CONNECT"); + if is_connect && !has_protocol { + malformed!("malformed headers: :scheme in CONNECT"); } let maybe_scheme = scheme.parse(); let scheme = maybe_scheme.or_else(|why| { @@ -1450,13 +1556,13 @@ impl proto::Peer for Peer { if parts.authority.is_some() { parts.scheme = Some(scheme); } - } else if !is_connect { + } else if !is_connect || has_protocol { malformed!("malformed headers: missing scheme"); } if let Some(path) = pseudo.path { - if is_connect { - malformed!(":path in CONNECT"); + if is_connect && !has_protocol { + malformed!("malformed headers: :path in CONNECT"); } // This cannot be empty @@ -1468,6 +1574,8 @@ impl proto::Peer for Peer { parts.path_and_query = Some(maybe_path.or_else(|why| { malformed!("malformed headers: malformed path ({:?}): {}", path, why,) })?); + } else if is_connect && has_protocol { + malformed!("malformed headers: missing path in extended CONNECT"); } b = b.uri(parts); @@ -1497,42 +1605,9 @@ where #[inline] fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { match *self { - Handshaking::Flushing(_) => write!(f, "Handshaking::Flushing(_)"), - Handshaking::ReadingPreface(_) => write!(f, "Handshaking::ReadingPreface(_)"), - Handshaking::Empty => write!(f, "Handshaking::Empty"), + Handshaking::Flushing(_) => f.write_str("Flushing(_)"), + Handshaking::ReadingPreface(_) => f.write_str("ReadingPreface(_)"), + Handshaking::Done => f.write_str("Done"), } } } - -impl convert::From>> for Handshaking -where - T: AsyncRead + AsyncWrite, - B: Buf, -{ - #[inline] - fn from(flush: Flush>) -> Self { - Handshaking::Flushing(flush.instrument(tracing::trace_span!("flush"))) - } -} - -impl convert::From>> for Handshaking -where - T: AsyncRead + AsyncWrite, - B: Buf, -{ - #[inline] - fn from(read: ReadPreface>) -> Self { - Handshaking::ReadingPreface(read.instrument(tracing::trace_span!("read_preface"))) - } -} - -impl convert::From>> for Handshaking -where - T: AsyncRead + AsyncWrite, - B: Buf, -{ - #[inline] - fn from(codec: Codec>) -> Self { - Handshaking::from(Flush::new(codec)) - } -} diff --git a/src/share.rs b/src/share.rs index 2a4ff1cdd..26b428797 100644 --- a/src/share.rs +++ b/src/share.rs @@ -5,7 +5,6 @@ use crate::proto::{self, WindowSize}; use bytes::{Buf, Bytes}; use http::HeaderMap; -use crate::PollExt; use std::fmt; #[cfg(feature = "stream")] use std::pin::Pin; @@ -95,7 +94,7 @@ use std::task::{Context, Poll}; /// [`send_trailers`]: #method.send_trailers /// [`send_reset`]: #method.send_reset #[derive(Debug)] -pub struct SendStream { +pub struct SendStream { inner: proto::StreamRef, } @@ -109,9 +108,15 @@ pub struct SendStream { /// new stream. /// /// [Section 5.1.1]: https://tools.ietf.org/html/rfc7540#section-5.1.1 -#[derive(Debug, Clone, Eq, PartialEq, Hash)] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] pub struct StreamId(u32); +impl From for u32 { + fn from(src: StreamId) -> Self { + src.0 + } +} + /// Receives the body stream and trailers from the remote peer. /// /// A `RecvStream` is provided by [`client::ResponseFuture`] and @@ -307,8 +312,8 @@ impl SendStream { pub fn poll_capacity(&mut self, cx: &mut Context) -> Poll>> { self.inner .poll_capacity(cx) - .map_ok_(|w| w as usize) - .map_err_(Into::into) + .map_ok(|w| w as usize) + .map_err(Into::into) } /// Sends a single data frame to the remote peer. @@ -383,6 +388,18 @@ impl StreamId { pub(crate) fn from_internal(id: crate::frame::StreamId) -> Self { StreamId(id.into()) } + + /// Returns the `u32` corresponding to this `StreamId` + /// + /// # Note + /// + /// This is the same as the `From` implementation, but + /// included as an inherent method because that implementation doesn't + /// appear in rustdocs, as well as a way to force the type instead of + /// relying on inference. + pub fn as_u32(&self) -> u32 { + (*self).into() + } } // ===== impl RecvStream ===== @@ -403,7 +420,7 @@ impl RecvStream { /// Poll for the next data frame. pub fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll>> { - self.inner.inner.poll_data(cx).map_err_(Into::into) + self.inner.inner.poll_data(cx).map_err(Into::into) } #[doc(hidden)] @@ -539,8 +556,8 @@ impl PingPong { pub fn send_ping(&mut self, ping: Ping) -> Result<(), crate::Error> { // Passing a `Ping` here is just to be forwards-compatible with // eventually allowing choosing a ping payload. For now, we can - // just drop it. - drop(ping); + // just ignore it. + let _ = ping; self.inner.send_ping().map_err(|err| match err { Some(err) => err.into(), diff --git a/tests/h2-fuzz/Cargo.toml b/tests/h2-fuzz/Cargo.toml index 524627f31..dadb62c92 100644 --- a/tests/h2-fuzz/Cargo.toml +++ b/tests/h2-fuzz/Cargo.toml @@ -8,7 +8,7 @@ edition = "2018" [dependencies] h2 = { path = "../.." } -env_logger = { version = "0.5.3", default-features = false } +env_logger = { version = "0.9", default-features = false } futures = { version = "0.3", default-features = false, features = ["std"] } honggfuzz = "0.5" http = "0.2" diff --git a/tests/h2-support/Cargo.toml b/tests/h2-support/Cargo.toml index e97c6b310..f178178eb 100644 --- a/tests/h2-support/Cargo.toml +++ b/tests/h2-support/Cargo.toml @@ -7,9 +7,11 @@ edition = "2018" [dependencies] h2 = { path = "../..", features = ["stream", "unstable"] } +atty = "0.2" bytes = "1" tracing = "0.1" -tracing-subscriber = { version = "0.2", default-features = false, features = ["fmt", "chrono", "ansi"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt"] } +tracing-tree = "0.2" futures = { version = "0.3", default-features = false } http = "0.2" tokio = { version = "1", features = ["time"] } diff --git a/tests/h2-support/src/assert.rs b/tests/h2-support/src/assert.rs index 8bc6d25c7..88e3d4f7c 100644 --- a/tests/h2-support/src/assert.rs +++ b/tests/h2-support/src/assert.rs @@ -47,6 +47,16 @@ macro_rules! assert_settings { }}; } +#[macro_export] +macro_rules! assert_go_away { + ($frame:expr) => {{ + match $frame { + h2::frame::Frame::GoAway(v) => v, + f => panic!("expected GO_AWAY; actual={:?}", f), + } + }}; +} + #[macro_export] macro_rules! poll_err { ($transport:expr) => {{ @@ -80,6 +90,7 @@ macro_rules! assert_default_settings { use h2::frame::Frame; +#[track_caller] pub fn assert_frame_eq, U: Into>(t: T, u: U) { let actual: Frame = t.into(); let expected: Frame = u.into(); diff --git a/tests/h2-support/src/client_ext.rs b/tests/h2-support/src/client_ext.rs index a9ab71d99..eebbae98b 100644 --- a/tests/h2-support/src/client_ext.rs +++ b/tests/h2-support/src/client_ext.rs @@ -11,7 +11,7 @@ pub trait SendRequestExt { impl SendRequestExt for SendRequest where - B: Buf + Unpin + 'static, + B: Buf, { fn get(&mut self, uri: &str) -> ResponseFuture { let req = Request::builder() diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index 824bc5c19..bc4e2e708 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -4,10 +4,13 @@ use std::fmt; use bytes::Bytes; use http::{self, HeaderMap, StatusCode}; -use h2::frame::{self, Frame, StreamId}; +use h2::{ + ext::Protocol, + frame::{self, Frame, StreamId}, +}; -pub const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; -pub const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; +pub const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; +pub const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; // ==== helper functions to easily construct h2 Frames ==== @@ -109,7 +112,9 @@ impl Mock { let method = method.try_into().unwrap(); let uri = uri.try_into().unwrap(); let (id, _, fields) = self.into_parts(); - let frame = frame::Headers::new(id, frame::Pseudo::request(method, uri), fields); + let extensions = Default::default(); + let pseudo = frame::Pseudo::request(method, uri, extensions); + let frame = frame::Headers::new(id, pseudo, fields); Mock(frame) } @@ -179,6 +184,15 @@ impl Mock { Mock(frame::Headers::new(id, pseudo, fields)) } + pub fn protocol(self, value: &str) -> Self { + let (id, mut pseudo, fields) = self.into_parts(); + let value = Protocol::from(value); + + pseudo.set_protocol(value); + + Mock(frame::Headers::new(id, pseudo, fields)) + } + pub fn eos(mut self) -> Self { self.0.set_end_stream(); self @@ -230,8 +244,9 @@ impl Mock { let method = method.try_into().unwrap(); let uri = uri.try_into().unwrap(); let (id, promised, _, fields) = self.into_parts(); - let frame = - frame::PushPromise::new(id, promised, frame::Pseudo::request(method, uri), fields); + let extensions = Default::default(); + let pseudo = frame::Pseudo::request(method, uri, extensions); + let frame = frame::PushPromise::new(id, promised, pseudo, fields); Mock(frame) } @@ -282,6 +297,10 @@ impl Mock { self.reason(frame::Reason::FRAME_SIZE_ERROR) } + pub fn calm(self) -> Self { + self.reason(frame::Reason::ENHANCE_YOUR_CALM) + } + pub fn no_error(self) -> Self { self.reason(frame::Reason::NO_ERROR) } @@ -352,6 +371,11 @@ impl Mock { self.0.set_enable_push(false); self } + + pub fn enable_connect_protocol(mut self, val: u32) -> Self { + self.0.set_enable_connect_protocol(Some(val)); + self + } } impl From> for frame::Settings { diff --git a/tests/h2-support/src/mock.rs b/tests/h2-support/src/mock.rs index b5df9ad9b..18d084841 100644 --- a/tests/h2-support/src/mock.rs +++ b/tests/h2-support/src/mock.rs @@ -56,7 +56,7 @@ struct Inner { closed: bool, } -const PREFACE: &'static [u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; +const PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// Create a new mock and handle pub fn new() -> (Mock, Handle) { @@ -148,7 +148,7 @@ impl Handle { poll_fn(move |cx| { while buf.has_remaining() { let res = Pin::new(self.codec.get_mut()) - .poll_write(cx, &mut buf.chunk()) + .poll_write(cx, buf.chunk()) .map_err(|e| panic!("write err={:?}", e)); let n = ready!(res).unwrap(); @@ -221,22 +221,15 @@ impl Handle { let settings = settings.into(); self.send(settings.into()).await.unwrap(); - let frame = self.next().await; - let settings = match frame { - Some(frame) => match frame.unwrap() { - Frame::Settings(settings) => { - // Send the ACK - let ack = frame::Settings::ack(); + let frame = self.next().await.expect("unexpected EOF").unwrap(); + let settings = assert_settings!(frame); - // TODO: Don't unwrap? - self.send(ack.into()).await.unwrap(); + // Send the ACK + let ack = frame::Settings::ack(); + + // TODO: Don't unwrap? + self.send(ack.into()).await.unwrap(); - settings - } - frame => panic!("unexpected frame; frame={:?}", frame), - }, - None => panic!("unexpected EOF"), - }; let frame = self.next().await; let f = assert_settings!(frame.unwrap().unwrap()); diff --git a/tests/h2-support/src/prelude.rs b/tests/h2-support/src/prelude.rs index 1fcb0dcc4..c40a518da 100644 --- a/tests/h2-support/src/prelude.rs +++ b/tests/h2-support/src/prelude.rs @@ -2,6 +2,7 @@ pub use h2; pub use h2::client; +pub use h2::ext::Protocol; pub use h2::frame::StreamId; pub use h2::server; pub use h2::*; @@ -20,8 +21,8 @@ pub use super::{Codec, SendFrame}; // Re-export macros pub use super::{ - assert_closed, assert_data, assert_default_settings, assert_headers, assert_ping, poll_err, - poll_frame, raw_codec, + assert_closed, assert_data, assert_default_settings, assert_go_away, assert_headers, + assert_ping, assert_settings, poll_err, poll_frame, raw_codec, }; pub use super::assert::assert_frame_eq; @@ -89,7 +90,7 @@ pub trait ClientExt { impl ClientExt for client::Connection where T: AsyncRead + AsyncWrite + Unpin + 'static, - B: Buf + Unpin + 'static, + B: Buf, { fn run<'a, F: Future + Unpin + 'a>( &'a mut self, @@ -102,7 +103,7 @@ where // Connection is done... b.await } - Right((v, _)) => return v, + Right((v, _)) => v, Left((Err(e), _)) => panic!("err: {:?}", e), } }) diff --git a/tests/h2-support/src/trace.rs b/tests/h2-support/src/trace.rs index 4ac11742c..87038c350 100644 --- a/tests/h2-support/src/trace.rs +++ b/tests/h2-support/src/trace.rs @@ -1,41 +1,17 @@ -use std::{io, str}; pub use tracing; pub use tracing_subscriber; -pub fn init() -> tracing::dispatcher::DefaultGuard { - tracing::subscriber::set_default( - tracing_subscriber::fmt() - .with_max_level(tracing::Level::TRACE) - .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE) - .with_writer(PrintlnWriter { _p: () }) - .finish(), - ) -} - -struct PrintlnWriter { - _p: (), -} +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; -impl tracing_subscriber::fmt::MakeWriter for PrintlnWriter { - type Writer = PrintlnWriter; - fn make_writer(&self) -> Self::Writer { - PrintlnWriter { _p: () } - } -} - -impl io::Write for PrintlnWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - let s = str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; - println!("{}", s); - Ok(s.len()) - } - - fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> io::Result<()> { - println!("{}", fmt); - Ok(()) - } +pub fn init() -> tracing::dispatcher::DefaultGuard { + let use_colors = atty::is(atty::Stream::Stdout); + let layer = tracing_tree::HierarchicalLayer::default() + .with_writer(tracing_subscriber::fmt::writer::TestWriter::default()) + .with_indent_lines(true) + .with_ansi(use_colors) + .with_targets(true) + .with_indent_amount(2); - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } + tracing_subscriber::registry().with(layer).set_default() } diff --git a/tests/h2-support/src/util.rs b/tests/h2-support/src/util.rs index ec768badc..aa7fb2c54 100644 --- a/tests/h2-support/src/util.rs +++ b/tests/h2-support/src/util.rs @@ -32,10 +32,11 @@ pub async fn yield_once() { .await; } +/// Should only be called after a non-0 capacity was requested for the stream. pub fn wait_for_capacity(stream: h2::SendStream, target: usize) -> WaitForCapacity { WaitForCapacity { stream: Some(stream), - target: target, + target, } } @@ -54,14 +55,19 @@ impl Future for WaitForCapacity { type Output = h2::SendStream; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let _ = ready!(self.stream().poll_capacity(cx)).unwrap(); + loop { + let _ = ready!(self.stream().poll_capacity(cx)).unwrap(); - let act = self.stream().capacity(); + let act = self.stream().capacity(); - if act >= self.target { - return Poll::Ready(self.stream.take().unwrap().into()); - } + // If a non-0 capacity was requested for the stream before calling + // wait_for_capacity, then poll_capacity should return Pending + // until there is a non-0 capacity. + assert_ne!(act, 0); - Poll::Pending + if act >= self.target { + return Poll::Ready(self.stream.take().unwrap()); + } + } } } diff --git a/tests/h2-tests/tests/client_request.rs b/tests/h2-tests/tests/client_request.rs index 2af0bdeec..aff39f5c1 100644 --- a/tests/h2-tests/tests/client_request.rs +++ b/tests/h2-tests/tests/client_request.rs @@ -371,7 +371,7 @@ async fn send_request_poll_ready_when_connection_error() { resp2.await.expect_err("req2"); })); - while let Some(_) = unordered.next().await {} + while unordered.next().await.is_some() {} }; join(srv, h2).await; @@ -489,9 +489,8 @@ async fn http_2_request_without_scheme_or_authority() { client .send_request(request, true) .expect_err("should be UserError"); - let ret = h2.await.expect("h2"); + let _: () = h2.await.expect("h2"); drop(client); - ret }; join(srv, h2).await; @@ -575,7 +574,7 @@ async fn connection_close_notifies_response_future() { .0 .await; let err = res.expect_err("response"); - assert_eq!(err.to_string(), "broken pipe"); + assert_eq!(err.to_string(), "stream closed because of a broken pipe"); }; join(async move { conn.await.expect("conn") }, req).await; }; @@ -614,7 +613,7 @@ async fn connection_close_notifies_client_poll_ready() { .0 .await; let err = res.expect_err("response"); - assert_eq!(err.to_string(), "broken pipe"); + assert_eq!(err.to_string(), "stream closed because of a broken pipe"); }; conn.drive(req).await; @@ -622,7 +621,10 @@ async fn connection_close_notifies_client_poll_ready() { let err = poll_fn(move |cx| client.poll_ready(cx)) .await .expect_err("poll_ready"); - assert_eq!(err.to_string(), "broken pipe"); + assert_eq!( + err.to_string(), + "connection closed because of a broken pipe" + ); }; join(srv, client).await; @@ -1305,8 +1307,155 @@ async fn informational_while_local_streaming() { join(srv, h2).await; } -const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; -const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; +#[tokio::test] +async fn extended_connect_protocol_disabled_by_default() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + // first request is allowed + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + assert!(!client.is_extended_connect_protocol_enabled()); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn extended_connect_protocol_enabled_during_handshake() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + + assert!(client.is_extended_connect_protocol_enabled()); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn invalid_connect_protocol_enabled_setting() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let srv = async move { + // Send a settings frame + srv.send(frames::settings().enable_connect_protocol(2).into()) + .await + .unwrap(); + srv.read_preface().await.unwrap(); + + let settings = assert_settings!(srv.next().await.expect("unexpected EOF").unwrap()); + assert_default_settings!(settings); + + // Send the ACK + let ack = frame::Settings::ack(); + + // TODO: Don't unwrap? + srv.send(ack.into()).await.unwrap(); + + let frame = srv.next().await.unwrap().unwrap(); + let go_away = assert_go_away!(frame); + assert_eq!(go_away.reason(), Reason::PROTOCOL_ERROR); + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + // we send a simple req here just to drive the connection so we can + // receive the server settings. + let request = Request::get("https://example.com/").body(()).unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + + let error = h2.drive(response).await.unwrap_err(); + assert_eq!(error.reason(), Some(Reason::PROTOCOL_ERROR)); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn extended_connect_request() { + h2_support::trace_init!(); + + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().enable_connect_protocol(1)) + .await; + assert_default_settings!(settings); + + srv.recv_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol") + .eos(), + ) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + + let request = Request::connect("http://bread/baguette") + .extension(Protocol::from("the-bread-protocol")) + .body(()) + .unwrap(); + let (response, _) = client.send_request(request, true).unwrap(); + h2.drive(response).await.unwrap(); + }; + + join(srv, h2).await; +} + +const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; +const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; trait MockH2 { fn handshake(&mut self) -> &mut Self; diff --git a/tests/h2-tests/tests/flow_control.rs b/tests/h2-tests/tests/flow_control.rs index be04a61b7..5caa2ec3a 100644 --- a/tests/h2-tests/tests/flow_control.rs +++ b/tests/h2-tests/tests/flow_control.rs @@ -1561,3 +1561,300 @@ async fn data_padding() { join(srv, h2).await; } + +#[tokio::test] +async fn poll_capacity_after_send_data_and_reserve() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().initial_window_size(5)) + .await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.send_frame(frames::window_update(1, 5)).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::handshake(io).await.unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(5); + + // Initial window size was 5 so current capacity is 0 even if we just reserved. + assert_eq!(stream.capacity(), 0); + + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + + stream.send_data("".into(), true).unwrap(); + + // Wait for the connection to close + h2.await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn poll_capacity_after_send_data_and_reserve_with_max_send_buffer_size() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().initial_window_size(10)) + .await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.send_frame(frames::window_update(1, 10)).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(5); + + // Initial window size was 10 but with a max send buffer size of 10 in the client, + // so current capacity is 0 even if we just reserved. + assert_eq!(stream.capacity(), 0); + + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + + stream.send_data("".into(), true).unwrap(); + + // Wait for the connection to close + h2.await.unwrap(); + }; + + join(srv, h2).await; +} + +#[tokio::test] +async fn max_send_buffer_size_overflow() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::data(1, &[0; 10][..])).await; + srv.recv_frame(frames::data(1, &[][..]).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = conn.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!(stream.capacity(), 0); + stream.reserve_capacity(10); + assert_eq!( + stream.capacity(), + 5, + "polled capacity not over max buffer size" + ); + + stream.send_data([0; 10][..].into(), false).unwrap(); + + stream.reserve_capacity(15); + assert_eq!( + stream.capacity(), + 0, + "now with buffered over the max, don't overflow" + ); + stream.send_data([0; 0][..].into(), true).unwrap(); + + // Wait for the connection to close + conn.await.unwrap(); + }; + + join(srv, client).await; +} + +#[tokio::test] +async fn max_send_buffer_size_poll_capacity_wakes_task() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv.assert_client_handshake().await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200).eos()).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[0; 5][..])).await; + srv.recv_frame(frames::data(1, &[][..]).eos()).await; + }; + + let client = async move { + let (mut client, mut conn) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = conn.drive(response).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + assert_eq!(stream.capacity(), 0); + const TO_SEND: usize = 20; + stream.reserve_capacity(TO_SEND); + assert_eq!( + stream.capacity(), + 5, + "polled capacity not over max buffer size" + ); + + let t1 = tokio::spawn(async move { + let mut sent = 0; + let buf = [0; TO_SEND]; + loop { + match poll_fn(|cx| stream.poll_capacity(cx)).await { + None => panic!("no cap"), + Some(Err(e)) => panic!("cap error: {:?}", e), + Some(Ok(cap)) => { + stream + .send_data(buf[sent..(sent + cap)].to_vec().into(), false) + .unwrap(); + sent += cap; + if sent >= TO_SEND { + break; + } + } + } + } + stream.send_data(Bytes::new(), true).unwrap(); + }); + + // Wait for the connection to close + conn.await.unwrap(); + t1.await.unwrap(); + }; + + join(srv, client).await; +} + +#[tokio::test] +async fn poll_capacity_wakeup_after_window_update() { + h2_support::trace_init!(); + let (io, mut srv) = mock::new(); + + let srv = async move { + let settings = srv + .assert_client_handshake_with_settings(frames::settings().initial_window_size(10)) + .await; + assert_default_settings!(settings); + srv.recv_frame(frames::headers(1).request("POST", "https://www.example.com/")) + .await; + srv.send_frame(frames::headers(1).response(200)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.send_frame(frames::window_update(1, 5)).await; + srv.send_frame(frames::window_update(1, 5)).await; + srv.recv_frame(frames::data(1, &b"abcde"[..])).await; + srv.recv_frame(frames::data(1, &b""[..]).eos()).await; + }; + + let h2 = async move { + let (mut client, mut h2) = client::Builder::new() + .max_send_buffer_size(5) + .handshake::<_, Bytes>(io) + .await + .unwrap(); + let request = Request::builder() + .method(Method::POST) + .uri("https://www.example.com/") + .body(()) + .unwrap(); + + let (response, mut stream) = client.send_request(request, false).unwrap(); + + let response = h2.drive(response).await.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(10); + assert_eq!(stream.capacity(), 0); + + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + h2.drive(idle_ms(10)).await; + stream.send_data("abcde".into(), false).unwrap(); + + stream.reserve_capacity(5); + assert_eq!(stream.capacity(), 0); + + // This will panic if there is a bug causing h2 to return Ok(0) from poll_capacity. + let mut stream = h2.drive(util::wait_for_capacity(stream, 5)).await; + + stream.send_data("".into(), true).unwrap(); + + // Wait for the connection to close + h2.await.unwrap(); + }; + + join(srv, h2).await; +} diff --git a/tests/h2-tests/tests/hammer.rs b/tests/h2-tests/tests/hammer.rs index 9a200537a..a5cba3dfa 100644 --- a/tests/h2-tests/tests/hammer.rs +++ b/tests/h2-tests/tests/hammer.rs @@ -58,7 +58,7 @@ impl Server { } fn addr(&self) -> SocketAddr { - self.addr.clone() + self.addr } fn request_count(&self) -> usize { diff --git a/tests/h2-tests/tests/ping_pong.rs b/tests/h2-tests/tests/ping_pong.rs index a57f35c17..0f93578cc 100644 --- a/tests/h2-tests/tests/ping_pong.rs +++ b/tests/h2-tests/tests/ping_pong.rs @@ -11,9 +11,8 @@ async fn recv_single_ping() { // Create the handshake let h2 = async move { - let (client, conn) = client::handshake(m).await.unwrap(); - let c = conn.await.unwrap(); - (client, c) + let (_client, conn) = client::handshake(m).await.unwrap(); + let _: () = conn.await.unwrap(); }; let mock = async move { @@ -146,6 +145,7 @@ async fn user_notifies_when_connection_closes() { srv }; + #[allow(clippy::async_yields_async)] let client = async move { let (_client, mut conn) = client::handshake(io).await.expect("client handshake"); // yield once so we can ack server settings diff --git a/tests/h2-tests/tests/push_promise.rs b/tests/h2-tests/tests/push_promise.rs index f52f781d5..94c1154ef 100644 --- a/tests/h2-tests/tests/push_promise.rs +++ b/tests/h2-tests/tests/push_promise.rs @@ -223,7 +223,7 @@ async fn pending_push_promises_reset_when_dropped() { assert_eq!(resp.status(), StatusCode::OK); }; - let _ = conn.drive(req).await; + conn.drive(req).await; conn.await.expect("client"); drop(client); }; diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index e60483d0d..c8c1c9d1c 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -5,8 +5,8 @@ use futures::StreamExt; use h2_support::prelude::*; use tokio::io::AsyncWriteExt; -const SETTINGS: &'static [u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; -const SETTINGS_ACK: &'static [u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; +const SETTINGS: &[u8] = &[0, 0, 0, 4, 0, 0, 0, 0, 0]; +const SETTINGS_ACK: &[u8] = &[0, 0, 0, 4, 1, 0, 0, 0, 0]; #[tokio::test] async fn read_preface_in_multiple_frames() { @@ -566,7 +566,9 @@ async fn sends_reset_cancel_when_req_body_is_dropped() { client .recv_frame(frames::headers(1).response(200).eos()) .await; - client.recv_frame(frames::reset(1).cancel()).await; + client + .recv_frame(frames::reset(1).reason(Reason::NO_ERROR)) + .await; }; let srv = async move { @@ -877,6 +879,40 @@ async fn too_big_headers_sends_reset_after_431_if_not_eos() { join(client, srv).await; } +#[tokio::test] +async fn pending_accept_recv_illegal_content_length_data() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + client + .send_frame( + frames::headers(1) + .request("POST", "https://a.b") + .field("content-length", "1"), + ) + .await; + client + .send_frame(frames::data(1, &b"hello"[..]).eos()) + .await; + client.recv_frame(frames::reset(1).protocol_error()).await; + idle_ms(10).await; + }; + + let srv = async move { + let mut srv = server::Builder::new() + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + let _req = srv.next().await.expect("req").expect("is_ok"); + }; + + join(client, srv).await; +} + #[tokio::test] async fn poll_reset() { h2_support::trace_init!(); @@ -1149,3 +1185,196 @@ async fn send_reset_explicitly() { join(client, srv).await; } + +#[tokio::test] +async fn extended_connect_protocol_disabled_by_default() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), None); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn extended_connect_protocol_enabled_during_handshake() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::headers(1).response(200)).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + assert_eq!( + req.extensions().get::(), + Some(&crate::ext::Protocol::from_static("the-bread-protocol")) + ); + + let rsp = Response::new(()); + stream.send_response(rsp, false).unwrap(); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_pseudo_protocol_on_non_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("GET", "http://bread/baguette") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_authority_target_on_extended_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame( + frames::headers(1) + .request("CONNECT", "bread:80") + .protocol("the-bread-protocol"), + ) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} + +#[tokio::test] +async fn reject_non_authority_target_on_connect_request() { + h2_support::trace_init!(); + + let (io, mut client) = mock::new(); + + let client = async move { + let settings = client.assert_server_handshake().await; + + assert_eq!(settings.is_extended_connect_protocol_enabled(), Some(true)); + + client + .send_frame(frames::headers(1).request("CONNECT", "https://bread/baguette")) + .await; + + client.recv_frame(frames::reset(1).protocol_error()).await; + }; + + let srv = async move { + let mut builder = server::Builder::new(); + + builder.enable_connect_protocol(); + + let mut srv = builder.handshake::<_, Bytes>(io).await.expect("handshake"); + + assert!(srv.next().await.is_none()); + + poll_fn(move |cx| srv.poll_closed(cx)) + .await + .expect("server"); + }; + + join(client, srv).await; +} diff --git a/tests/h2-tests/tests/stream_states.rs b/tests/h2-tests/tests/stream_states.rs index f2b2efc1e..138328efa 100644 --- a/tests/h2-tests/tests/stream_states.rs +++ b/tests/h2-tests/tests/stream_states.rs @@ -1,6 +1,6 @@ #![deny(warnings)] -use futures::future::{join, join3, lazy, try_join}; +use futures::future::{join, join3, lazy, poll_fn, try_join}; use futures::{FutureExt, StreamExt, TryStreamExt}; use h2_support::prelude::*; use h2_support::util::yield_once; @@ -194,6 +194,88 @@ async fn closed_streams_are_released() { join(srv, h2).await; } +#[tokio::test] +async fn reset_streams_dont_grow_memory_continuously() { + //h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + const N: u32 = 50; + const MAX: usize = 20; + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + for n in (1..(N * 2)).step_by(2) { + client + .send_frame(frames::headers(n).request("GET", "https://a.b/").eos()) + .await; + client.send_frame(frames::reset(n).protocol_error()).await; + } + tokio::time::timeout( + std::time::Duration::from_secs(1), + client.recv_frame(frames::go_away((MAX * 2 + 1) as u32).calm()), + ) + .await + .expect("client goaway"); + }; + + let srv = async move { + let mut srv = server::Builder::new() + .max_pending_accept_reset_streams(MAX) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + poll_fn(|cx| srv.poll_closed(cx)) + .await + .expect_err("server should error"); + // specifically, not 50; + assert_eq!(21, srv.num_wired_streams()); + }; + join(srv, client).await; +} + +#[tokio::test] +async fn pending_accept_reset_streams_decrement_too() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + // If it didn't decrement internally, this would eventually get + // the count over MAX. + const M: usize = 2; + const N: usize = 5; + const MAX: usize = 6; + + let client = async move { + let settings = client.assert_server_handshake().await; + assert_default_settings!(settings); + let mut id = 1; + for _ in 0..M { + for _ in 0..N { + client + .send_frame(frames::headers(id).request("GET", "https://a.b/").eos()) + .await; + client.send_frame(frames::reset(id).protocol_error()).await; + id += 2; + } + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + } + }; + + let srv = async move { + let mut srv = server::Builder::new() + .max_pending_accept_reset_streams(MAX) + .handshake::<_, Bytes>(io) + .await + .expect("handshake"); + + while let Some(Ok(_)) = srv.accept().await {} + + poll_fn(|cx| srv.poll_closed(cx)).await.expect("server"); + }; + join(srv, client).await; +} + #[tokio::test] async fn errors_if_recv_frame_exceeds_max_frame_size() { h2_support::trace_init!(); @@ -786,7 +868,7 @@ async fn rst_while_closing() { // Enqueue trailers frame. let _ = stream.send_trailers(HeaderMap::new()); // Signal the server mock to send RST_FRAME - let _ = tx.send(()).unwrap(); + let _: () = tx.send(()).unwrap(); drop(stream); yield_once().await; // yield once to allow the server mock to be polled diff --git a/util/genfixture/Cargo.toml b/util/genfixture/Cargo.toml index 694a99496..cce7eb1b1 100644 --- a/util/genfixture/Cargo.toml +++ b/util/genfixture/Cargo.toml @@ -6,4 +6,4 @@ publish = false edition = "2018" [dependencies] -walkdir = "1.0.0" +walkdir = "2.3.2" diff --git a/util/genfixture/src/main.rs b/util/genfixture/src/main.rs index a6d730761..9dc7b00f9 100644 --- a/util/genfixture/src/main.rs +++ b/util/genfixture/src/main.rs @@ -10,7 +10,7 @@ fn main() { let path = args.get(1).expect("usage: genfixture [PATH]"); let path = Path::new(path); - let mut tests = HashMap::new(); + let mut tests: HashMap> = HashMap::new(); for entry in WalkDir::new(path) { let entry = entry.unwrap(); @@ -28,21 +28,21 @@ fn main() { let fixture_path = path.split("fixtures/hpack/").last().unwrap(); // Now, split that into the group and the name - let module = fixture_path.split("/").next().unwrap(); + let module = fixture_path.split('/').next().unwrap(); tests .entry(module.to_string()) - .or_insert(vec![]) + .or_default() .push(fixture_path.to_string()); } let mut one = false; for (module, tests) in tests { - let module = module.replace("-", "_"); + let module = module.replace('-', "_"); if one { - println!(""); + println!(); } one = true; @@ -51,7 +51,7 @@ fn main() { println!(" {} => {{", module); for test in tests { - let ident = test.split("/").nth(1).unwrap().split(".").next().unwrap(); + let ident = test.split('/').nth(1).unwrap().split('.').next().unwrap(); println!(" ({}, {:?});", ident, test); } diff --git a/util/genhuff/src/main.rs b/util/genhuff/src/main.rs index 2d5b0ba75..6418fab8b 100644 --- a/util/genhuff/src/main.rs +++ b/util/genhuff/src/main.rs @@ -112,8 +112,8 @@ impl Node { }; start.transitions.borrow_mut().push(Transition { - target: target, - byte: byte, + target, + byte, maybe_eos: self.maybe_eos, }); @@ -238,7 +238,7 @@ pub fn main() { let (encode, decode) = load_table(); println!("// !!! DO NOT EDIT !!! Generated by util/genhuff/src/main.rs"); - println!(""); + println!(); println!("// (num-bits, bits)"); println!("pub const ENCODE_TABLE: [(usize, u64); 257] = ["); @@ -247,7 +247,7 @@ pub fn main() { } println!("];"); - println!(""); + println!(); println!("// (next-state, byte, flags)"); println!("pub const DECODE_TABLE: [[(usize, u8, u8); 16]; 256] = ["); @@ -256,7 +256,7 @@ pub fn main() { println!("];"); } -const TABLE: &'static str = r##" +const TABLE: &str = r##" ( 0) |11111111|11000 1ff8 [13] ( 1) |11111111|11111111|1011000 7fffd8 [23] ( 2) |11111111|11111111|11111110|0010 fffffe2 [28]