Skip to content

Commit

Permalink
iterator: narrow error type of internal items
Browse files Browse the repository at this point in the history
Narrowed the error types in multiple places in internal API of
iterator module. Now the error type we manipulate on mainly is
`NextPageError` (instead of `QueryError`).

I did not change the return type of public methods yet.
I want to do it in a separate commit.
  • Loading branch information
muzarski committed Dec 27, 2024
1 parent 43e71cc commit 6c416a4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 35 deletions.
6 changes: 3 additions & 3 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use std::{
};

use super::errors::{ProtocolError, SchemaVersionFetchError, UseKeyspaceProtocolError};
use super::iterator::QueryPager;
use super::iterator::{NextRowError, QueryPager};
use super::locator::tablets::{RawTablet, TabletParsingError};
use super::query_result::QueryResult;
use super::session::AddressTranslator;
Expand Down Expand Up @@ -1186,7 +1186,7 @@ impl Connection {
pub(crate) async fn query_iter(
self: Arc<Self>,
query: Query,
) -> Result<QueryPager, QueryError> {
) -> Result<QueryPager, NextRowError> {
let consistency = query
.config
.determine_consistency(self.config.default_consistency);
Expand All @@ -1202,7 +1202,7 @@ impl Connection {
self: Arc<Self>,
prepared_statement: PreparedStatement,
values: SerializedValues,
) -> Result<QueryPager, QueryError> {
) -> Result<QueryPager, NextRowError> {
let consistency = prepared_statement
.config
.determine_consistency(self.config.default_consistency);
Expand Down
69 changes: 39 additions & 30 deletions scylla/src/transport/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use crate::frame::response::{
result::{ColumnSpec, Row},
};
use crate::history::{self, HistoryListener};
use crate::prepared_statement::PartitionKeyError;
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
use crate::statement::{Consistency, PagingState, SerialConsistency};
use crate::transport::cluster::ClusterData;
Expand Down Expand Up @@ -79,9 +80,7 @@ mod checked_channel_sender {
use tokio::sync::mpsc;
use uuid::Uuid;

use crate::transport::errors::QueryError;

use super::ReceivedPage;
use super::{NextPageError, ReceivedPage};

/// A value whose existence proves that there was an attempt
/// to send an item of type T through a channel.
Expand All @@ -106,7 +105,7 @@ mod checked_channel_sender {
}
}

type ResultPage = Result<ReceivedPage, QueryError>;
type ResultPage = Result<ReceivedPage, NextPageError>;

impl ProvingSender<ResultPage> {
pub(crate) async fn send_empty_page(
Expand All @@ -127,12 +126,12 @@ mod checked_channel_sender {

use checked_channel_sender::{ProvingSender, SendAttemptedProof};

type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, QueryError>>;
type PageSendAttemptedProof = SendAttemptedProof<Result<ReceivedPage, NextPageError>>;

// PagerWorker works in the background to fetch pages
// QueryPager receives them through a channel
struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,

// Closure used to perform a single page query
// AsyncFn(Arc<Connection>, Option<Arc<[u8]>>) -> Result<QueryResponse, UserRequestError>
Expand Down Expand Up @@ -275,7 +274,10 @@ where
}
TimeoutableRequestError::RequestFailure(err) => err,
};
let (proof, _) = self.sender.send(Err(error.into_query_error())).await;
let (proof, _) = self
.sender
.send(Err(NextPageError::RequestFailure(error)))
.await;
proof
}

Expand Down Expand Up @@ -485,7 +487,7 @@ where
/// any complicated logic related to retries, it just fetches pages from
/// a single connection.
struct SingleConnectionPagerWorker<Fetcher> {
sender: ProvingSender<Result<ReceivedPage, QueryError>>,
sender: ProvingSender<Result<ReceivedPage, NextPageError>>,
fetcher: Fetcher,
}

Expand All @@ -498,7 +500,12 @@ where
match self.do_work().await {
Ok(proof) => proof,
Err(err) => {
let (proof, _) = self.sender.send(Err(err.into_query_error())).await;
let (proof, _) = self
.sender
.send(Err(NextPageError::RequestFailure(
RetriableRequestError::RequestFailure(err),
)))
.await;
proof
}
}
Expand Down Expand Up @@ -567,7 +574,7 @@ where
/// [Row] is not the intended target type.
pub struct QueryPager {
current_page: RawRowLendingIterator,
page_receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
page_receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
tracing_ids: Vec<Uuid>,
}

Expand All @@ -589,7 +596,7 @@ impl QueryPager {
let res = std::future::poll_fn(|cx| Pin::new(&mut *self).poll_fill_page(cx)).await;
match res {
Some(Ok(())) => {}
Some(Err(err)) => return Some(Err(err)),
Some(Err(err)) => return Some(Err(err.into())),
None => return None,
}

Expand All @@ -606,7 +613,7 @@ impl QueryPager {
fn poll_fill_page<'r>(
mut self: Pin<&'r mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), QueryError>>> {
) -> Poll<Option<Result<(), NextRowError>>> {
if !self.is_current_page_exhausted() {
return Poll::Ready(Some(Ok(())));
}
Expand All @@ -629,14 +636,11 @@ impl QueryPager {
fn poll_next_page<'r>(
mut self: Pin<&'r mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<(), QueryError>>> {
) -> Poll<Option<Result<(), NextRowError>>> {
let mut s = self.as_mut();

let received_page = ready_some_ok!(Pin::new(&mut s.page_receiver).poll_recv(cx));

// TODO: see my other comment next to QueryError::NextRowError
// This is the place where conversion happens. To fix this, we need to refactor error types in iterator API.
// The `page_receiver`'s error type should be narrowed from QueryError to some other error type.
let raw_rows_with_deserialized_metadata =
received_page.rows.deserialize_metadata().map_err(|err| {
NextRowError::NextPageError(NextPageError::ResultMetadataParseError(err))
Expand Down Expand Up @@ -691,8 +695,8 @@ impl QueryPager {
execution_profile: Arc<ExecutionProfileInner>,
cluster_data: Arc<ClusterData>,
metrics: Arc<Metrics>,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let consistency = query
.config
Expand Down Expand Up @@ -770,8 +774,8 @@ impl QueryPager {

pub(crate) async fn new_for_prepared_statement(
config: PreparedIteratorConfig,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let consistency = config
.prepared
Expand Down Expand Up @@ -806,7 +810,7 @@ impl QueryPager {
Ok(res) => res.unzip(),
Err(err) => {
let (proof, _res) = ProvingSender::from(sender)
.send(Err(err.into_query_error()))
.send(Err(NextPageError::PartitionKeyError(err)))
.await;
return proof;
}
Expand Down Expand Up @@ -893,8 +897,8 @@ impl QueryPager {
connection: Arc<Connection>,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let page_size = query.get_validated_page_size();

Expand Down Expand Up @@ -923,8 +927,8 @@ impl QueryPager {
connection: Arc<Connection>,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<Self, QueryError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);
) -> Result<Self, NextRowError> {
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, NextPageError>>(1);

let page_size = prepared.get_validated_page_size();

Expand All @@ -950,8 +954,8 @@ impl QueryPager {

async fn new_from_worker_future(
worker_task: impl Future<Output = PageSendAttemptedProof> + Send + 'static,
mut receiver: mpsc::Receiver<Result<ReceivedPage, QueryError>>,
) -> Result<Self, QueryError> {
mut receiver: mpsc::Receiver<Result<ReceivedPage, NextPageError>>,
) -> Result<Self, NextRowError> {
tokio::task::spawn(worker_task);

// This unwrap is safe because:
Expand Down Expand Up @@ -1061,12 +1065,17 @@ where
#[derive(Error, Debug, Clone)]
#[non_exhaustive]
pub enum NextPageError {
/// PK extraction and/or token calculation error. Applies only for prepared statements.
#[error("Failed to extract PK and compute token required for routing: {0}")]
PartitionKeyError(#[from] PartitionKeyError),

/// Failed to run a request responsible for fetching new page.
#[error(transparent)]
RequestFailure(#[from] RetriableRequestError),

/// Failed to deserialize result metadata associated with next page response.
#[error("Failed to deserialize result metadata associated with next page response: {0}")]
ResultMetadataParseError(#[from] ResultMetadataAndRowsCountParseError),
// TODO: This should also include a variant representing an error that occurred during
// query that fetches the next page. However, as of now, it would require that we include QueryError here.
// This would introduce a cyclic dependency: QueryError -> NextRowError -> NextPageError -> QueryError.
}

/// An error returned by async iterator API.
Expand Down
3 changes: 3 additions & 0 deletions scylla/src/transport/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1328,6 +1328,7 @@ where
self.metrics.clone(),
)
.await
.map_err(QueryError::from)
} else {
// Making QueryPager::new_for_query work with values is too hard (if even possible)
// so instead of sending one prepare to a specific connection on each iterator query,
Expand All @@ -1342,6 +1343,7 @@ where
metrics: self.metrics.clone(),
})
.await
.map_err(QueryError::from)
}
}

Expand Down Expand Up @@ -1599,6 +1601,7 @@ where
metrics: self.metrics.clone(),
})
.await
.map_err(QueryError::from)
}

async fn do_batch(
Expand Down
6 changes: 4 additions & 2 deletions scylla/src/transport/topology.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ where
let mut query = Query::new(query_str);
query.set_page_size(METADATA_QUERY_PAGE_SIZE);

conn.query_iter(query).await
conn.query_iter(query).await.map_err(QueryError::from)
} else {
let keyspaces = &[keyspaces_to_fetch] as &[&[String]];
let query_str = format!("{query_str} where keyspace_name in ?");
Expand All @@ -974,7 +974,9 @@ where
.await
.map_err(UserRequestError::into_query_error)?;
let serialized_values = prepared.serialize_values(&keyspaces)?;
conn.execute_iter(prepared, serialized_values).await
conn.execute_iter(prepared, serialized_values)
.await
.map_err(QueryError::from)
}
}

Expand Down

0 comments on commit 6c416a4

Please sign in to comment.