use std::future::Future;
use tokio::task::{JoinError, JoinSet};
#[derive(Debug)]
pub struct SpawnedTask<R> {
inner: JoinSet<R>,
}
impl<R: 'static> SpawnedTask<R> {
pub fn spawn<T>(task: T) -> Self
where
T: Future<Output = R>,
T: Send + 'static,
R: Send,
{
let mut inner = JoinSet::new();
inner.spawn(task);
Self { inner }
}
pub fn spawn_blocking<T>(task: T) -> Self
where
T: FnOnce() -> R,
T: Send + 'static,
R: Send,
{
let mut inner = JoinSet::new();
inner.spawn_blocking(task);
Self { inner }
}
pub async fn join(mut self) -> Result<R, JoinError> {
self.inner
.join_next()
.await
.expect("`SpawnedTask` instance always contains exactly 1 task")
}
pub async fn join_unwind(self) -> Result<R, JoinError> {
self.join().await.map_err(|e| {
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
} else {
log::warn!("SpawnedTask was polled during shutdown");
e
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::future::{pending, Pending};
use tokio::runtime::Runtime;
#[tokio::test]
async fn runtime_shutdown() {
let rt = Runtime::new().unwrap();
let task = rt
.spawn(async {
SpawnedTask::spawn(async {
let fut: Pending<()> = pending();
fut.await;
unreachable!("should never return");
})
})
.await
.unwrap();
rt.shutdown_background();
assert!(matches!(
task.join_unwind().await,
Err(e) if e.is_cancelled()
));
}
#[tokio::test]
#[should_panic(expected = "foo")]
async fn panic_resume() {
SpawnedTask::spawn(async { panic!("foo") })
.join_unwind()
.await
.ok();
}
}