diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 008158fb0..431e17748 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -53,7 +53,9 @@ jobs: steps: - uses: actions/checkout@v3 - uses: sfackler/actions/rustup@master - - run: echo "version=$(rustc --version)" >> $GITHUB_OUTPUT + with: + version: 1.67.0 + - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - run: rustup target add wasm32-unknown-unknown - uses: actions/cache@v3 diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 1b5be1098..da267101c 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -72,6 +72,7 @@ impl Header { } /// An enum representing Postgres backend messages. +#[derive(Debug, PartialEq)] #[non_exhaustive] pub enum Message { AuthenticationCleartextPassword, @@ -333,6 +334,7 @@ impl Read for Buffer { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationMd5PasswordBody { salt: [u8; 4], } @@ -344,6 +346,7 @@ impl AuthenticationMd5PasswordBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationGssContinueBody(Bytes); impl AuthenticationGssContinueBody { @@ -353,6 +356,7 @@ impl AuthenticationGssContinueBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslBody(Bytes); impl AuthenticationSaslBody { @@ -362,6 +366,7 @@ impl AuthenticationSaslBody { } } +#[derive(Debug, PartialEq)] pub struct SaslMechanisms<'a>(&'a [u8]); impl<'a> FallibleIterator for SaslMechanisms<'a> { @@ -387,6 +392,7 @@ impl<'a> FallibleIterator for SaslMechanisms<'a> { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslContinueBody(Bytes); impl AuthenticationSaslContinueBody { @@ -396,6 +402,7 @@ impl AuthenticationSaslContinueBody { } } +#[derive(Debug, PartialEq)] pub struct AuthenticationSaslFinalBody(Bytes); impl AuthenticationSaslFinalBody { @@ -405,6 +412,7 @@ impl AuthenticationSaslFinalBody { } } +#[derive(Debug, PartialEq)] pub struct BackendKeyDataBody { process_id: i32, secret_key: i32, @@ -422,6 +430,7 @@ impl BackendKeyDataBody { } } +#[derive(Debug, PartialEq)] pub struct CommandCompleteBody { tag: Bytes, } @@ -433,6 +442,7 @@ impl CommandCompleteBody { } } +#[derive(Debug, PartialEq)] pub struct CopyDataBody { storage: Bytes, } @@ -449,6 +459,7 @@ impl CopyDataBody { } } +#[derive(Debug, PartialEq)] pub struct CopyInResponseBody { format: u8, len: u16, @@ -470,6 +481,7 @@ impl CopyInResponseBody { } } +#[derive(Debug, PartialEq)] pub struct ColumnFormats<'a> { buf: &'a [u8], remaining: u16, @@ -503,6 +515,7 @@ impl<'a> FallibleIterator for ColumnFormats<'a> { } } +#[derive(Debug, PartialEq)] pub struct CopyOutResponseBody { format: u8, len: u16, @@ -524,7 +537,7 @@ impl CopyOutResponseBody { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct DataRowBody { storage: Bytes, len: u16, @@ -599,6 +612,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { } } +#[derive(Debug, PartialEq)] pub struct ErrorResponseBody { storage: Bytes, } @@ -657,6 +671,7 @@ impl<'a> ErrorField<'a> { } } +#[derive(Debug, PartialEq)] pub struct NoticeResponseBody { storage: Bytes, } @@ -668,6 +683,7 @@ impl NoticeResponseBody { } } +#[derive(Debug, PartialEq)] pub struct NotificationResponseBody { process_id: i32, channel: Bytes, @@ -691,6 +707,7 @@ impl NotificationResponseBody { } } +#[derive(Debug, PartialEq)] pub struct ParameterDescriptionBody { storage: Bytes, len: u16, @@ -706,6 +723,7 @@ impl ParameterDescriptionBody { } } +#[derive(Debug, PartialEq)] pub struct Parameters<'a> { buf: &'a [u8], remaining: u16, @@ -739,6 +757,7 @@ impl<'a> FallibleIterator for Parameters<'a> { } } +#[derive(Debug, PartialEq)] pub struct ParameterStatusBody { name: Bytes, value: Bytes, @@ -756,6 +775,7 @@ impl ParameterStatusBody { } } +#[derive(Debug, PartialEq)] pub struct ReadyForQueryBody { status: u8, } @@ -767,6 +787,7 @@ impl ReadyForQueryBody { } } +#[derive(Debug, PartialEq)] pub struct RowDescriptionBody { storage: Bytes, len: u16, diff --git a/postgres-types/src/chrono_04.rs b/postgres-types/src/chrono_04.rs index 0ec92437d..aef7549d0 100644 --- a/postgres-types/src/chrono_04.rs +++ b/postgres-types/src/chrono_04.rs @@ -40,7 +40,7 @@ impl ToSql for NaiveDateTime { impl<'a> FromSql<'a> for DateTime { fn from_sql(type_: &Type, raw: &[u8]) -> Result, Box> { let naive = NaiveDateTime::from_sql(type_, raw)?; - Ok(DateTime::from_utc(naive, Utc)) + Ok(DateTime::from_naive_utc_and_offset(naive, Utc)) } accepts!(TIMESTAMPTZ); diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index 52b5c773a..531a9f719 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -442,6 +442,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -893,7 +909,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/tokio-postgres/Cargo.toml b/tokio-postgres/Cargo.toml index ec5e3cbec..c11de2e2b 100644 --- a/tokio-postgres/Cargo.toml +++ b/tokio-postgres/Cargo.toml @@ -59,7 +59,7 @@ postgres-types = { version = "0.2.5", path = "../postgres-types" } tokio = { version = "1.27", features = ["io-util"] } tokio-util = { version = "0.7", features = ["codec"] } rand = "0.8.5" -whoami = "1.4.1" +whoami = "1.4" [target.'cfg(not(target_arch = "wasm32"))'.dependencies] socket2 = { version = "0.5", features = ["all"] } diff --git a/tokio-postgres/src/bind.rs b/tokio-postgres/src/bind.rs index 9c5c49218..dac1a3c06 100644 --- a/tokio-postgres/src/bind.rs +++ b/tokio-postgres/src/bind.rs @@ -31,7 +31,7 @@ where match responses.next().await? { Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(Portal::new(client, name, statement)) diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index 427a05049..187c3fa43 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -234,6 +234,10 @@ impl Client { prepare::prepare(&self.inner, query, parameter_types).await } + pub(crate) async fn prepare_unnamed(&self, query: &str) -> Result { + prepare::prepare(&self.inner, query, &[]).await + } + /// Executes a statement, returning a vector of the resulting rows. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list @@ -368,6 +372,23 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result + where + T: ?Sized + ToStatement, + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::query_txt(&self.inner, statement, params).await + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index b178eac80..2547469ec 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -207,6 +207,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) load_balance_hosts: LoadBalanceHosts, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -240,6 +241,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, load_balance_hosts: LoadBalanceHosts::Disable, + max_backend_message_size: None, } } @@ -520,6 +522,17 @@ impl Config { self.load_balance_hosts } + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { @@ -655,6 +668,14 @@ impl Config { }; self.load_balance_hosts(load_balance_hosts); } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect.rs b/tokio-postgres/src/connect.rs index ca57b9cdd..e697e5bc6 100644 --- a/tokio-postgres/src/connect.rs +++ b/tokio-postgres/src/connect.rs @@ -195,7 +195,7 @@ where } } Some(_) => {} - None => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), } } } diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 19be9eb01..b468c5f32 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -92,7 +92,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls, has_hostname).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; @@ -190,14 +195,14 @@ where )) } Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), } match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => Ok(()), Some(Message::ErrorResponse(body)) => Err(Error::db(body)), - Some(_) => Err(Error::unexpected_message()), + Some(m) => Err(Error::unexpected_message(m)), None => Err(Error::closed()), } } @@ -291,7 +296,7 @@ where let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslContinue(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), }; @@ -309,7 +314,7 @@ where let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslFinal(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), }; @@ -348,7 +353,7 @@ where } Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), None => return Err(Error::closed()), } } diff --git a/tokio-postgres/src/connection.rs b/tokio-postgres/src/connection.rs index 414335955..652038cc0 100644 --- a/tokio-postgres/src/connection.rs +++ b/tokio-postgres/src/connection.rs @@ -139,7 +139,8 @@ where Some(response) => response, None => match messages.next().map_err(Error::parse)? { Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), - _ => return Err(Error::unexpected_message()), + Some(m) => return Err(Error::unexpected_message(m)), + None => return Err(Error::closed()), }, }; diff --git a/tokio-postgres/src/copy_in.rs b/tokio-postgres/src/copy_in.rs index 59e31fea6..2092495b2 100644 --- a/tokio-postgres/src/copy_in.rs +++ b/tokio-postgres/src/copy_in.rs @@ -114,7 +114,7 @@ where let rows = extract_row_affected(&body)?; return Poll::Ready(Ok(rows)); } - _ => return Poll::Ready(Err(Error::unexpected_message())), + m => return Poll::Ready(Err(Error::unexpected_message(m))), } } } @@ -192,7 +192,7 @@ pub async fn copy_in(client: &InnerClient, statement: Statement) -> Result match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + }, Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } match responses.next().await? { Message::CopyInResponse(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(CopyInSink { diff --git a/tokio-postgres/src/copy_out.rs b/tokio-postgres/src/copy_out.rs index 1e6949252..c822a0152 100644 --- a/tokio-postgres/src/copy_out.rs +++ b/tokio-postgres/src/copy_out.rs @@ -12,7 +12,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; pub async fn copy_out(client: &InnerClient, statement: Statement) -> Result { - debug!("executing copy out statement {}", statement.name()); + debug!("executing copy out statement"); let buf = query::encode(client, &statement, slice_iter(&[]))?; let responses = start(client, buf).await?; @@ -26,13 +26,17 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { + Message::ParseComplete => match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + }, Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } match responses.next().await? { Message::CopyOutResponse(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(responses) @@ -56,7 +60,7 @@ impl Stream for CopyOutStream { match ready!(this.responses.poll_next(cx)?) { Message::CopyData(body) => Poll::Ready(Some(Ok(body.into_bytes()))), Message::CopyDone => Poll::Ready(None), - _ => Poll::Ready(Some(Err(Error::unexpected_message()))), + m => Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } diff --git a/tokio-postgres/src/error/mod.rs b/tokio-postgres/src/error/mod.rs index f1e2644c6..764f77f9c 100644 --- a/tokio-postgres/src/error/mod.rs +++ b/tokio-postgres/src/error/mod.rs @@ -1,7 +1,7 @@ //! Errors. use fallible_iterator::FallibleIterator; -use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody}; +use postgres_protocol::message::backend::{ErrorFields, ErrorResponseBody, Message}; use std::error::{self, Error as _Error}; use std::fmt; use std::io; @@ -339,7 +339,7 @@ pub enum ErrorPosition { #[derive(Debug, PartialEq)] enum Kind { Io, - UnexpectedMessage, + UnexpectedMessage(Message), Tls, ToSql(usize), FromSql(usize), @@ -379,7 +379,9 @@ impl fmt::Display for Error { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.0.kind { Kind::Io => fmt.write_str("error communicating with the server")?, - Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::UnexpectedMessage(msg) => { + write!(fmt, "unexpected message from server: {:?}", msg)? + } Kind::Tls => fmt.write_str("error performing TLS handshake")?, Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, @@ -445,8 +447,8 @@ impl Error { Error::new(Kind::Closed, None) } - pub(crate) fn unexpected_message() -> Error { - Error::new(Kind::UnexpectedMessage, None) + pub(crate) fn unexpected_message(message: Message) -> Error { + Error::new(Kind::UnexpectedMessage(message), None) } #[allow(clippy::needless_pass_by_value)] diff --git a/tokio-postgres/src/generic_client.rs b/tokio-postgres/src/generic_client.rs index 50cff9712..a4ee4808b 100644 --- a/tokio-postgres/src/generic_client.rs +++ b/tokio-postgres/src/generic_client.rs @@ -56,6 +56,18 @@ pub trait GenericClient: private::Sealed { I: IntoIterator + Sync + Send, I::IntoIter: ExactSizeIterator; + /// Like `Client::query_raw_txt`. + async fn query_raw_txt<'a, T, S, I>( + &self, + statement: &T, + params: I, + ) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + /// Like `Client::prepare`. async fn prepare(&self, query: &str) -> Result; @@ -136,6 +148,16 @@ impl GenericClient for Client { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } @@ -222,6 +244,16 @@ impl GenericClient for Transaction<'_> { self.query_raw(statement, params).await } + async fn query_raw_txt<'a, T, S, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement + Sync + Send, + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + async fn prepare(&self, query: &str) -> Result { self.prepare(query).await } diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..e8b787f36 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -13,7 +13,6 @@ use postgres_protocol::message::backend::Message; use postgres_protocol::message::frontend; use std::future::Future; use std::pin::Pin; -use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; const TYPEINFO_QUERY: &str = "\ @@ -56,31 +55,29 @@ AND attnum > 0 ORDER BY attnum "; -static NEXT_ID: AtomicUsize = AtomicUsize::new(0); - pub async fn prepare( client: &Arc, query: &str, types: &[Type], ) -> Result { - let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); - let buf = encode(client, &name, query, types)?; + let name = ""; + let buf = encode(client, name, query, types)?; let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } let parameter_description = match responses.next().await? { Message::ParameterDescription(body) => body, - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), }; let row_description = match responses.next().await? { Message::RowDescription(body) => Some(body), Message::NoData => None, - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), }; let mut parameters = vec![]; @@ -95,12 +92,12 @@ pub async fn prepare( let mut it = row_description.fields(); while let Some(field) = it.next().map_err(Error::parse)? { let type_ = get_type(client, field.type_oid()).await?; - let column = Column::new(field.name().to_string(), type_); + let column = Column::new(field.name().to_string(), type_, field); columns.push(column); } } - Ok(Statement::new(client, name, parameters, columns)) + Ok(Statement::new(query.to_owned(), parameters, columns)) } fn prepare_rec<'a>( @@ -126,7 +123,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } @@ -142,7 +139,7 @@ async fn get_type(client: &Arc, oid: Oid) -> Result { let row = match rows.try_next().await? { Some(row) => row, - None => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), }; let name: String = row.try_get(0)?; diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index e6e1d00a8..2be715015 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -3,15 +3,17 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::{BorrowToSql, IsNull}; use crate::{Error, Portal, Row, Statement}; -use bytes::{Bytes, BytesMut}; +use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{ready, Stream}; use log::{debug, log_enabled, Level}; use pin_project_lite::pin_project; use postgres_protocol::message::backend::{CommandCompleteBody, Message}; use postgres_protocol::message::frontend; +use postgres_types::Format; use std::fmt; use std::marker::PhantomPinned; use std::pin::Pin; +use std::sync::Arc; use std::task::{Context, Poll}; struct BorrowToSqlParamsDebug<'a, T>(&'a [T]); @@ -40,8 +42,7 @@ where let buf = if log_enabled!(Level::Debug) { let params = params.into_iter().collect::>(); debug!( - "executing statement {} with parameters: {:?}", - statement.name(), + "executing statement with parameters: {:?}", BorrowToSqlParamsDebug(params.as_slice()), ); encode(client, &statement, params)? @@ -53,10 +54,68 @@ where statement, responses, rows_affected: None, + command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } +pub async fn query_txt( + client: &Arc, + statement: Statement, + params: I, +) -> Result +where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + + let buf = client.with_buf(|buf| { + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // named prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + } + None => Ok(postgres_protocol::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + // now read the responses + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: None, + output_format: Format::Text, + _p: PhantomPinned, + rows_affected: None, + }) +} + pub async fn query_portal( client: &InnerClient, portal: &Portal, @@ -74,6 +133,9 @@ pub async fn query_portal( statement: portal.statement().clone(), responses, rows_affected: None, + command_tag: None, + status: None, + output_format: Format::Binary, _p: PhantomPinned, }) } @@ -104,8 +166,7 @@ where let buf = if log_enabled!(Level::Debug) { let params = params.into_iter().collect::>(); debug!( - "executing statement {} with parameters: {:?}", - statement.name(), + "executing statement with parameters: {:?}", BorrowToSqlParamsDebug(params.as_slice()), ); encode(client, &statement, params)? @@ -123,7 +184,7 @@ where } Message::EmptyQueryResponse => rows = 0, Message::ReadyForQuery(_) => return Ok(rows), - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } } } @@ -132,8 +193,12 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result { let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; match responses.next().await? { + Message::ParseComplete => match responses.next().await? { + Message::BindComplete => {} + m => return Err(Error::unexpected_message(m)), + }, Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } Ok(responses) @@ -146,6 +211,7 @@ where I::IntoIter: ExactSizeIterator, { client.with_buf(|buf| { + frontend::parse("", statement.query(), [], buf).unwrap(); encode_bind(statement, params, "", buf)?; frontend::execute("", 0, buf).map_err(Error::encode)?; frontend::sync(buf); @@ -181,7 +247,7 @@ where let mut error_idx = 0; let r = frontend::bind( portal, - statement.name(), + "", // statement name param_formats, params.zip(param_types).enumerate(), |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) { @@ -208,6 +274,9 @@ pin_project! { statement: Statement, responses: Responses, rows_affected: Option, + command_tag: Option, + output_format: Format, + status: Option, #[pin] _p: PhantomPinned, } @@ -221,14 +290,25 @@ impl Stream for RowStream { loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { - return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) } Message::CommandComplete(body) => { *this.rows_affected = Some(extract_row_affected(&body)?); + + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } } Message::EmptyQueryResponse | Message::PortalSuspended => {} - Message::ReadyForQuery(_) => return Poll::Ready(None), - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + Message::ReadyForQuery(status) => { + *this.status = Some(status.status()); + return Poll::Ready(None); + } + m => return Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } @@ -241,4 +321,18 @@ impl RowStream { pub fn rows_affected(&self) -> Option { self.rows_affected } + + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> Option { + self.status + } } diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..754b5f28c 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -97,6 +98,7 @@ where /// A row of data returned from the database by a query. pub struct Row { statement: Statement, + output_format: Format, body: DataRowBody, ranges: Vec>>, } @@ -110,12 +112,17 @@ impl fmt::Debug for Row { } impl Row { - pub(crate) fn new(statement: Statement, body: DataRowBody) -> Result { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { let ranges = body.ranges().collect().map_err(Error::parse)?; Ok(Row { statement, body, ranges, + output_format, }) } @@ -187,6 +194,27 @@ impl Row { let range = self.ranges[idx].to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.output_format == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/simple_query.rs b/tokio-postgres/src/simple_query.rs index bcc6d928b..9838b0809 100644 --- a/tokio-postgres/src/simple_query.rs +++ b/tokio-postgres/src/simple_query.rs @@ -58,7 +58,7 @@ pub async fn batch_execute(client: &InnerClient, query: &str) -> Result<(), Erro | Message::EmptyQueryResponse | Message::RowDescription(_) | Message::DataRow(_) => {} - _ => return Err(Error::unexpected_message()), + m => return Err(Error::unexpected_message(m)), } } } @@ -107,12 +107,12 @@ impl Stream for SimpleQueryStream { Message::DataRow(body) => { let row = match &this.columns { Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, - None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + None => return Poll::Ready(Some(Err(Error::closed()))), }; return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); } Message::ReadyForQuery(_) => return Poll::Ready(None), - _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + m => return Poll::Ready(Some(Err(Error::unexpected_message(m)))), } } } diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..16b3ca992 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -1,33 +1,13 @@ -use crate::client::InnerClient; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::types::Type; -use postgres_protocol::message::frontend; -use std::{ - fmt, - sync::{Arc, Weak}, -}; +use postgres_protocol::{message::backend::Field, Oid}; +use std::{fmt, sync::Arc}; struct StatementInner { - client: Weak, - name: String, + query: String, params: Vec, columns: Vec, } -impl Drop for StatementInner { - fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - let buf = client.with_buf(|buf| { - frontend::close(b'S', &self.name, buf).unwrap(); - frontend::sync(buf); - buf.split().freeze() - }); - let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); - } - } -} - /// A prepared statement. /// /// Prepared statements can only be used with the connection that created them. @@ -35,32 +15,26 @@ impl Drop for StatementInner { pub struct Statement(Arc); impl Statement { - pub(crate) fn new( - inner: &Arc, - name: String, - params: Vec, - columns: Vec, - ) -> Statement { + pub(crate) fn new(query: String, params: Vec, columns: Vec) -> Self { Statement(Arc::new(StatementInner { - client: Arc::downgrade(inner), - name, + query, params, columns, })) } - pub(crate) fn name(&self) -> &str { - &self.0.name + pub(crate) fn query(&self) -> &str { + &*self.0.query } /// Returns the expected types of the statement's parameters. pub fn params(&self) -> &[Type] { - &self.0.params + &*self.0.params } /// Returns information about the columns returned when the statement is queried. pub fn columns(&self) -> &[Column] { - &self.0.columns + &*self.0.columns } } @@ -68,11 +42,30 @@ impl Statement { pub struct Column { name: String, type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, } impl Column { - pub(crate) fn new(name: String, type_: Type) -> Column { - Column { name, type_ } + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } } /// Returns the name of the column. @@ -84,6 +77,36 @@ impl Column { pub fn type_(&self) -> &Type { &self.type_ } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } } impl fmt::Debug for Column { diff --git a/tokio-postgres/src/to_statement.rs b/tokio-postgres/src/to_statement.rs index 427f77dd7..ef1e65272 100644 --- a/tokio-postgres/src/to_statement.rs +++ b/tokio-postgres/src/to_statement.rs @@ -15,7 +15,7 @@ mod private { pub async fn into_statement(self, client: &Client) -> Result { match self { ToStatementType::Statement(s) => Ok(s.clone()), - ToStatementType::Query(s) => client.prepare(s).await, + ToStatementType::Query(s) => client.prepare_unnamed(s).await, } } } diff --git a/tokio-postgres/src/transaction.rs b/tokio-postgres/src/transaction.rs index 96a324652..ca386974e 100644 --- a/tokio-postgres/src/transaction.rs +++ b/tokio-postgres/src/transaction.rs @@ -149,6 +149,17 @@ impl<'a> Transaction<'a> { self.client.query_raw(statement, params).await } + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + self.client.query_raw_txt(statement, params).await + } + /// Like `Client::execute`. pub async fn execute( &self, diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 0ab4a7bab..565984271 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -249,6 +249,161 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("SELECT 55 * $1", [Some("42")]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec = client + .query_raw_txt("SELECT $1", [Some("42")]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); +} + +#[tokio::test] +async fn query_raw_txt_nulls() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "SELECT $1 as str, $2 as n, 'null' as str2, null as n2", + [Some("null"), None], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + + let res = rows[0].as_text(0).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(1).unwrap(); + assert_eq!(res, None); + + let res = rows[0].as_text(2).unwrap(); + assert_eq!(res, Some("null")); + + let res = rows[0].as_text(3).unwrap(); + assert_eq!(res, None); +} + +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", [] as [Option<&str>; 0]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", [] as [Option<&str>; 0]) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + +#[tokio::test] +async fn ready_for_query() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("START TRANSACTION", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'T')); + + let row_stream = client + .query_raw_txt("ROLLBACK", [] as [Option<&str>; 0]) + .await + .unwrap(); + + pin_mut!(row_stream); + while row_stream.next().await.is_none() {} + + assert_eq!(row_stream.ready_status(), Some(b'I')); +} + +#[tokio::test] +async fn column_extras() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt( + "select relacl, relname from pg_class limit 1", + [] as [Option<&str>; 0], + ) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + let column = rows[0].columns().get(1).unwrap(); + assert_eq!(column.name(), "relname"); + assert_eq!(column.type_(), &Type::NAME); + + assert!(column.table_oid() > 0); + assert_eq!(column.column_id(), 2); + assert_eq!(column.format(), 0); + + assert_eq!(column.type_oid(), 19); + assert_eq!(column.type_size(), 64); + assert_eq!(column.type_modifier(), -1); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; diff --git a/tokio-postgres/tests/test/types/chrono_04.rs b/tokio-postgres/tests/test/types/chrono_04.rs index a8e9e5afa..c0229377d 100644 --- a/tokio-postgres/tests/test/types/chrono_04.rs +++ b/tokio-postgres/tests/test/types/chrono_04.rs @@ -1,4 +1,4 @@ -use chrono_04::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; +use chrono_04::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use std::fmt; use tokio_postgres::types::{Date, FromSqlOwned, Timestamp}; use tokio_postgres::Client; @@ -51,12 +51,9 @@ async fn test_with_special_naive_date_time_params() { #[tokio::test] async fn test_date_time_params() { - fn make_check(time: &str) -> (Option>, &str) { + fn make_check(time: &str) -> (Option>, &str) { ( - Some( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), + Some(DateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()), time, ) } @@ -74,12 +71,9 @@ async fn test_date_time_params() { #[tokio::test] async fn test_with_special_date_time_params() { - fn make_check(time: &str) -> (Timestamp>, &str) { + fn make_check(time: &str) -> (Timestamp>, &str) { ( - Timestamp::Value( - Utc.datetime_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'") - .unwrap(), - ), + Timestamp::Value(DateTime::parse_from_str(time, "'%Y-%m-%d %H:%M:%S.%f'").unwrap()), time, ) }