From 0c366cf372d31786384f0649955b8eda93676989 Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Fri, 6 Feb 2026 14:59:43 -0800 Subject: [PATCH 1/3] test(sinktools): LazySinkSource add tests, make existing tests more robust --- sinktools/Cargo.toml | 2 +- sinktools/src/lazy_sink_source.rs | 89 +++++++++++++++++++++++++++++-- 2 files changed, 86 insertions(+), 5 deletions(-) diff --git a/sinktools/Cargo.toml b/sinktools/Cargo.toml index 6677b4cd0fb4..402ba1899ab8 100644 --- a/sinktools/Cargo.toml +++ b/sinktools/Cargo.toml @@ -30,6 +30,6 @@ variadics = { optional = true, path = "../variadics", default-features = false, [dev-dependencies] bytes = "1.1.0" futures-task = { version = "0.3" } -tokio = { version = "1.29.0", default-features = false, features = ["macros", "rt"] } +tokio = { version = "1.29.0", default-features = false, features = ["macros", "rt", "time"] } tokio-stream = { version = "0.1.3", default-features = false } tokio-util = { version = "0.7.5", default-features = false, features = ["net", "codec"] } diff --git a/sinktools/src/lazy_sink_source.rs b/sinktools/src/lazy_sink_source.rs index f73797e7a334..45c99362322a 100644 --- a/sinktools/src/lazy_sink_source.rs +++ b/sinktools/src/lazy_sink_source.rs @@ -390,9 +390,92 @@ where #[cfg(test)] mod test { use futures_util::{SinkExt, StreamExt}; + use tokio_util::sync::PollSendError; use super::*; + #[tokio::test(flavor = "current_thread")] + async fn stream_drives_initialization() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>(); + + let sink_source = LazySinkSource::new(async move { + let () = init_lazy_recv.await.unwrap(); + let (send, recv) = tokio::sync::mpsc::channel(1); + let sink = tokio_util::sync::PollSender::new(send); + let stream = tokio_stream::wrappers::ReceiverStream::new(recv); + Ok::<_, PollSendError<_>>((stream, sink)) + }); + + let (mut sink, mut stream) = sink_source.split(); + + // Ensures stream starts the lazy. + let (stream_init_send, stream_init_recv) = tokio::sync::oneshot::channel::<()>(); + let stream_task = tokio::task::spawn_local(async move { + stream_init_send.send(()).unwrap(); + (stream.next().await.unwrap(), stream.next().await.unwrap()) + }); + let sink_task = tokio::task::spawn_local(async move { + stream_init_recv.await.unwrap(); + SinkExt::send(&mut sink, "test1").await.unwrap(); + SinkExt::send(&mut sink, "test2").await.unwrap(); + }); + + // finish the future. + init_lazy_send.send(()).unwrap(); + + tokio::task::yield_now().await; + + assert!(sink_task.is_finished()); + assert_eq!(("test1", "test2"), stream_task.await.unwrap()); + sink_task.await.unwrap(); + }) + .await; + } + + #[tokio::test(flavor = "current_thread")] + async fn sink_drives_initialization() { + let local = tokio::task::LocalSet::new(); + local + .run_until(async { + let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>(); + + let sink_source = LazySinkSource::new(async move { + let () = init_lazy_recv.await.unwrap(); + let (send, recv) = tokio::sync::mpsc::channel(1); + let sink = tokio_util::sync::PollSender::new(send); + let stream = tokio_stream::wrappers::ReceiverStream::new(recv); + Ok::<_, PollSendError<_>>((stream, sink)) + }); + + let (mut sink, mut stream) = sink_source.split(); + + // Ensures stream starts the lazy. + let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>(); + let stream_task = tokio::task::spawn_local(async move { + sink_init_recv.await.unwrap(); + (stream.next().await.unwrap(), stream.next().await.unwrap()) + }); + let sink_task = tokio::task::spawn_local(async move { + sink_init_send.send(()).unwrap(); + SinkExt::send(&mut sink, "test1").await.unwrap(); + SinkExt::send(&mut sink, "test2").await.unwrap(); + }); + + // finish the future. + init_lazy_send.send(()).unwrap(); + + tokio::task::yield_now().await; + + assert!(sink_task.is_finished()); + assert_eq!(("test1", "test2"), stream_task.await.unwrap()); + sink_task.await.unwrap(); + }) + .await; + } + #[tokio::test(flavor = "current_thread")] async fn tcp_stream_drives_initialization() { use tokio::net::{TcpListener, TcpStream}; @@ -510,11 +593,9 @@ mod test { let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new()); // try to be really sure that the effects of the above initialization completing are propagated. - for _ in 0..20 { - tokio::task::yield_now().await - } + tokio::time::sleep(std::time::Duration::from_millis(10)).await; - assert!(sink_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now. + assert!(sink_task.is_finished()); // Sink should have sent its item. assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2"); From 68920b670ce0d5c48ed6d01629decc6ed706525e Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Fri, 6 Feb 2026 13:55:57 -0800 Subject: [PATCH 2/3] implement dualwaker --- sinktools/src/lazy_sink_source.rs | 183 ++++++++++-------------------- 1 file changed, 63 insertions(+), 120 deletions(-) diff --git a/sinktools/src/lazy_sink_source.rs b/sinktools/src/lazy_sink_source.rs index 45c99362322a..d2a318a5ee07 100644 --- a/sinktools/src/lazy_sink_source.rs +++ b/sinktools/src/lazy_sink_source.rs @@ -3,42 +3,34 @@ use core::marker::PhantomData; use core::pin::Pin; use core::task::{Context, Poll, Waker}; -use std::cell::RefCell; -use std::rc::Rc; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use std::task::Wake; +use futures_util::task::AtomicWaker; use futures_util::{Sink, Stream, ready}; -struct MultiWaker { - wakers: Mutex>, +#[derive(Default)] +struct DualWaker { + sink: AtomicWaker, + stream: AtomicWaker, } -impl MultiWaker { - fn new(waker: &Waker) -> Self { - MultiWaker { - wakers: Mutex::new(vec![waker.clone()]), - } - } - - fn push(&self, waker: &Waker) { - let mut guard = self.wakers.lock().unwrap(); - guard.push(waker.clone()); +impl DualWaker { + fn new() -> (Arc, Waker) { + let dual_waker = Arc::new(Self::default()); + let waker = Waker::from(dual_waker.clone()); + (dual_waker, waker) } } -impl Wake for MultiWaker { +impl Wake for DualWaker { fn wake(self: Arc) { - let mut wakers = Vec::new(); - - { - let mut guard = self.wakers.lock().unwrap(); - std::mem::swap(&mut wakers, &mut *guard); - } + self.wake_by_ref(); + } - for waker in wakers { - waker.wake(); - } + fn wake_by_ref(self: &Arc) { + self.sink.wake(); + self.stream.wake(); } } @@ -49,7 +41,8 @@ enum SharedState { Thunkulating { future: Pin>, item: Option, - multi_waker: Option>, + dual_waker_state: Arc, + dual_waker_waker: Waker, }, Done { stream: Pin>, @@ -59,9 +52,11 @@ enum SharedState { Taken, } -/// A lazy sink-source that can be split into a sink and a source. The internal state is initialized when the first item is attempted to be pulled from the source half, or when the first item is sent to the sink half. +/// A lazy sink-source, where the internal state is initialized when the first item is attempted to be pulled from the +/// source, or when the first item is sent to the sink. To split into separate source and sink halves, use +/// [`futures_util::StreamExt::split`]. pub struct LazySinkSource { - state: Rc>>, + state: SharedState, _phantom: PhantomData, } @@ -69,50 +64,17 @@ impl LazySinkSource { /// Creates a new `LazySinkSource` with the given initialization future. pub fn new(future: Fut) -> Self { Self { - state: Rc::new(RefCell::new(SharedState::Uninit { + state: SharedState::Uninit { future: Box::pin(future), - })), + }, _phantom: PhantomData, } } - - #[expect( - clippy::type_complexity, - reason = "this type is actually fine and not too complex." - )] - /// Splits into a sink and stream that share the same underlying connection. - pub fn split( - self, - ) -> ( - LazySinkHalf, - LazySourceHalf, - ) { - let sink = LazySinkHalf { - state: Rc::clone(&self.state), - _phantom: PhantomData, - }; - let stream = LazySourceHalf { - state: self.state, - _phantom: PhantomData, - }; - (sink, stream) - } } -/// Sink half of the SinkSource -pub struct LazySinkHalf { - state: Rc>>, - _phantom: PhantomData, -} - -/// Stream half of the SinkSource -pub struct LazySourceHalf { - state: Rc>>, - _phantom: PhantomData, -} - -impl Sink for LazySinkHalf +impl Sink for LazySinkSource where + Self: Unpin, Fut: Future>, St: Stream, Si: Sink, @@ -121,7 +83,7 @@ where type Error = Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut state = self.state.borrow_mut(); + let state = &mut self.get_mut().state; if let SharedState::Uninit { .. } = &*state { return Poll::Ready(Ok(())); @@ -130,21 +92,15 @@ where if let SharedState::Thunkulating { future, item, - multi_waker, + dual_waker_state, + dual_waker_waker, } = &mut *state { - let waker = if let Some(waker) = multi_waker { - waker.push(cx.waker()); - Waker::from(waker.clone()) - } else { - let waker = Arc::new(MultiWaker::new(cx.waker())); - *multi_waker = Some(waker.clone()); - Waker::from(waker) - }; + dual_waker_state.sink.register(cx.waker()); - let mut new_context = Context::from_waker(&waker); + let mut dual_context = Context::from_waker(dual_waker_waker); - match future.as_mut().poll(&mut new_context) { + match future.as_mut().poll(&mut dual_context) { Poll::Ready(Ok((stream, sink))) => { let buf = item.take(); *state = SharedState::Done { @@ -171,19 +127,21 @@ where return result; } - panic!("LazySinkHalf in invalid state."); + panic!("LazySinkSource in invalid state."); } fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { - let mut state = self.state.borrow_mut(); + let state = &mut self.get_mut().state; if let SharedState::Uninit { .. } = &*state { let old_state = std::mem::replace(&mut *state, SharedState::Taken); if let SharedState::Uninit { future } = old_state { + let (dual_waker_state, dual_waker_waker) = DualWaker::new(); *state = SharedState::Thunkulating { future, item: Some(item), - multi_waker: None, + dual_waker_state, + dual_waker_waker, }; return Ok(()); @@ -191,7 +149,7 @@ where } if let SharedState::Thunkulating { .. } = &mut *state { - panic!("LazySinkHalf not ready."); + panic!("LazySinkSource not ready."); } if let SharedState::Done { sink, buf, .. } = &mut *state { @@ -200,11 +158,11 @@ where return result; } - panic!("LazySinkHalf not ready."); + panic!("LazySinkSource not ready."); } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut state = self.state.borrow_mut(); + let state = &mut self.get_mut().state; if let SharedState::Uninit { .. } = &*state { return Poll::Ready(Ok(())); @@ -213,19 +171,13 @@ where if let SharedState::Thunkulating { future, item, - multi_waker, + dual_waker_state, + dual_waker_waker, } = &mut *state { - let waker = if let Some(waker) = multi_waker { - waker.push(cx.waker()); - Waker::from(waker.clone()) - } else { - let waker = Arc::new(MultiWaker::new(cx.waker())); - *multi_waker = Some(waker.clone()); - Waker::from(waker) - }; + dual_waker_state.sink.register(cx.waker()); - let mut new_context = Context::from_waker(&waker); + let mut new_context = Context::from_waker(dual_waker_waker); match future.as_mut().poll(&mut new_context) { Poll::Ready(Ok((stream, sink))) => { @@ -258,7 +210,7 @@ where } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut state = self.state.borrow_mut(); + let state = &mut self.get_mut().state; if let SharedState::Uninit { .. } = &*state { return Poll::Ready(Ok(())); @@ -267,19 +219,13 @@ where if let SharedState::Thunkulating { future, item, - multi_waker, + dual_waker_state, + dual_waker_waker, } = &mut *state { - let waker = if let Some(waker) = multi_waker { - waker.push(cx.waker()); - Waker::from(waker.clone()) - } else { - let waker = Arc::new(MultiWaker::new(cx.waker())); - *multi_waker = Some(waker.clone()); - Waker::from(waker) - }; + dual_waker_state.sink.register(cx.waker()); - let mut new_context = Context::from_waker(&waker); + let mut new_context = Context::from_waker(dual_waker_waker); match future.as_mut().poll(&mut new_context) { Poll::Ready(Ok((stream, sink))) => { @@ -312,8 +258,9 @@ where } } -impl Stream for LazySourceHalf +impl Stream for LazySinkSource where + Self: Unpin, Fut: Future>, St: Stream, Si: Sink, @@ -321,15 +268,17 @@ where type Item = St::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut state = self.state.borrow_mut(); + let state = &mut self.get_mut().state; if let SharedState::Uninit { .. } = &*state { let old_state = std::mem::replace(&mut *state, SharedState::Taken); if let SharedState::Uninit { future } = old_state { + let (dual_waker_state, dual_waker_waker) = DualWaker::new(); *state = SharedState::Thunkulating { future, item: None, - multi_waker: None, + dual_waker_state, + dual_waker_waker, }; } else { unreachable!(); @@ -339,19 +288,13 @@ where if let SharedState::Thunkulating { future, item, - multi_waker, + dual_waker_state, + dual_waker_waker, } = &mut *state { - let waker = if let Some(waker) = multi_waker { - waker.push(cx.waker()); - Waker::from(waker.clone()) - } else { - let waker = Arc::new(MultiWaker::new(cx.waker())); - *multi_waker = Some(waker.clone()); - Waker::from(waker) - }; + dual_waker_state.stream.register(cx.waker()); - let mut new_context = Context::from_waker(&waker); + let mut new_context = Context::from_waker(dual_waker_waker); match future.as_mut().poll(&mut new_context) { Poll::Ready(Ok((stream, sink))) => { @@ -383,7 +326,7 @@ where return result; } - panic!("LazySourceHalf in invalid state."); + panic!("LazySinkSource in invalid state."); } } @@ -452,7 +395,7 @@ mod test { let (mut sink, mut stream) = sink_source.split(); - // Ensures stream starts the lazy. + // Ensures sink starts the lazy. let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>(); let stream_task = tokio::task::spawn_local(async move { sink_init_recv.await.unwrap(); @@ -584,7 +527,7 @@ mod test { tokio::task::yield_now().await } - assert!(!sink_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now. + assert!(!sink_task.is_finished(), "We haven't sent anything yet, so the sink should definitely not be resolved now."); // trigger further initialization of the future. let mut socket = TcpStream::connect(addr).await.unwrap(); From 74e307b905c92b99acda4fcb208af2202d4e2d73 Mon Sep 17 00:00:00 2001 From: Mingwei Samuel Date: Fri, 13 Feb 2026 00:15:22 +0000 Subject: [PATCH 3/3] refactor(sinktools): use `pin_project` in `LazySinkSource` --- sinktools/src/lazy_sink_source.rs | 284 +++++++++++------------------- 1 file changed, 101 insertions(+), 183 deletions(-) diff --git a/sinktools/src/lazy_sink_source.rs b/sinktools/src/lazy_sink_source.rs index d2a318a5ee07..f6897fb2a9ee 100644 --- a/sinktools/src/lazy_sink_source.rs +++ b/sinktools/src/lazy_sink_source.rs @@ -1,5 +1,6 @@ //! [`LazySinkSource`], and related items. +use core::future::Future; use core::marker::PhantomData; use core::pin::Pin; use core::task::{Context, Poll, Waker}; @@ -8,6 +9,7 @@ use std::task::Wake; use futures_util::task::AtomicWaker; use futures_util::{Sink, Stream, ready}; +use pin_project_lite::pin_project; #[derive(Default)] struct DualWaker { @@ -34,30 +36,39 @@ impl Wake for DualWaker { } } -enum SharedState { - Uninit { - future: Pin>, - }, - Thunkulating { - future: Pin>, - item: Option, - dual_waker_state: Arc, - dual_waker_waker: Waker, - }, - Done { - stream: Pin>, - sink: Pin>, - buf: Option, - }, - Taken, +pin_project! { + #[project = SharedStateProj] + enum SharedState { + Uninit { + // The future, always `Some` in this state. + future: Option, + }, + Thunkulating { + #[pin] + future: Fut, + item: Option, + dual_waker_state: Arc, + dual_waker_waker: Waker, + }, + Done { + #[pin] + stream: St, + #[pin] + sink: Si, + buf: Option, + }, + } } -/// A lazy sink-source, where the internal state is initialized when the first item is attempted to be pulled from the -/// source, or when the first item is sent to the sink. To split into separate source and sink halves, use -/// [`futures_util::StreamExt::split`]. -pub struct LazySinkSource { - state: SharedState, - _phantom: PhantomData, +pin_project! { + /// A lazy sink-source, where the internal state is initialized when the first item is attempted to be pulled from the + /// source, or when the first item is sent to the sink. To split into separate source and sink halves, use + /// [`futures_util::StreamExt::split`]. + pub struct LazySinkSource { + #[pin] + state: SharedState, + _phantom: PhantomData, + } } impl LazySinkSource { @@ -65,49 +76,48 @@ impl LazySinkSource { pub fn new(future: Fut) -> Self { Self { state: SharedState::Uninit { - future: Box::pin(future), + future: Some(future), }, _phantom: PhantomData, } } } -impl Sink for LazySinkSource +impl LazySinkSource where - Self: Unpin, Fut: Future>, St: Stream, Si: Sink, Error: From, { - type Error = Error; - - fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let state = &mut self.get_mut().state; - - if let SharedState::Uninit { .. } = &*state { + fn poll_sink_op( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + sink_op: impl FnOnce(Pin<&mut Si>, &mut Context<'_>) -> Poll>, + ) -> Poll> { + let mut this = self.project(); + + if let SharedStateProj::Uninit { .. } = this.state.as_mut().project() { return Poll::Ready(Ok(())); } - if let SharedState::Thunkulating { + if let SharedStateProj::Thunkulating { future, item, dual_waker_state, dual_waker_waker, - } = &mut *state + } = this.state.as_mut().project() { dual_waker_state.sink.register(cx.waker()); let mut dual_context = Context::from_waker(dual_waker_waker); - match future.as_mut().poll(&mut dual_context) { + match future.poll(&mut dual_context) { Poll::Ready(Ok((stream, sink))) => { let buf = item.take(); - *state = SharedState::Done { - stream: Box::pin(stream), - sink: Box::pin(sink), - buf, - }; + this.state + .as_mut() + .set(SharedState::Done { stream, sink, buf }); } Poll::Ready(Err(e)) => { return Poll::Ready(Err(e)); @@ -118,149 +128,69 @@ where } } - if let SharedState::Done { sink, buf, .. } = &mut *state { + if let SharedStateProj::Done { mut sink, buf, .. } = this.state.as_mut().project() { if buf.is_some() { ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?); sink.as_mut().start_send(buf.take().unwrap())?; } - let result = sink.as_mut().poll_ready(cx).map_err(From::from); - return result; + return (sink_op)(sink, cx).map_err(From::from); } panic!("LazySinkSource in invalid state."); } +} + +impl Sink for LazySinkSource +where + Fut: Future>, + St: Stream, + Si: Sink, + Error: From, +{ + type Error = Error; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_sink_op(cx, Sink::poll_ready) + } fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { - let state = &mut self.get_mut().state; - - if let SharedState::Uninit { .. } = &*state { - let old_state = std::mem::replace(&mut *state, SharedState::Taken); - if let SharedState::Uninit { future } = old_state { - let (dual_waker_state, dual_waker_waker) = DualWaker::new(); - *state = SharedState::Thunkulating { - future, - item: Some(item), - dual_waker_state, - dual_waker_waker, - }; - - return Ok(()); - } + let mut this = self.project(); + + if let SharedStateProj::Uninit { future } = this.state.as_mut().project() { + let future = future.take().unwrap(); + let (dual_waker_state, dual_waker_waker) = DualWaker::new(); + this.state.as_mut().set(SharedState::Thunkulating { + future, + item: Some(item), + dual_waker_state, + dual_waker_waker, + }); + return Ok(()); } - if let SharedState::Thunkulating { .. } = &mut *state { + if let SharedStateProj::Thunkulating { .. } = this.state.as_mut().project() { panic!("LazySinkSource not ready."); } - if let SharedState::Done { sink, buf, .. } = &mut *state { + if let SharedStateProj::Done { sink, buf, .. } = this.state.as_mut().project() { debug_assert!(buf.is_none()); - let result = sink.as_mut().start_send(item).map_err(From::from); - return result; + return sink.start_send(item).map_err(From::from); } panic!("LazySinkSource not ready."); } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let state = &mut self.get_mut().state; - - if let SharedState::Uninit { .. } = &*state { - return Poll::Ready(Ok(())); - } - - if let SharedState::Thunkulating { - future, - item, - dual_waker_state, - dual_waker_waker, - } = &mut *state - { - dual_waker_state.sink.register(cx.waker()); - - let mut new_context = Context::from_waker(dual_waker_waker); - - match future.as_mut().poll(&mut new_context) { - Poll::Ready(Ok((stream, sink))) => { - let buf = item.take(); - *state = SharedState::Done { - stream: Box::pin(stream), - sink: Box::pin(sink), - buf, - }; - } - Poll::Ready(Err(e)) => { - return Poll::Ready(Err(e)); - } - Poll::Pending => { - return Poll::Pending; - } - } - } - - if let SharedState::Done { sink, buf, .. } = &mut *state { - if buf.is_some() { - ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?); - sink.as_mut().start_send(buf.take().unwrap())?; - } - let result = sink.as_mut().poll_flush(cx).map_err(From::from); - return result; - } - - panic!("LazySinkHalf in invalid state."); + self.poll_sink_op(cx, Sink::poll_flush) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let state = &mut self.get_mut().state; - - if let SharedState::Uninit { .. } = &*state { - return Poll::Ready(Ok(())); - } - - if let SharedState::Thunkulating { - future, - item, - dual_waker_state, - dual_waker_waker, - } = &mut *state - { - dual_waker_state.sink.register(cx.waker()); - - let mut new_context = Context::from_waker(dual_waker_waker); - - match future.as_mut().poll(&mut new_context) { - Poll::Ready(Ok((stream, sink))) => { - let buf = item.take(); - *state = SharedState::Done { - stream: Box::pin(stream), - sink: Box::pin(sink), - buf, - }; - } - Poll::Ready(Err(e)) => { - return Poll::Ready(Err(e)); - } - Poll::Pending => { - return Poll::Pending; - } - } - } - - if let SharedState::Done { sink, buf, .. } = &mut *state { - if buf.is_some() { - ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?); - sink.as_mut().start_send(buf.take().unwrap())?; - } - let result = sink.as_mut().poll_close(cx).map_err(From::from); - return result; - } - - panic!("LazySinkHalf in invalid state."); + self.poll_sink_op(cx, Sink::poll_close) } } impl Stream for LazySinkSource where - Self: Unpin, Fut: Future>, St: Stream, Si: Sink, @@ -268,42 +198,36 @@ where type Item = St::Item; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let state = &mut self.get_mut().state; - - if let SharedState::Uninit { .. } = &*state { - let old_state = std::mem::replace(&mut *state, SharedState::Taken); - if let SharedState::Uninit { future } = old_state { - let (dual_waker_state, dual_waker_waker) = DualWaker::new(); - *state = SharedState::Thunkulating { - future, - item: None, - dual_waker_state, - dual_waker_waker, - }; - } else { - unreachable!(); - } + let mut this = self.project(); + + if let SharedStateProj::Uninit { future } = this.state.as_mut().project() { + let future = future.take().unwrap(); + let (dual_waker_state, dual_waker_waker) = DualWaker::new(); + this.state.as_mut().set(SharedState::Thunkulating { + future, + item: None, + dual_waker_state, + dual_waker_waker, + }); } - if let SharedState::Thunkulating { + if let SharedStateProj::Thunkulating { future, item, dual_waker_state, dual_waker_waker, - } = &mut *state + } = this.state.as_mut().project() { dual_waker_state.stream.register(cx.waker()); let mut new_context = Context::from_waker(dual_waker_waker); - match future.as_mut().poll(&mut new_context) { + match future.poll(&mut new_context) { Poll::Ready(Ok((stream, sink))) => { let buf = item.take(); - *state = SharedState::Done { - stream: Box::pin(stream), - sink: Box::pin(sink), - buf, - }; + this.state + .as_mut() + .set(SharedState::Done { stream, sink, buf }); } Poll::Ready(Err(_)) => { @@ -316,14 +240,8 @@ where } } - if let SharedState::Done { stream, .. } = &mut *state { - let result = stream.as_mut().poll_next(cx); - match &result { - Poll::Ready(Some(_)) => {} - Poll::Ready(None) => {} - Poll::Pending => {} - } - return result; + if let SharedStateProj::Done { stream, .. } = this.state.as_mut().project() { + return stream.poll_next(cx); } panic!("LazySinkSource in invalid state.");