rustc_thread_pool/broadcast/
mod.rs

1use std::fmt;
2use std::marker::PhantomData;
3use std::sync::Arc;
4
5use crate::job::{ArcJob, StackJob};
6use crate::latch::{CountLatch, LatchRef};
7use crate::registry::{Registry, WorkerThread};
8
9mod tests;
10
11/// Executes `op` within every thread in the current threadpool. If this is
12/// called from a non-Rayon thread, it will execute in the global threadpool.
13/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
14/// within that threadpool. When the call has completed on each thread, returns
15/// a vector containing all of their return values.
16///
17/// For more information, see the [`ThreadPool::broadcast()`][m] method.
18///
19/// [m]: struct.ThreadPool.html#method.broadcast
20pub fn broadcast<OP, R>(op: OP) -> Vec<R>
21where
22    OP: Fn(BroadcastContext<'_>) -> R + Sync,
23    R: Send,
24{
25    // We assert that current registry has not terminated.
26    unsafe { broadcast_in(op, &Registry::current()) }
27}
28
29/// Spawns an asynchronous task on every thread in this thread-pool. This task
30/// will run in the implicit, global scope, which means that it may outlast the
31/// current stack frame -- therefore, it cannot capture any references onto the
32/// stack (you will likely need a `move` closure).
33///
34/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
35///
36/// [m]: struct.ThreadPool.html#method.spawn_broadcast
37pub fn spawn_broadcast<OP>(op: OP)
38where
39    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
40{
41    // We assert that current registry has not terminated.
42    unsafe { spawn_broadcast_in(op, &Registry::current()) }
43}
44
45/// Provides context to a closure called by `broadcast`.
46pub struct BroadcastContext<'a> {
47    worker: &'a WorkerThread,
48
49    /// Make sure to prevent auto-traits like `Send` and `Sync`.
50    _marker: PhantomData<&'a mut dyn Fn()>,
51}
52
53impl<'a> BroadcastContext<'a> {
54    pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
55        let worker_thread = WorkerThread::current();
56        assert!(!worker_thread.is_null());
57        f(BroadcastContext { worker: unsafe { &*worker_thread }, _marker: PhantomData })
58    }
59
60    /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
61    #[inline]
62    pub fn index(&self) -> usize {
63        self.worker.index()
64    }
65
66    /// The number of threads receiving the broadcast in the thread pool.
67    ///
68    /// # Future compatibility note
69    ///
70    /// Future versions of Rayon might vary the number of threads over time, but
71    /// this method will always return the number of threads which are actually
72    /// receiving your particular `broadcast` call.
73    #[inline]
74    pub fn num_threads(&self) -> usize {
75        self.worker.registry().num_threads()
76    }
77}
78
79impl<'a> fmt::Debug for BroadcastContext<'a> {
80    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
81        fmt.debug_struct("BroadcastContext")
82            .field("index", &self.index())
83            .field("num_threads", &self.num_threads())
84            .field("pool_id", &self.worker.registry().id())
85            .finish()
86    }
87}
88
89/// Execute `op` on every thread in the pool. It will be executed on each
90/// thread when they have nothing else to do locally, before they try to
91/// steal work from other threads. This function will not return until all
92/// threads have completed the `op`.
93///
94/// Unsafe because `registry` must not yet have terminated.
95pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
96where
97    OP: Fn(BroadcastContext<'_>) -> R + Sync,
98    R: Send,
99{
100    let f = move |injected: bool| {
101        debug_assert!(injected);
102        BroadcastContext::with(&op)
103    };
104
105    let n_threads = registry.num_threads();
106    let current_thread = unsafe { WorkerThread::current().as_ref() };
107    let tlv = crate::tlv::get();
108    let latch = CountLatch::with_count(n_threads, current_thread);
109    let jobs: Vec<_> =
110        (0..n_threads).map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch))).collect();
111    let job_refs = jobs.iter().map(|job| unsafe { job.as_job_ref() });
112
113    registry.inject_broadcast(job_refs);
114
115    // Wait for all jobs to complete, then collect the results, maybe propagating a panic.
116    latch.wait(current_thread);
117    jobs.into_iter().map(|job| unsafe { job.into_result() }).collect()
118}
119
120/// Execute `op` on every thread in the pool. It will be executed on each
121/// thread when they have nothing else to do locally, before they try to
122/// steal work from other threads. This function returns immediately after
123/// injecting the jobs.
124///
125/// Unsafe because `registry` must not yet have terminated.
126pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
127where
128    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
129{
130    let job = ArcJob::new({
131        let registry = Arc::clone(registry);
132        move || {
133            registry.catch_unwind(|| BroadcastContext::with(&op));
134            registry.terminate(); // (*) permit registry to terminate now
135        }
136    });
137
138    let n_threads = registry.num_threads();
139    let job_refs = (0..n_threads).map(|_| {
140        // Ensure that registry cannot terminate until this job has executed
141        // on each thread. This ref is decremented at the (*) above.
142        registry.increment_terminate_count();
143
144        ArcJob::as_static_job_ref(&job)
145    });
146
147    registry.inject_broadcast(job_refs);
148}