use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::task::{JoinError, JoinHandle};
use crate::trace_utils::{trace_block, trace_future};
#[derive(Debug)]
pub struct SpawnedTask<R> {
inner: JoinHandle<R>,
}
impl<R: 'static> SpawnedTask<R> {
pub fn spawn<T>(task: T) -> Self
where
T: Future<Output = R>,
T: Send + 'static,
R: Send,
{
#[allow(clippy::disallowed_methods)]
let inner = tokio::task::spawn(trace_future(task));
Self { inner }
}
pub fn spawn_blocking<T>(task: T) -> Self
where
T: FnOnce() -> R,
T: Send + 'static,
R: Send,
{
#[allow(clippy::disallowed_methods)]
let inner = tokio::task::spawn_blocking(trace_block(task));
Self { inner }
}
pub async fn join(self) -> Result<R, JoinError> {
self.await
}
pub async fn join_unwind(self) -> Result<R, JoinError> {
self.await.map_err(|e| {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
log::warn!("SpawnedTask was polled during shutdown");
e
}
})
}
}
impl<R> Future for SpawnedTask<R> {
type Output = Result<R, JoinError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}
impl<R> Drop for SpawnedTask<R> {
fn drop(&mut self) {
self.inner.abort();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::{pending, Pending};
use tokio::{runtime::Runtime, sync::oneshot};
#[tokio::test]
async fn runtime_shutdown() {
let rt = Runtime::new().unwrap();
#[allow(clippy::async_yields_async)]
let task = rt
.spawn(async {
SpawnedTask::spawn(async {
let fut: Pending<()> = pending();
fut.await;
unreachable!("should never return");
})
})
.await
.unwrap();
rt.shutdown_background();
assert!(matches!(
task.join_unwind().await,
Err(e) if e.is_cancelled()
));
}
#[tokio::test]
#[should_panic(expected = "foo")]
async fn panic_resume() {
SpawnedTask::spawn(async { panic!("foo") })
.join_unwind()
.await
.ok();
}
#[tokio::test]
async fn cancel_not_started_task() {
let (sender, receiver) = oneshot::channel::<i32>();
let task = SpawnedTask::spawn(async {
sender.send(42).unwrap();
});
drop(task);
assert!(receiver.await.is_err());
}
#[tokio::test]
async fn cancel_ongoing_task() {
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
let task = SpawnedTask::spawn(async move {
sender.send(1).await.unwrap();
sender.send(2).await.unwrap();
});
assert_eq!(receiver.recv().await.unwrap(), 1);
drop(task);
assert!(receiver.recv().await.is_none());
}
}