rustc_thread_pool/broadcast/
mod.rs1use std::fmt;
2use std::marker::PhantomData;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5
6use crate::job::{ArcJob, StackJob};
7use crate::latch::{CountLatch, LatchRef};
8use crate::registry::{Registry, WorkerThread};
9
10mod tests;
11
12pub fn broadcast<OP, R>(op: OP) -> Vec<R>
22where
23 OP: Fn(BroadcastContext<'_>) -> R + Sync,
24 R: Send,
25{
26 unsafe { broadcast_in(op, &Registry::current()) }
28}
29
30pub fn spawn_broadcast<OP>(op: OP)
39where
40 OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
41{
42 unsafe { spawn_broadcast_in(op, &Registry::current()) }
44}
45
46pub struct BroadcastContext<'a> {
48 worker: &'a WorkerThread,
49
50 _marker: PhantomData<&'a mut dyn Fn()>,
52}
53
54impl<'a> BroadcastContext<'a> {
55 pub(super) fn with<R>(f: impl FnOnce(BroadcastContext<'_>) -> R) -> R {
56 let worker_thread = WorkerThread::current();
57 assert!(!worker_thread.is_null());
58 f(BroadcastContext { worker: unsafe { &*worker_thread }, _marker: PhantomData })
59 }
60
61 #[inline]
63 pub fn index(&self) -> usize {
64 self.worker.index()
65 }
66
67 #[inline]
75 pub fn num_threads(&self) -> usize {
76 self.worker.registry().num_threads()
77 }
78}
79
80impl<'a> fmt::Debug for BroadcastContext<'a> {
81 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
82 fmt.debug_struct("BroadcastContext")
83 .field("index", &self.index())
84 .field("num_threads", &self.num_threads())
85 .field("pool_id", &self.worker.registry().id())
86 .finish()
87 }
88}
89
90pub(super) unsafe fn broadcast_in<OP, R>(op: OP, registry: &Arc<Registry>) -> Vec<R>
97where
98 OP: Fn(BroadcastContext<'_>) -> R + Sync,
99 R: Send,
100{
101 let current_thread = WorkerThread::current();
102 let current_thread_addr = current_thread.expose_provenance();
103 let started = &AtomicBool::new(false);
104 let f = move |injected: bool| {
105 debug_assert!(injected);
106
107 if current_thread_addr == WorkerThread::current().expose_provenance() {
109 started.store(true, Ordering::Relaxed);
110 }
111
112 BroadcastContext::with(&op)
113 };
114
115 let n_threads = registry.num_threads();
116 let current_thread = unsafe { current_thread.as_ref() };
117 let tlv = crate::tlv::get();
118 let latch = CountLatch::with_count(n_threads, current_thread);
119 let jobs: Vec<_> =
120 (0..n_threads).map(|_| StackJob::new(tlv, &f, LatchRef::new(&latch))).collect();
121 let job_refs = jobs.iter().map(|job| unsafe { job.as_job_ref() });
122
123 registry.inject_broadcast(job_refs);
124
125 let current_thread_job_id = current_thread
126 .and_then(|worker| (registry.id() == worker.registry.id()).then(|| worker))
127 .map(|worker| unsafe { jobs[worker.index()].as_job_ref() }.id());
128
129 latch.wait(
131 current_thread,
132 || started.load(Ordering::Relaxed),
133 |job| Some(job.id()) == current_thread_job_id,
134 );
135 jobs.into_iter().map(|job| unsafe { job.into_result() }).collect()
136}
137
138pub(super) unsafe fn spawn_broadcast_in<OP>(op: OP, registry: &Arc<Registry>)
145where
146 OP: Fn(BroadcastContext<'_>) + Send + Sync + 'static,
147{
148 let job = ArcJob::new({
149 let registry = Arc::clone(registry);
150 move |_| {
151 registry.catch_unwind(|| BroadcastContext::with(&op));
152 registry.terminate(); }
154 });
155
156 let n_threads = registry.num_threads();
157 let job_refs = (0..n_threads).map(|_| {
158 registry.increment_terminate_count();
161
162 ArcJob::as_static_job_ref(&job)
163 });
164
165 registry.inject_broadcast(job_refs);
166}