rustc_thread_pool/
registry.rs

1use std::cell::Cell;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::Hasher;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, Mutex, Once};
6use std::{fmt, io, mem, ptr, thread};
7
8use crossbeam_deque::{Injector, Steal, Stealer, Worker};
9use smallvec::SmallVec;
10
11use crate::job::{JobFifo, JobRef, StackJob};
12use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch};
13use crate::sleep::Sleep;
14use crate::tlv::Tlv;
15use crate::{
16    AcquireThreadHandler, DeadlockHandler, ErrorKind, ExitHandler, PanicHandler,
17    ReleaseThreadHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, Yield, unwind,
18};
19
20/// Thread builder used for customization via
21/// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler).
22pub struct ThreadBuilder {
23    name: Option<String>,
24    stack_size: Option<usize>,
25    worker: Worker<JobRef>,
26    stealer: Stealer<JobRef>,
27    registry: Arc<Registry>,
28    index: usize,
29}
30
31impl ThreadBuilder {
32    /// Gets the index of this thread in the pool, within `0..num_threads`.
33    pub fn index(&self) -> usize {
34        self.index
35    }
36
37    /// Gets the string that was specified by `ThreadPoolBuilder::name()`.
38    pub fn name(&self) -> Option<&str> {
39        self.name.as_deref()
40    }
41
42    /// Gets the value that was specified by `ThreadPoolBuilder::stack_size()`.
43    pub fn stack_size(&self) -> Option<usize> {
44        self.stack_size
45    }
46
47    /// Executes the main loop for this thread. This will not return until the
48    /// thread pool is dropped.
49    pub fn run(self) {
50        unsafe { main_loop(self) }
51    }
52}
53
54impl fmt::Debug for ThreadBuilder {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.debug_struct("ThreadBuilder")
57            .field("pool", &self.registry.id())
58            .field("index", &self.index)
59            .field("name", &self.name)
60            .field("stack_size", &self.stack_size)
61            .finish()
62    }
63}
64
65/// Generalized trait for spawning a thread in the `Registry`.
66///
67/// This trait is pub-in-private -- E0445 forces us to make it public,
68/// but we don't actually want to expose these details in the API.
69pub trait ThreadSpawn {
70    private_decl! {}
71
72    /// Spawn a thread with the `ThreadBuilder` parameters, and then
73    /// call `ThreadBuilder::run()`.
74    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
75}
76
77/// Spawns a thread in the "normal" way with `std::thread::Builder`.
78///
79/// This type is pub-in-private -- E0445 forces us to make it public,
80/// but we don't actually want to expose these details in the API.
81#[derive(Debug, Default)]
82pub struct DefaultSpawn;
83
84impl ThreadSpawn for DefaultSpawn {
85    private_impl! {}
86
87    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
88        let mut b = thread::Builder::new();
89        if let Some(name) = thread.name() {
90            b = b.name(name.to_owned());
91        }
92        if let Some(stack_size) = thread.stack_size() {
93            b = b.stack_size(stack_size);
94        }
95        b.spawn(|| thread.run())?;
96        Ok(())
97    }
98}
99
100/// Spawns a thread with a user's custom callback.
101///
102/// This type is pub-in-private -- E0445 forces us to make it public,
103/// but we don't actually want to expose these details in the API.
104#[derive(Debug)]
105pub struct CustomSpawn<F>(F);
106
107impl<F> CustomSpawn<F>
108where
109    F: FnMut(ThreadBuilder) -> io::Result<()>,
110{
111    pub(super) fn new(spawn: F) -> Self {
112        CustomSpawn(spawn)
113    }
114}
115
116impl<F> ThreadSpawn for CustomSpawn<F>
117where
118    F: FnMut(ThreadBuilder) -> io::Result<()>,
119{
120    private_impl! {}
121
122    #[inline]
123    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
124        (self.0)(thread)
125    }
126}
127
128pub struct Registry {
129    thread_infos: Vec<ThreadInfo>,
130    sleep: Sleep,
131    injected_jobs: Injector<JobRef>,
132    broadcasts: Mutex<Vec<Worker<JobRef>>>,
133    panic_handler: Option<Box<PanicHandler>>,
134    pub(crate) deadlock_handler: Option<Box<DeadlockHandler>>,
135    start_handler: Option<Box<StartHandler>>,
136    exit_handler: Option<Box<ExitHandler>>,
137    pub(crate) acquire_thread_handler: Option<Box<AcquireThreadHandler>>,
138    pub(crate) release_thread_handler: Option<Box<ReleaseThreadHandler>>,
139
140    // When this latch reaches 0, it means that all work on this
141    // registry must be complete. This is ensured in the following ways:
142    //
143    // - if this is the global registry, there is a ref-count that never
144    //   gets released.
145    // - if this is a user-created thread-pool, then so long as the thread-pool
146    //   exists, it holds a reference.
147    // - when we inject a "blocking job" into the registry with `ThreadPool::install()`,
148    //   no adjustment is needed; the `ThreadPool` holds the reference, and since we won't
149    //   return until the blocking job is complete, that ref will continue to be held.
150    // - when `join()` or `scope()` is invoked, similarly, no adjustments are needed.
151    //   These are always owned by some other job (e.g., one injected by `ThreadPool::install()`)
152    //   and that job will keep the pool alive.
153    terminate_count: AtomicUsize,
154}
155
156/// ////////////////////////////////////////////////////////////////////////
157/// Initialization
158
159static mut THE_REGISTRY: Option<Arc<Registry>> = None;
160static THE_REGISTRY_SET: Once = Once::new();
161
162/// Starts the worker threads (if that has not already happened). If
163/// initialization has not already occurred, use the default
164/// configuration.
165pub(super) fn global_registry() -> &'static Arc<Registry> {
166    set_global_registry(default_global_registry)
167        .or_else(|err| {
168            // SAFETY: we only create a shared reference to `THE_REGISTRY` after the `call_once`
169            // that initializes it, and there will be no more mutable accesses at all.
170            debug_assert!(THE_REGISTRY_SET.is_completed());
171            let the_registry = unsafe { &*ptr::addr_of!(THE_REGISTRY) };
172            the_registry.as_ref().ok_or(err)
173        })
174        .expect("The global thread pool has not been initialized.")
175}
176
177/// Starts the worker threads (if that has not already happened) with
178/// the given builder.
179pub(super) fn init_global_registry<S>(
180    builder: ThreadPoolBuilder<S>,
181) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
182where
183    S: ThreadSpawn,
184{
185    set_global_registry(|| Registry::new(builder))
186}
187
188/// Starts the worker threads (if that has not already happened)
189/// by creating a registry with the given callback.
190fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
191where
192    F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
193{
194    let mut result = Err(ThreadPoolBuildError::new(ErrorKind::GlobalPoolAlreadyInitialized));
195
196    THE_REGISTRY_SET.call_once(|| {
197        result = registry().map(|registry: Arc<Registry>| {
198            // SAFETY: this is the only mutable access to `THE_REGISTRY`, thanks to `Once`, and
199            // `global_registry()` only takes a shared reference **after** this `call_once`.
200            unsafe {
201                ptr::addr_of_mut!(THE_REGISTRY).write(Some(registry));
202                (*ptr::addr_of!(THE_REGISTRY)).as_ref().unwrap_unchecked()
203            }
204        })
205    });
206
207    result
208}
209
210fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
211    let result = Registry::new(ThreadPoolBuilder::new());
212
213    // If we're running in an environment that doesn't support threads at all, we can fall back to
214    // using the current thread alone. This is crude, and probably won't work for non-blocking
215    // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine.
216    //
217    // Notably, this allows current WebAssembly targets to work even though their threading support
218    // is stubbed out, and we won't have to change anything if they do add real threading.
219    let unsupported = matches!(&result, Err(e) if e.is_unsupported());
220    if unsupported && WorkerThread::current().is_null() {
221        let builder = ThreadPoolBuilder::new().num_threads(1).spawn_handler(|thread| {
222            // Rather than starting a new thread, we're just taking over the current thread
223            // *without* running the main loop, so we can still return from here.
224            // The WorkerThread is leaked, but we never shutdown the global pool anyway.
225            let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
226            let registry = &*worker_thread.registry;
227            let index = worker_thread.index;
228
229            unsafe {
230                WorkerThread::set_current(worker_thread);
231
232                // let registry know we are ready to do work
233                Latch::set(&registry.thread_infos[index].primed);
234            }
235
236            Ok(())
237        });
238
239        let fallback_result = Registry::new(builder);
240        if fallback_result.is_ok() {
241            return fallback_result;
242        }
243    }
244
245    result
246}
247
248struct Terminator<'a>(&'a Arc<Registry>);
249
250impl<'a> Drop for Terminator<'a> {
251    fn drop(&mut self) {
252        self.0.terminate()
253    }
254}
255
256impl Registry {
257    pub(super) fn new<S>(
258        mut builder: ThreadPoolBuilder<S>,
259    ) -> Result<Arc<Self>, ThreadPoolBuildError>
260    where
261        S: ThreadSpawn,
262    {
263        // Soft-limit the number of threads that we can actually support.
264        let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads());
265
266        let breadth_first = builder.get_breadth_first();
267
268        let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
269            .map(|_| {
270                let worker = if breadth_first { Worker::new_fifo() } else { Worker::new_lifo() };
271
272                let stealer = worker.stealer();
273                (worker, stealer)
274            })
275            .unzip();
276
277        let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads)
278            .map(|_| {
279                let worker = Worker::new_fifo();
280                let stealer = worker.stealer();
281                (worker, stealer)
282            })
283            .unzip();
284
285        let registry = Arc::new(Registry {
286            thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
287            sleep: Sleep::new(n_threads),
288            injected_jobs: Injector::new(),
289            broadcasts: Mutex::new(broadcasts),
290            terminate_count: AtomicUsize::new(1),
291            panic_handler: builder.take_panic_handler(),
292            deadlock_handler: builder.take_deadlock_handler(),
293            start_handler: builder.take_start_handler(),
294            exit_handler: builder.take_exit_handler(),
295            acquire_thread_handler: builder.take_acquire_thread_handler(),
296            release_thread_handler: builder.take_release_thread_handler(),
297        });
298
299        // If we return early or panic, make sure to terminate existing threads.
300        let t1000 = Terminator(&registry);
301
302        for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
303            let thread = ThreadBuilder {
304                name: builder.get_thread_name(index),
305                stack_size: builder.get_stack_size(),
306                registry: Arc::clone(&registry),
307                worker,
308                stealer,
309                index,
310            };
311            if let Err(e) = builder.get_spawn_handler().spawn(thread) {
312                return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
313            }
314        }
315
316        // Returning normally now, without termination.
317        mem::forget(t1000);
318
319        Ok(registry)
320    }
321
322    pub fn current() -> Arc<Registry> {
323        unsafe {
324            let worker_thread = WorkerThread::current();
325            let registry = if worker_thread.is_null() {
326                global_registry()
327            } else {
328                &(*worker_thread).registry
329            };
330            Arc::clone(registry)
331        }
332    }
333
334    /// Returns the number of threads in the current registry. This
335    /// is better than `Registry::current().num_threads()` because it
336    /// avoids incrementing the `Arc`.
337    pub(super) fn current_num_threads() -> usize {
338        unsafe {
339            let worker_thread = WorkerThread::current();
340            if worker_thread.is_null() {
341                global_registry().num_threads()
342            } else {
343                (*worker_thread).registry.num_threads()
344            }
345        }
346    }
347
348    /// Returns the current `WorkerThread` if it's part of this `Registry`.
349    pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
350        unsafe {
351            let worker = WorkerThread::current().as_ref()?;
352            if worker.registry().id() == self.id() { Some(worker) } else { None }
353        }
354    }
355
356    /// Returns an opaque identifier for this registry.
357    pub(super) fn id(&self) -> RegistryId {
358        // We can rely on `self` not to change since we only ever create
359        // registries that are boxed up in an `Arc` (see `new()` above).
360        RegistryId { addr: self as *const Self as usize }
361    }
362
363    pub(super) fn num_threads(&self) -> usize {
364        self.thread_infos.len()
365    }
366
367    pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
368        if let Err(err) = unwind::halt_unwinding(f) {
369            // If there is no handler, or if that handler itself panics, then we abort.
370            let abort_guard = unwind::AbortIfPanic;
371            if let Some(ref handler) = self.panic_handler {
372                handler(err);
373                mem::forget(abort_guard);
374            }
375        }
376    }
377
378    /// Waits for the worker threads to get up and running. This is
379    /// meant to be used for benchmarking purposes, primarily, so that
380    /// you can get more consistent numbers by having everything
381    /// "ready to go".
382    pub(super) fn wait_until_primed(&self) {
383        for info in &self.thread_infos {
384            info.primed.wait();
385        }
386    }
387
388    /// Waits for the worker threads to stop. This is used for testing
389    /// -- so we can check that termination actually works.
390    pub(super) fn wait_until_stopped(&self) {
391        self.release_thread();
392        for info in &self.thread_infos {
393            info.stopped.wait();
394        }
395        self.acquire_thread();
396    }
397
398    pub(crate) fn acquire_thread(&self) {
399        if let Some(ref acquire_thread_handler) = self.acquire_thread_handler {
400            acquire_thread_handler();
401        }
402    }
403
404    pub(crate) fn release_thread(&self) {
405        if let Some(ref release_thread_handler) = self.release_thread_handler {
406            release_thread_handler();
407        }
408    }
409
410    /// ////////////////////////////////////////////////////////////////////////
411    /// MAIN LOOP
412    ///
413    /// So long as all of the worker threads are hanging out in their
414    /// top-level loop, there is no work to be done.
415
416    /// Push a job into the given `registry`. If we are running on a
417    /// worker thread for the registry, this will push onto the
418    /// deque. Else, it will inject from the outside (which is slower).
419    pub(super) fn inject_or_push(&self, job_ref: JobRef) {
420        let worker_thread = WorkerThread::current();
421        unsafe {
422            if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
423                (*worker_thread).push(job_ref);
424            } else {
425                self.inject(job_ref);
426            }
427        }
428    }
429
430    /// Push a job into the "external jobs" queue; it will be taken by
431    /// whatever worker has nothing to do. Use this if you know that
432    /// you are not on a worker of this registry.
433    pub(super) fn inject(&self, injected_job: JobRef) {
434        // It should not be possible for `state.terminate` to be true
435        // here. It is only set to true when the user creates (and
436        // drops) a `ThreadPool`; and, in that case, they cannot be
437        // calling `inject()` later, since they dropped their
438        // `ThreadPool`.
439        debug_assert_ne!(
440            self.terminate_count.load(Ordering::Acquire),
441            0,
442            "inject() sees state.terminate as true"
443        );
444
445        let queue_was_empty = self.injected_jobs.is_empty();
446
447        self.injected_jobs.push(injected_job);
448        self.sleep.new_injected_jobs(1, queue_was_empty);
449    }
450
451    pub(crate) fn has_injected_job(&self) -> bool {
452        !self.injected_jobs.is_empty()
453    }
454
455    fn pop_injected_job(&self) -> Option<JobRef> {
456        loop {
457            match self.injected_jobs.steal() {
458                Steal::Success(job) => return Some(job),
459                Steal::Empty => return None,
460                Steal::Retry => {}
461            }
462        }
463    }
464
465    /// Push a job into each thread's own "external jobs" queue; it will be
466    /// executed only on that thread, when it has nothing else to do locally,
467    /// before it tries to steal other work.
468    ///
469    /// **Panics** if not given exactly as many jobs as there are threads.
470    pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) {
471        assert_eq!(self.num_threads(), injected_jobs.len());
472        {
473            let broadcasts = self.broadcasts.lock().unwrap();
474
475            // It should not be possible for `state.terminate` to be true
476            // here. It is only set to true when the user creates (and
477            // drops) a `ThreadPool`; and, in that case, they cannot be
478            // calling `inject_broadcast()` later, since they dropped their
479            // `ThreadPool`.
480            debug_assert_ne!(
481                self.terminate_count.load(Ordering::Acquire),
482                0,
483                "inject_broadcast() sees state.terminate as true"
484            );
485
486            assert_eq!(broadcasts.len(), injected_jobs.len());
487            for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
488                worker.push(job_ref);
489            }
490        }
491        for i in 0..self.num_threads() {
492            self.sleep.notify_worker_latch_is_set(i);
493        }
494    }
495
496    /// If already in a worker-thread of this registry, just execute `op`.
497    /// Otherwise, inject `op` in this thread-pool. Either way, block until `op`
498    /// completes and return its return value. If `op` panics, that panic will
499    /// be propagated as well. The second argument indicates `true` if injection
500    /// was performed, `false` if executed directly.
501    pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
502    where
503        OP: FnOnce(&WorkerThread, bool) -> R + Send,
504        R: Send,
505    {
506        unsafe {
507            let worker_thread = WorkerThread::current();
508            if worker_thread.is_null() {
509                self.in_worker_cold(op)
510            } else if (*worker_thread).registry().id() != self.id() {
511                self.in_worker_cross(&*worker_thread, op)
512            } else {
513                // Perfectly valid to give them a `&T`: this is the
514                // current thread, so we know the data structure won't be
515                // invalidated until we return.
516                op(&*worker_thread, false)
517            }
518        }
519    }
520
521    #[cold]
522    unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
523    where
524        OP: FnOnce(&WorkerThread, bool) -> R + Send,
525        R: Send,
526    {
527        thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
528
529        LOCK_LATCH.with(|l| {
530            // This thread isn't a member of *any* thread pool, so just block.
531            debug_assert!(WorkerThread::current().is_null());
532            let job = StackJob::new(
533                Tlv::null(),
534                |injected| {
535                    let worker_thread = WorkerThread::current();
536                    assert!(injected && !worker_thread.is_null());
537                    op(unsafe { &*worker_thread }, true)
538                },
539                LatchRef::new(l),
540            );
541            self.inject(unsafe { job.as_job_ref() });
542            self.release_thread();
543            job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.
544            self.acquire_thread();
545
546            unsafe { job.into_result() }
547        })
548    }
549
550    #[cold]
551    unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
552    where
553        OP: FnOnce(&WorkerThread, bool) -> R + Send,
554        R: Send,
555    {
556        // This thread is a member of a different pool, so let it process
557        // other work while waiting for this `op` to complete.
558        debug_assert!(current_thread.registry().id() != self.id());
559        let latch = SpinLatch::cross(current_thread);
560        let job = StackJob::new(
561            Tlv::null(),
562            |injected| {
563                let worker_thread = WorkerThread::current();
564                assert!(injected && !worker_thread.is_null());
565                op(unsafe { &*worker_thread }, true)
566            },
567            latch,
568        );
569        self.inject(unsafe { job.as_job_ref() });
570        unsafe { current_thread.wait_until(&job.latch) };
571        unsafe { job.into_result() }
572    }
573
574    /// Increments the terminate counter. This increment should be
575    /// balanced by a call to `terminate`, which will decrement. This
576    /// is used when spawning asynchronous work, which needs to
577    /// prevent the registry from terminating so long as it is active.
578    ///
579    /// Note that blocking functions such as `join` and `scope` do not
580    /// need to concern themselves with this fn; their context is
581    /// responsible for ensuring the current thread-pool will not
582    /// terminate until they return.
583    ///
584    /// The global thread-pool always has an outstanding reference
585    /// (the initial one). Custom thread-pools have one outstanding
586    /// reference that is dropped when the `ThreadPool` is dropped:
587    /// since installing the thread-pool blocks until any joins/scopes
588    /// complete, this ensures that joins/scopes are covered.
589    ///
590    /// The exception is `::spawn()`, which can create a job outside
591    /// of any blocking scope. In that case, the job itself holds a
592    /// terminate count and is responsible for invoking `terminate()`
593    /// when finished.
594    pub(super) fn increment_terminate_count(&self) {
595        let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
596        debug_assert!(previous != 0, "registry ref count incremented from zero");
597        assert!(previous != usize::MAX, "overflow in registry ref count");
598    }
599
600    /// Signals that the thread-pool which owns this registry has been
601    /// dropped. The worker threads will gradually terminate, once any
602    /// extant work is completed.
603    pub(super) fn terminate(&self) {
604        if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
605            for (i, thread_info) in self.thread_infos.iter().enumerate() {
606                unsafe { OnceLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
607            }
608        }
609    }
610
611    /// Notify the worker that the latch they are sleeping on has been "set".
612    pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
613        self.sleep.notify_worker_latch_is_set(target_worker_index);
614    }
615}
616
617/// Mark a Rayon worker thread as blocked. This triggers the deadlock handler
618/// if no other worker thread is active
619#[inline]
620pub fn mark_blocked() {
621    let worker_thread = WorkerThread::current();
622    assert!(!worker_thread.is_null());
623    unsafe {
624        let registry = &(*worker_thread).registry;
625        registry.sleep.mark_blocked(&registry.deadlock_handler)
626    }
627}
628
629/// Mark a previously blocked Rayon worker thread as unblocked
630#[inline]
631pub fn mark_unblocked(registry: &Registry) {
632    registry.sleep.mark_unblocked()
633}
634
635#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
636pub(super) struct RegistryId {
637    addr: usize,
638}
639
640struct ThreadInfo {
641    /// Latch set once thread has started and we are entering into the
642    /// main loop. Used to wait for worker threads to become primed,
643    /// primarily of interest for benchmarking.
644    primed: LockLatch,
645
646    /// Latch is set once worker thread has completed. Used to wait
647    /// until workers have stopped; only used for tests.
648    stopped: LockLatch,
649
650    /// The latch used to signal that terminated has been requested.
651    /// This latch is *set* by the `terminate` method on the
652    /// `Registry`, once the registry's main "terminate" counter
653    /// reaches zero.
654    terminate: OnceLatch,
655
656    /// the "stealer" half of the worker's deque
657    stealer: Stealer<JobRef>,
658}
659
660impl ThreadInfo {
661    fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
662        ThreadInfo {
663            primed: LockLatch::new(),
664            stopped: LockLatch::new(),
665            terminate: OnceLatch::new(),
666            stealer,
667        }
668    }
669}
670
671/// ////////////////////////////////////////////////////////////////////////
672/// WorkerThread identifiers
673
674pub(super) struct WorkerThread {
675    /// the "worker" half of our local deque
676    worker: Worker<JobRef>,
677
678    /// the "stealer" half of the worker's broadcast deque
679    stealer: Stealer<JobRef>,
680
681    /// local queue used for `spawn_fifo` indirection
682    fifo: JobFifo,
683
684    pub(crate) index: usize,
685
686    /// A weak random number generator.
687    rng: XorShift64Star,
688
689    pub(crate) registry: Arc<Registry>,
690}
691
692// This is a bit sketchy, but basically: the WorkerThread is
693// allocated on the stack of the worker on entry and stored into this
694// thread local variable. So it will remain valid at least until the
695// worker is fully unwound. Using an unsafe pointer avoids the need
696// for a RefCell<T> etc.
697thread_local! {
698    static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) };
699}
700
701impl From<ThreadBuilder> for WorkerThread {
702    fn from(thread: ThreadBuilder) -> Self {
703        Self {
704            worker: thread.worker,
705            stealer: thread.stealer,
706            fifo: JobFifo::new(),
707            index: thread.index,
708            rng: XorShift64Star::new(),
709            registry: thread.registry,
710        }
711    }
712}
713
714impl Drop for WorkerThread {
715    fn drop(&mut self) {
716        // Undo `set_current`
717        WORKER_THREAD_STATE.with(|t| {
718            assert!(t.get().eq(&(self as *const _)));
719            t.set(ptr::null());
720        });
721    }
722}
723
724impl WorkerThread {
725    /// Gets the `WorkerThread` index for the current thread; returns
726    /// NULL if this is not a worker thread. This pointer is valid
727    /// anywhere on the current thread.
728    #[inline]
729    pub(super) fn current() -> *const WorkerThread {
730        WORKER_THREAD_STATE.with(Cell::get)
731    }
732
733    /// Sets `self` as the worker thread index for the current thread.
734    /// This is done during worker thread startup.
735    unsafe fn set_current(thread: *const WorkerThread) {
736        WORKER_THREAD_STATE.with(|t| {
737            assert!(t.get().is_null());
738            t.set(thread);
739        });
740    }
741
742    /// Returns the registry that owns this worker thread.
743    #[inline]
744    pub(super) fn registry(&self) -> &Arc<Registry> {
745        &self.registry
746    }
747
748    /// Our index amongst the worker threads (ranges from `0..self.num_threads()`).
749    #[inline]
750    pub(super) fn index(&self) -> usize {
751        self.index
752    }
753
754    #[inline]
755    pub(super) unsafe fn push(&self, job: JobRef) {
756        let queue_was_empty = self.worker.is_empty();
757        self.worker.push(job);
758        self.registry.sleep.new_internal_jobs(1, queue_was_empty);
759    }
760
761    #[inline]
762    pub(super) unsafe fn push_fifo(&self, job: JobRef) {
763        unsafe { self.push(self.fifo.push(job)) };
764    }
765
766    #[inline]
767    pub(super) fn local_deque_is_empty(&self) -> bool {
768        self.worker.is_empty()
769    }
770
771    /// Attempts to obtain a "local" job -- typically this means
772    /// popping from the top of the stack, though if we are configured
773    /// for breadth-first execution, it would mean dequeuing from the
774    /// bottom.
775    #[inline]
776    pub(super) fn take_local_job(&self) -> Option<JobRef> {
777        let popped_job = self.worker.pop();
778
779        if popped_job.is_some() {
780            return popped_job;
781        }
782
783        loop {
784            match self.stealer.steal() {
785                Steal::Success(job) => return Some(job),
786                Steal::Empty => return None,
787                Steal::Retry => {}
788            }
789        }
790    }
791
792    pub(super) fn has_injected_job(&self) -> bool {
793        !self.stealer.is_empty() || self.registry.has_injected_job()
794    }
795
796    /// Wait until the latch is set. Try to keep busy by popping and
797    /// stealing tasks as necessary.
798    #[inline]
799    pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
800        unsafe { self.wait_or_steal_until(latch, false) };
801    }
802
803    /// Wait until the latch is set. Executes local jobs if `is_job` is true for them and
804    /// `all_jobs_started` still returns false.
805    #[inline]
806    pub(super) unsafe fn wait_for_jobs<L: AsCoreLatch + ?Sized, const BROADCAST_JOBS: bool>(
807        &self,
808        latch: &L,
809        mut all_jobs_started: impl FnMut() -> bool,
810        mut is_job: impl FnMut(&JobRef) -> bool,
811        mut execute_job: impl FnMut(JobRef) -> (),
812    ) {
813        let mut jobs = SmallVec::<[JobRef; 8]>::new();
814        let mut broadcast_jobs = SmallVec::<[JobRef; 8]>::new();
815
816        while !all_jobs_started() {
817            if let Some(job) = self.worker.pop() {
818                if is_job(&job) {
819                    execute_job(job);
820                } else {
821                    jobs.push(job);
822                }
823            } else {
824                if BROADCAST_JOBS {
825                    let broadcast_job = loop {
826                        match self.stealer.steal() {
827                            Steal::Success(job) => break Some(job),
828                            Steal::Empty => break None,
829                            Steal::Retry => continue,
830                        }
831                    };
832                    if let Some(job) = broadcast_job {
833                        if is_job(&job) {
834                            execute_job(job);
835                        } else {
836                            broadcast_jobs.push(job);
837                        }
838                    }
839                }
840                break;
841            }
842        }
843
844        // Restore the jobs that we weren't looking for.
845        for job in jobs {
846            self.worker.push(job);
847        }
848        if BROADCAST_JOBS {
849            let broadcasts = self.registry.broadcasts.lock().unwrap();
850            for job in broadcast_jobs {
851                broadcasts[self.index].push(job);
852            }
853        }
854
855        // Wait for the jobs to finish.
856        unsafe { self.wait_until(latch) };
857        debug_assert!(latch.as_core_latch().probe());
858    }
859
860    pub(super) unsafe fn wait_or_steal_until<L: AsCoreLatch + ?Sized>(
861        &self,
862        latch: &L,
863        steal: bool,
864    ) {
865        let latch = latch.as_core_latch();
866        if !latch.probe() {
867            if steal {
868                unsafe { self.wait_or_steal_until_cold(latch) };
869            } else {
870                unsafe { self.wait_until_cold(latch) };
871            }
872        }
873    }
874
875    #[cold]
876    unsafe fn wait_or_steal_until_cold(&self, latch: &CoreLatch) {
877        // the code below should swallow all panics and hence never
878        // unwind; but if something does wrong, we want to abort,
879        // because otherwise other code in rayon may assume that the
880        // latch has been signaled, and that can lead to random memory
881        // accesses, which would be *very bad*
882        let abort_guard = unwind::AbortIfPanic;
883
884        'outer: while !latch.probe() {
885            // Check for local work *before* we start marking ourself idle,
886            // especially to avoid modifying shared sleep state.
887            if let Some(job) = self.take_local_job() {
888                unsafe { self.execute(job) };
889                continue;
890            }
891
892            let mut idle_state = self.registry.sleep.start_looking(self.index);
893            while !latch.probe() {
894                if let Some(job) = self.find_work() {
895                    self.registry.sleep.work_found();
896                    unsafe { self.execute(job) };
897                    // The job might have injected local work, so go back to the outer loop.
898                    continue 'outer;
899                } else {
900                    self.registry.sleep.no_work_found(&mut idle_state, latch, &self, true)
901                }
902            }
903
904            // If we were sleepy, we are not anymore. We "found work" --
905            // whatever the surrounding thread was doing before it had to wait.
906            self.registry.sleep.work_found();
907            break;
908        }
909
910        mem::forget(abort_guard); // successful execution, do not abort
911    }
912
913    #[cold]
914    unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
915        // the code below should swallow all panics and hence never
916        // unwind; but if something does wrong, we want to abort,
917        // because otherwise other code in rayon may assume that the
918        // latch has been signaled, and that can lead to random memory
919        // accesses, which would be *very bad*
920        let abort_guard = unwind::AbortIfPanic;
921
922        let mut idle_state = self.registry.sleep.start_looking(self.index);
923        while !latch.probe() {
924            self.registry.sleep.no_work_found(&mut idle_state, latch, &self, false);
925        }
926
927        // If we were sleepy, we are not anymore. We "found work" --
928        // whatever the surrounding thread was doing before it had to wait.
929        self.registry.sleep.work_found();
930
931        mem::forget(abort_guard); // successful execution, do not abort
932    }
933
934    unsafe fn wait_until_out_of_work(&self) {
935        debug_assert_eq!(self as *const _, WorkerThread::current());
936        let registry = &*self.registry;
937        let index = self.index;
938
939        registry.acquire_thread();
940        unsafe { self.wait_or_steal_until(&registry.thread_infos[index].terminate, true) };
941
942        // Should not be any work left in our queue.
943        debug_assert!(self.take_local_job().is_none());
944
945        // Let registry know we are done
946        unsafe { Latch::set(&registry.thread_infos[index].stopped) };
947    }
948
949    fn find_work(&self) -> Option<JobRef> {
950        // Try to find some work to do. We give preference first
951        // to things in our local deque, then in other workers
952        // deques, and finally to injected jobs from the
953        // outside. The idea is to finish what we started before
954        // we take on something new.
955        self.take_local_job().or_else(|| self.steal()).or_else(|| self.registry.pop_injected_job())
956    }
957
958    pub(super) fn yield_now(&self) -> Yield {
959        match self.find_work() {
960            Some(job) => unsafe {
961                self.execute(job);
962                Yield::Executed
963            },
964            None => Yield::Idle,
965        }
966    }
967
968    pub(super) fn yield_local(&self) -> Yield {
969        match self.take_local_job() {
970            Some(job) => unsafe {
971                self.execute(job);
972                Yield::Executed
973            },
974            None => Yield::Idle,
975        }
976    }
977
978    #[inline]
979    pub(super) unsafe fn execute(&self, job: JobRef) {
980        unsafe { job.execute() };
981    }
982
983    /// Try to steal a single job and return it.
984    ///
985    /// This should only be done as a last resort, when there is no
986    /// local work to do.
987    fn steal(&self) -> Option<JobRef> {
988        // we only steal when we don't have any work to do locally
989        debug_assert!(self.local_deque_is_empty());
990
991        // otherwise, try to steal
992        let thread_infos = &self.registry.thread_infos.as_slice();
993        let num_threads = thread_infos.len();
994        if num_threads <= 1 {
995            return None;
996        }
997
998        loop {
999            let mut retry = false;
1000            let start = self.rng.next_usize(num_threads);
1001            let job = (start..num_threads)
1002                .chain(0..start)
1003                .filter(move |&i| i != self.index)
1004                .find_map(|victim_index| {
1005                    let victim = &thread_infos[victim_index];
1006                    match victim.stealer.steal() {
1007                        Steal::Success(job) => Some(job),
1008                        Steal::Empty => None,
1009                        Steal::Retry => {
1010                            retry = true;
1011                            None
1012                        }
1013                    }
1014                });
1015            if job.is_some() || !retry {
1016                return job;
1017            }
1018        }
1019    }
1020}
1021
1022/// ////////////////////////////////////////////////////////////////////////
1023
1024unsafe fn main_loop(thread: ThreadBuilder) {
1025    let worker_thread = &WorkerThread::from(thread);
1026    unsafe { WorkerThread::set_current(worker_thread) };
1027    let registry = &*worker_thread.registry;
1028    let index = worker_thread.index;
1029
1030    // let registry know we are ready to do work
1031    unsafe { Latch::set(&registry.thread_infos[index].primed) };
1032
1033    // Worker threads should not panic. If they do, just abort, as the
1034    // internal state of the threadpool is corrupted. Note that if
1035    // **user code** panics, we should catch that and redirect.
1036    let abort_guard = unwind::AbortIfPanic;
1037
1038    // Inform a user callback that we started a thread.
1039    if let Some(ref handler) = registry.start_handler {
1040        registry.catch_unwind(|| handler(index));
1041    }
1042
1043    unsafe { worker_thread.wait_until_out_of_work() };
1044
1045    // Normal termination, do not abort.
1046    mem::forget(abort_guard);
1047
1048    // Inform a user callback that we exited a thread.
1049    if let Some(ref handler) = registry.exit_handler {
1050        registry.catch_unwind(|| handler(index));
1051        // We're already exiting the thread, there's nothing else to do.
1052    }
1053
1054    registry.release_thread();
1055}
1056
1057/// If already in a worker-thread, just execute `op`. Otherwise,
1058/// execute `op` in the default thread-pool. Either way, block until
1059/// `op` completes and return its return value. If `op` panics, that
1060/// panic will be propagated as well. The second argument indicates
1061/// `true` if injection was performed, `false` if executed directly.
1062pub(super) fn in_worker<OP, R>(op: OP) -> R
1063where
1064    OP: FnOnce(&WorkerThread, bool) -> R + Send,
1065    R: Send,
1066{
1067    unsafe {
1068        let owner_thread = WorkerThread::current();
1069        if !owner_thread.is_null() {
1070            // Perfectly valid to give them a `&T`: this is the
1071            // current thread, so we know the data structure won't be
1072            // invalidated until we return.
1073            op(&*owner_thread, false)
1074        } else {
1075            global_registry().in_worker(op)
1076        }
1077    }
1078}
1079
1080/// [xorshift*] is a fast pseudorandom number generator which will
1081/// even tolerate weak seeding, as long as it's not zero.
1082///
1083/// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift*
1084struct XorShift64Star {
1085    state: Cell<u64>,
1086}
1087
1088impl XorShift64Star {
1089    fn new() -> Self {
1090        // Any non-zero seed will do -- this uses the hash of a global counter.
1091        let mut seed = 0;
1092        while seed == 0 {
1093            let mut hasher = DefaultHasher::new();
1094            static COUNTER: AtomicUsize = AtomicUsize::new(0);
1095            hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
1096            seed = hasher.finish();
1097        }
1098
1099        XorShift64Star { state: Cell::new(seed) }
1100    }
1101
1102    fn next(&self) -> u64 {
1103        let mut x = self.state.get();
1104        debug_assert_ne!(x, 0);
1105        x ^= x >> 12;
1106        x ^= x << 25;
1107        x ^= x >> 27;
1108        self.state.set(x);
1109        x.wrapping_mul(0x2545_f491_4f6c_dd1d)
1110    }
1111
1112    /// Return a value from `0..n`.
1113    fn next_usize(&self, n: usize) -> usize {
1114        (self.next() % n as u64) as usize
1115    }
1116}