rustc_thread_pool/broadcast/
mod.rs

1use std::fmt;
2use std::marker::PhantomData;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5
6use crate::job::{ArcJob, StackJob};
7use crate::latch::{CountLatch, LatchRef};
8use crate::registry::{Registry, WorkerThread};
9
10mod tests;
11
12/// Executes `op` within every thread in the current threadpool. If this is
13/// called from a non-Rayon thread, it will execute in the global threadpool.
14/// Any attempts to use `join`, `scope`, or parallel iterators will then operate
15/// within that threadpool. When the call has completed on each thread, returns
16/// a vector containing all of their return values.
17///
18/// For more information, see the [`ThreadPool::broadcast()`][m] method.
19///
20/// [m]: struct.ThreadPool.html#method.broadcast
21pub fn broadcast<OP, R>(op: OP) -> Vec<R>
22where
23    OP: Fn(BroadcastContext<'_>) -> R + Sync,
24    R: Send,
25{
26    // We assert that current registry has not terminated.
27    unsafe { broadcast_in(op, &Registry::current()) }
28}
29
30/// Spawns an asynchronous task on every thread in this thread-pool. This task
31/// will run in the implicit, global scope, which means that it may outlast the
32/// current stack frame -- therefore, it cannot capture any references onto the
33/// stack (you will likely need a `move` closure).
34///
35/// For more information, see the [`ThreadPool::spawn_broadcast()`][m] method.
36///
37/// [m]: struct.ThreadPool.html#method.spawn_broadcast
38pub fn spawn_broadcast<OP>(op: OP)
39where
40    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
41{
42    // We assert that current registry has not terminated.
43    unsafe { spawn_broadcast_in(op, &Registry::current()) }
44}
45
46/// Provides context to a closure called by `broadcast`.
47pub struct BroadcastContext<'a> {
48    worker: &'a WorkerThread,
49
50    /// Make sure to prevent auto-traits like `Send` and `Sync`.
51    _marker: PhantomData<&'a mut dyn Fn()>,
52}
53
54impl<'a> BroadcastContext<'a> {
55    pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
56        let worker_thread = WorkerThread::current();
57        assert!(!worker_thread.is_null());
58        f(BroadcastContext { worker: unsafe { &*worker_thread }, _marker: PhantomData })
59    }
60
61    /// Our index amongst the broadcast threads (ranges from `0..self.num_threads()`).
62    #[inline]
63    pub fn index(&self) -> usize {
64        self.worker.index()
65    }
66
67    /// The number of threads receiving the broadcast in the thread pool.
68    ///
69    /// # Future compatibility note
70    ///
71    /// Future versions of Rayon might vary the number of threads over time, but
72    /// this method will always return the number of threads which are actually
73    /// receiving your particular `broadcast` call.
74    #[inline]
75    pub fn num_threads(&self) -> usize {
76        self.worker.registry().num_threads()
77    }
78}
79
80impl<'a> fmt::Debug for BroadcastContext<'a> {
81    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
82        fmt.debug_struct("BroadcastContext")
83            .field("index", &self.index())
84            .field("num_threads", &self.num_threads())
85            .field("pool_id", &self.worker.registry().id())
86            .finish()
87    }
88}
89
90/// Execute `op` on every thread in the pool. It will be executed on each
91/// thread when they have nothing else to do locally, before they try to
92/// steal work from other threads. This function will not return until all
93/// threads have completed the `op`.
94///
95/// Unsafe because `registry` must not yet have terminated.
96pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
97where
98    OP: Fn(BroadcastContext<'_>) -> R + Sync,
99    R: Send,
100{
101    let current_thread = WorkerThread::current();
102    let current_thread_addr = current_thread.expose_provenance();
103    let started = &AtomicBool::new(false);
104    let f = move |injected: bool| {
105        debug_assert!(injected);
106
107        // Mark as started if we are the thread that initiated that broadcast.
108        if current_thread_addr == WorkerThread::current().expose_provenance() {
109            started.store(true, Ordering::Relaxed);
110        }
111
112        BroadcastContext::with(&op)
113    };
114
115    let n_threads = registry.num_threads();
116    let current_thread = unsafe { current_thread.as_ref() };
117    let tlv = crate::tlv::get();
118    let latch = CountLatch::with_count(n_threads, current_thread);
119    let jobs: Vec<_> =
120        (0..n_threads).map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch))).collect();
121    let job_refs = jobs.iter().map(|job| unsafe { job.as_job_ref() });
122
123    registry.inject_broadcast(job_refs);
124
125    let current_thread_job_id = current_thread
126        .and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
127        .map(|worker| unsafe { jobs[worker.index()].as_job_ref() }.id());
128
129    // Wait for all jobs to complete, then collect the results, maybe propagating a panic.
130    latch.wait(
131        current_thread,
132        || started.load(Ordering::Relaxed),
133        |job| Some(job.id()) == current_thread_job_id,
134    );
135    jobs.into_iter().map(|job| unsafe { job.into_result() }).collect()
136}
137
138/// Execute `op` on every thread in the pool. It will be executed on each
139/// thread when they have nothing else to do locally, before they try to
140/// steal work from other threads. This function returns immediately after
141/// injecting the jobs.
142///
143/// Unsafe because `registry` must not yet have terminated.
144pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
145where
146    OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
147{
148    let job = ArcJob::new({
149        let registry = Arc::clone(registry);
150        move |_| {
151            registry.catch_unwind(|| BroadcastContext::with(&op));
152            registry.terminate(); // (*) permit registry to terminate now
153        }
154    });
155
156    let n_threads = registry.num_threads();
157    let job_refs = (0..n_threads).map(|_| {
158        // Ensure that registry cannot terminate until this job has executed
159        // on each thread. This ref is decremented at the (*) above.
160        registry.increment_terminate_count();
161
162        ArcJob::as_static_job_ref(&job)
163    });
164
165    registry.inject_broadcast(job_refs);
166}