rustc_thread_pool/
registry.rs

1use std::cell::Cell;
2use std::collections::hash_map::DefaultHasher;
3use std::hash::Hasher;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::{Arc, Mutex, Once};
6use std::{fmt, io, mem, ptr, thread};
7
8use crossbeam_deque::{Injector, Steal, Stealer, Worker};
9
10use crate::job::{JobFifo, JobRef, StackJob};
11use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch};
12use crate::sleep::Sleep;
13use crate::tlv::Tlv;
14use crate::{
15    AcquireThreadHandler, DeadlockHandler, ErrorKind, ExitHandler, PanicHandler,
16    ReleaseThreadHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, Yield, unwind,
17};
18
19/// Thread builder used for customization via
20/// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler).
21pub struct ThreadBuilder {
22    name: Option<String>,
23    stack_size: Option<usize>,
24    worker: Worker<JobRef>,
25    stealer: Stealer<JobRef>,
26    registry: Arc<Registry>,
27    index: usize,
28}
29
30impl ThreadBuilder {
31    /// Gets the index of this thread in the pool, within `0..num_threads`.
32    pub fn index(&self) -> usize {
33        self.index
34    }
35
36    /// Gets the string that was specified by `ThreadPoolBuilder::name()`.
37    pub fn name(&self) -> Option<&str> {
38        self.name.as_deref()
39    }
40
41    /// Gets the value that was specified by `ThreadPoolBuilder::stack_size()`.
42    pub fn stack_size(&self) -> Option<usize> {
43        self.stack_size
44    }
45
46    /// Executes the main loop for this thread. This will not return until the
47    /// thread pool is dropped.
48    pub fn run(self) {
49        unsafe { main_loop(self) }
50    }
51}
52
53impl fmt::Debug for ThreadBuilder {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        f.debug_struct("ThreadBuilder")
56            .field("pool", &self.registry.id())
57            .field("index", &self.index)
58            .field("name", &self.name)
59            .field("stack_size", &self.stack_size)
60            .finish()
61    }
62}
63
64/// Generalized trait for spawning a thread in the `Registry`.
65///
66/// This trait is pub-in-private -- E0445 forces us to make it public,
67/// but we don't actually want to expose these details in the API.
68pub trait ThreadSpawn {
69    private_decl! {}
70
71    /// Spawn a thread with the `ThreadBuilder` parameters, and then
72    /// call `ThreadBuilder::run()`.
73    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
74}
75
76/// Spawns a thread in the "normal" way with `std::thread::Builder`.
77///
78/// This type is pub-in-private -- E0445 forces us to make it public,
79/// but we don't actually want to expose these details in the API.
80#[derive(Debug, Default)]
81pub struct DefaultSpawn;
82
83impl ThreadSpawn for DefaultSpawn {
84    private_impl! {}
85
86    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
87        let mut b = thread::Builder::new();
88        if let Some(name) = thread.name() {
89            b = b.name(name.to_owned());
90        }
91        if let Some(stack_size) = thread.stack_size() {
92            b = b.stack_size(stack_size);
93        }
94        b.spawn(|| thread.run())?;
95        Ok(())
96    }
97}
98
99/// Spawns a thread with a user's custom callback.
100///
101/// This type is pub-in-private -- E0445 forces us to make it public,
102/// but we don't actually want to expose these details in the API.
103#[derive(Debug)]
104pub struct CustomSpawn<F>(F);
105
106impl<F> CustomSpawn<F>
107where
108    F: FnMut(ThreadBuilder) -> io::Result<()>,
109{
110    pub(super) fn new(spawn: F) -> Self {
111        CustomSpawn(spawn)
112    }
113}
114
115impl<F> ThreadSpawn for CustomSpawn<F>
116where
117    F: FnMut(ThreadBuilder) -> io::Result<()>,
118{
119    private_impl! {}
120
121    #[inline]
122    fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
123        (self.0)(thread)
124    }
125}
126
127pub struct Registry {
128    thread_infos: Vec<ThreadInfo>,
129    sleep: Sleep,
130    injected_jobs: Injector<JobRef>,
131    broadcasts: Mutex<Vec<Worker<JobRef>>>,
132    panic_handler: Option<Box<PanicHandler>>,
133    pub(crate) deadlock_handler: Option<Box<DeadlockHandler>>,
134    start_handler: Option<Box<StartHandler>>,
135    exit_handler: Option<Box<ExitHandler>>,
136    pub(crate) acquire_thread_handler: Option<Box<AcquireThreadHandler>>,
137    pub(crate) release_thread_handler: Option<Box<ReleaseThreadHandler>>,
138
139    // When this latch reaches 0, it means that all work on this
140    // registry must be complete. This is ensured in the following ways:
141    //
142    // - if this is the global registry, there is a ref-count that never
143    //   gets released.
144    // - if this is a user-created thread-pool, then so long as the thread-pool
145    //   exists, it holds a reference.
146    // - when we inject a "blocking job" into the registry with `ThreadPool::install()`,
147    //   no adjustment is needed; the `ThreadPool` holds the reference, and since we won't
148    //   return until the blocking job is complete, that ref will continue to be held.
149    // - when `join()` or `scope()` is invoked, similarly, no adjustments are needed.
150    //   These are always owned by some other job (e.g., one injected by `ThreadPool::install()`)
151    //   and that job will keep the pool alive.
152    terminate_count: AtomicUsize,
153}
154
155/// ////////////////////////////////////////////////////////////////////////
156/// Initialization
157
158static mut THE_REGISTRY: Option<Arc<Registry>> = None;
159static THE_REGISTRY_SET: Once = Once::new();
160
161/// Starts the worker threads (if that has not already happened). If
162/// initialization has not already occurred, use the default
163/// configuration.
164pub(super) fn global_registry() -> &'static Arc<Registry> {
165    set_global_registry(default_global_registry)
166        .or_else(|err| {
167            // SAFETY: we only create a shared reference to `THE_REGISTRY` after the `call_once`
168            // that initializes it, and there will be no more mutable accesses at all.
169            debug_assert!(THE_REGISTRY_SET.is_completed());
170            let the_registry = unsafe { &*ptr::addr_of!(THE_REGISTRY) };
171            the_registry.as_ref().ok_or(err)
172        })
173        .expect("The global thread pool has not been initialized.")
174}
175
176/// Starts the worker threads (if that has not already happened) with
177/// the given builder.
178pub(super) fn init_global_registry<S>(
179    builder: ThreadPoolBuilder<S>,
180) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
181where
182    S: ThreadSpawn,
183{
184    set_global_registry(|| Registry::new(builder))
185}
186
187/// Starts the worker threads (if that has not already happened)
188/// by creating a registry with the given callback.
189fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
190where
191    F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
192{
193    let mut result = Err(ThreadPoolBuildError::new(ErrorKind::GlobalPoolAlreadyInitialized));
194
195    THE_REGISTRY_SET.call_once(|| {
196        result = registry().map(|registry: Arc<Registry>| {
197            // SAFETY: this is the only mutable access to `THE_REGISTRY`, thanks to `Once`, and
198            // `global_registry()` only takes a shared reference **after** this `call_once`.
199            unsafe {
200                ptr::addr_of_mut!(THE_REGISTRY).write(Some(registry));
201                (*ptr::addr_of!(THE_REGISTRY)).as_ref().unwrap_unchecked()
202            }
203        })
204    });
205
206    result
207}
208
209fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
210    let result = Registry::new(ThreadPoolBuilder::new());
211
212    // If we're running in an environment that doesn't support threads at all, we can fall back to
213    // using the current thread alone. This is crude, and probably won't work for non-blocking
214    // calls like `spawn` or `broadcast_spawn`, but a lot of stuff does work fine.
215    //
216    // Notably, this allows current WebAssembly targets to work even though their threading support
217    // is stubbed out, and we won't have to change anything if they do add real threading.
218    let unsupported = matches!(&result, Err(e) if e.is_unsupported());
219    if unsupported && WorkerThread::current().is_null() {
220        let builder = ThreadPoolBuilder::new().num_threads(1).spawn_handler(|thread| {
221            // Rather than starting a new thread, we're just taking over the current thread
222            // *without* running the main loop, so we can still return from here.
223            // The WorkerThread is leaked, but we never shutdown the global pool anyway.
224            let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
225            let registry = &*worker_thread.registry;
226            let index = worker_thread.index;
227
228            unsafe {
229                WorkerThread::set_current(worker_thread);
230
231                // let registry know we are ready to do work
232                Latch::set(&registry.thread_infos[index].primed);
233            }
234
235            Ok(())
236        });
237
238        let fallback_result = Registry::new(builder);
239        if fallback_result.is_ok() {
240            return fallback_result;
241        }
242    }
243
244    result
245}
246
247struct Terminator<'a>(&'a Arc<Registry>);
248
249impl<'a> Drop for Terminator<'a> {
250    fn drop(&mut self) {
251        self.0.terminate()
252    }
253}
254
255impl Registry {
256    pub(super) fn new<S>(
257        mut builder: ThreadPoolBuilder<S>,
258    ) -> Result<Arc<Self>, ThreadPoolBuildError>
259    where
260        S: ThreadSpawn,
261    {
262        // Soft-limit the number of threads that we can actually support.
263        let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads());
264
265        let breadth_first = builder.get_breadth_first();
266
267        let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
268            .map(|_| {
269                let worker = if breadth_first { Worker::new_fifo() } else { Worker::new_lifo() };
270
271                let stealer = worker.stealer();
272                (worker, stealer)
273            })
274            .unzip();
275
276        let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads)
277            .map(|_| {
278                let worker = Worker::new_fifo();
279                let stealer = worker.stealer();
280                (worker, stealer)
281            })
282            .unzip();
283
284        let registry = Arc::new(Registry {
285            thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
286            sleep: Sleep::new(n_threads),
287            injected_jobs: Injector::new(),
288            broadcasts: Mutex::new(broadcasts),
289            terminate_count: AtomicUsize::new(1),
290            panic_handler: builder.take_panic_handler(),
291            deadlock_handler: builder.take_deadlock_handler(),
292            start_handler: builder.take_start_handler(),
293            exit_handler: builder.take_exit_handler(),
294            acquire_thread_handler: builder.take_acquire_thread_handler(),
295            release_thread_handler: builder.take_release_thread_handler(),
296        });
297
298        // If we return early or panic, make sure to terminate existing threads.
299        let t1000 = Terminator(&registry);
300
301        for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
302            let thread = ThreadBuilder {
303                name: builder.get_thread_name(index),
304                stack_size: builder.get_stack_size(),
305                registry: Arc::clone(&registry),
306                worker,
307                stealer,
308                index,
309            };
310            if let Err(e) = builder.get_spawn_handler().spawn(thread) {
311                return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
312            }
313        }
314
315        // Returning normally now, without termination.
316        mem::forget(t1000);
317
318        Ok(registry)
319    }
320
321    pub fn current() -> Arc<Registry> {
322        unsafe {
323            let worker_thread = WorkerThread::current();
324            let registry = if worker_thread.is_null() {
325                global_registry()
326            } else {
327                &(*worker_thread).registry
328            };
329            Arc::clone(registry)
330        }
331    }
332
333    /// Returns the number of threads in the current registry. This
334    /// is better than `Registry::current().num_threads()` because it
335    /// avoids incrementing the `Arc`.
336    pub(super) fn current_num_threads() -> usize {
337        unsafe {
338            let worker_thread = WorkerThread::current();
339            if worker_thread.is_null() {
340                global_registry().num_threads()
341            } else {
342                (*worker_thread).registry.num_threads()
343            }
344        }
345    }
346
347    /// Returns the current `WorkerThread` if it's part of this `Registry`.
348    pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
349        unsafe {
350            let worker = WorkerThread::current().as_ref()?;
351            if worker.registry().id() == self.id() { Some(worker) } else { None }
352        }
353    }
354
355    /// Returns an opaque identifier for this registry.
356    pub(super) fn id(&self) -> RegistryId {
357        // We can rely on `self` not to change since we only ever create
358        // registries that are boxed up in an `Arc` (see `new()` above).
359        RegistryId { addr: self as *const Self as usize }
360    }
361
362    pub(super) fn num_threads(&self) -> usize {
363        self.thread_infos.len()
364    }
365
366    pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
367        if let Err(err) = unwind::halt_unwinding(f) {
368            // If there is no handler, or if that handler itself panics, then we abort.
369            let abort_guard = unwind::AbortIfPanic;
370            if let Some(ref handler) = self.panic_handler {
371                handler(err);
372                mem::forget(abort_guard);
373            }
374        }
375    }
376
377    /// Waits for the worker threads to get up and running. This is
378    /// meant to be used for benchmarking purposes, primarily, so that
379    /// you can get more consistent numbers by having everything
380    /// "ready to go".
381    pub(super) fn wait_until_primed(&self) {
382        for info in &self.thread_infos {
383            info.primed.wait();
384        }
385    }
386
387    /// Waits for the worker threads to stop. This is used for testing
388    /// -- so we can check that termination actually works.
389    pub(super) fn wait_until_stopped(&self) {
390        self.release_thread();
391        for info in &self.thread_infos {
392            info.stopped.wait();
393        }
394        self.acquire_thread();
395    }
396
397    pub(crate) fn acquire_thread(&self) {
398        if let Some(ref acquire_thread_handler) = self.acquire_thread_handler {
399            acquire_thread_handler();
400        }
401    }
402
403    pub(crate) fn release_thread(&self) {
404        if let Some(ref release_thread_handler) = self.release_thread_handler {
405            release_thread_handler();
406        }
407    }
408
409    /// ////////////////////////////////////////////////////////////////////////
410    /// MAIN LOOP
411    ///
412    /// So long as all of the worker threads are hanging out in their
413    /// top-level loop, there is no work to be done.
414
415    /// Push a job into the given `registry`. If we are running on a
416    /// worker thread for the registry, this will push onto the
417    /// deque. Else, it will inject from the outside (which is slower).
418    pub(super) fn inject_or_push(&self, job_ref: JobRef) {
419        let worker_thread = WorkerThread::current();
420        unsafe {
421            if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
422                (*worker_thread).push(job_ref);
423            } else {
424                self.inject(job_ref);
425            }
426        }
427    }
428
429    /// Push a job into the "external jobs" queue; it will be taken by
430    /// whatever worker has nothing to do. Use this if you know that
431    /// you are not on a worker of this registry.
432    pub(super) fn inject(&self, injected_job: JobRef) {
433        // It should not be possible for `state.terminate` to be true
434        // here. It is only set to true when the user creates (and
435        // drops) a `ThreadPool`; and, in that case, they cannot be
436        // calling `inject()` later, since they dropped their
437        // `ThreadPool`.
438        debug_assert_ne!(
439            self.terminate_count.load(Ordering::Acquire),
440            0,
441            "inject() sees state.terminate as true"
442        );
443
444        let queue_was_empty = self.injected_jobs.is_empty();
445
446        self.injected_jobs.push(injected_job);
447        self.sleep.new_injected_jobs(1, queue_was_empty);
448    }
449
450    pub(crate) fn has_injected_job(&self) -> bool {
451        !self.injected_jobs.is_empty()
452    }
453
454    fn pop_injected_job(&self) -> Option<JobRef> {
455        loop {
456            match self.injected_jobs.steal() {
457                Steal::Success(job) => return Some(job),
458                Steal::Empty => return None,
459                Steal::Retry => {}
460            }
461        }
462    }
463
464    /// Push a job into each thread's own "external jobs" queue; it will be
465    /// executed only on that thread, when it has nothing else to do locally,
466    /// before it tries to steal other work.
467    ///
468    /// **Panics** if not given exactly as many jobs as there are threads.
469    pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) {
470        assert_eq!(self.num_threads(), injected_jobs.len());
471        {
472            let broadcasts = self.broadcasts.lock().unwrap();
473
474            // It should not be possible for `state.terminate` to be true
475            // here. It is only set to true when the user creates (and
476            // drops) a `ThreadPool`; and, in that case, they cannot be
477            // calling `inject_broadcast()` later, since they dropped their
478            // `ThreadPool`.
479            debug_assert_ne!(
480                self.terminate_count.load(Ordering::Acquire),
481                0,
482                "inject_broadcast() sees state.terminate as true"
483            );
484
485            assert_eq!(broadcasts.len(), injected_jobs.len());
486            for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
487                worker.push(job_ref);
488            }
489        }
490        for i in 0..self.num_threads() {
491            self.sleep.notify_worker_latch_is_set(i);
492        }
493    }
494
495    /// If already in a worker-thread of this registry, just execute `op`.
496    /// Otherwise, inject `op` in this thread-pool. Either way, block until `op`
497    /// completes and return its return value. If `op` panics, that panic will
498    /// be propagated as well. The second argument indicates `true` if injection
499    /// was performed, `false` if executed directly.
500    pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
501    where
502        OP: FnOnce(&WorkerThread, bool) -> R + Send,
503        R: Send,
504    {
505        unsafe {
506            let worker_thread = WorkerThread::current();
507            if worker_thread.is_null() {
508                self.in_worker_cold(op)
509            } else if (*worker_thread).registry().id() != self.id() {
510                self.in_worker_cross(&*worker_thread, op)
511            } else {
512                // Perfectly valid to give them a `&T`: this is the
513                // current thread, so we know the data structure won't be
514                // invalidated until we return.
515                op(&*worker_thread, false)
516            }
517        }
518    }
519
520    #[cold]
521    unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
522    where
523        OP: FnOnce(&WorkerThread, bool) -> R + Send,
524        R: Send,
525    {
526        thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
527
528        LOCK_LATCH.with(|l| {
529            // This thread isn't a member of *any* thread pool, so just block.
530            debug_assert!(WorkerThread::current().is_null());
531            let job = StackJob::new(
532                Tlv::null(),
533                |injected| {
534                    let worker_thread = WorkerThread::current();
535                    assert!(injected && !worker_thread.is_null());
536                    op(unsafe { &*worker_thread }, true)
537                },
538                LatchRef::new(l),
539            );
540            self.inject(unsafe { job.as_job_ref() });
541            self.release_thread();
542            job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.
543            self.acquire_thread();
544
545            unsafe { job.into_result() }
546        })
547    }
548
549    #[cold]
550    unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
551    where
552        OP: FnOnce(&WorkerThread, bool) -> R + Send,
553        R: Send,
554    {
555        // This thread is a member of a different pool, so let it process
556        // other work while waiting for this `op` to complete.
557        debug_assert!(current_thread.registry().id() != self.id());
558        let latch = SpinLatch::cross(current_thread);
559        let job = StackJob::new(
560            Tlv::null(),
561            |injected| {
562                let worker_thread = WorkerThread::current();
563                assert!(injected && !worker_thread.is_null());
564                op(unsafe { &*worker_thread }, true)
565            },
566            latch,
567        );
568        self.inject(unsafe { job.as_job_ref() });
569        unsafe { current_thread.wait_until(&job.latch) };
570        unsafe { job.into_result() }
571    }
572
573    /// Increments the terminate counter. This increment should be
574    /// balanced by a call to `terminate`, which will decrement. This
575    /// is used when spawning asynchronous work, which needs to
576    /// prevent the registry from terminating so long as it is active.
577    ///
578    /// Note that blocking functions such as `join` and `scope` do not
579    /// need to concern themselves with this fn; their context is
580    /// responsible for ensuring the current thread-pool will not
581    /// terminate until they return.
582    ///
583    /// The global thread-pool always has an outstanding reference
584    /// (the initial one). Custom thread-pools have one outstanding
585    /// reference that is dropped when the `ThreadPool` is dropped:
586    /// since installing the thread-pool blocks until any joins/scopes
587    /// complete, this ensures that joins/scopes are covered.
588    ///
589    /// The exception is `::spawn()`, which can create a job outside
590    /// of any blocking scope. In that case, the job itself holds a
591    /// terminate count and is responsible for invoking `terminate()`
592    /// when finished.
593    pub(super) fn increment_terminate_count(&self) {
594        let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
595        debug_assert!(previous != 0, "registry ref count incremented from zero");
596        assert!(previous != usize::MAX, "overflow in registry ref count");
597    }
598
599    /// Signals that the thread-pool which owns this registry has been
600    /// dropped. The worker threads will gradually terminate, once any
601    /// extant work is completed.
602    pub(super) fn terminate(&self) {
603        if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
604            for (i, thread_info) in self.thread_infos.iter().enumerate() {
605                unsafe { OnceLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
606            }
607        }
608    }
609
610    /// Notify the worker that the latch they are sleeping on has been "set".
611    pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
612        self.sleep.notify_worker_latch_is_set(target_worker_index);
613    }
614}
615
616/// Mark a Rayon worker thread as blocked. This triggers the deadlock handler
617/// if no other worker thread is active
618#[inline]
619pub fn mark_blocked() {
620    let worker_thread = WorkerThread::current();
621    assert!(!worker_thread.is_null());
622    unsafe {
623        let registry = &(*worker_thread).registry;
624        registry.sleep.mark_blocked(&registry.deadlock_handler)
625    }
626}
627
628/// Mark a previously blocked Rayon worker thread as unblocked
629#[inline]
630pub fn mark_unblocked(registry: &Registry) {
631    registry.sleep.mark_unblocked()
632}
633
634#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
635pub(super) struct RegistryId {
636    addr: usize,
637}
638
639struct ThreadInfo {
640    /// Latch set once thread has started and we are entering into the
641    /// main loop. Used to wait for worker threads to become primed,
642    /// primarily of interest for benchmarking.
643    primed: LockLatch,
644
645    /// Latch is set once worker thread has completed. Used to wait
646    /// until workers have stopped; only used for tests.
647    stopped: LockLatch,
648
649    /// The latch used to signal that terminated has been requested.
650    /// This latch is *set* by the `terminate` method on the
651    /// `Registry`, once the registry's main "terminate" counter
652    /// reaches zero.
653    terminate: OnceLatch,
654
655    /// the "stealer" half of the worker's deque
656    stealer: Stealer<JobRef>,
657}
658
659impl ThreadInfo {
660    fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
661        ThreadInfo {
662            primed: LockLatch::new(),
663            stopped: LockLatch::new(),
664            terminate: OnceLatch::new(),
665            stealer,
666        }
667    }
668}
669
670/// ////////////////////////////////////////////////////////////////////////
671/// WorkerThread identifiers
672
673pub(super) struct WorkerThread {
674    /// the "worker" half of our local deque
675    worker: Worker<JobRef>,
676
677    /// the "stealer" half of the worker's broadcast deque
678    stealer: Stealer<JobRef>,
679
680    /// local queue used for `spawn_fifo` indirection
681    fifo: JobFifo,
682
683    pub(crate) index: usize,
684
685    /// A weak random number generator.
686    rng: XorShift64Star,
687
688    pub(crate) registry: Arc<Registry>,
689}
690
691// This is a bit sketchy, but basically: the WorkerThread is
692// allocated on the stack of the worker on entry and stored into this
693// thread local variable. So it will remain valid at least until the
694// worker is fully unwound. Using an unsafe pointer avoids the need
695// for a RefCell<T> etc.
696thread_local! {
697    static WORKER_THREAD_STATE: Cell<*const WorkerThread> = const { Cell::new(ptr::null()) };
698}
699
700impl From<ThreadBuilder> for WorkerThread {
701    fn from(thread: ThreadBuilder) -> Self {
702        Self {
703            worker: thread.worker,
704            stealer: thread.stealer,
705            fifo: JobFifo::new(),
706            index: thread.index,
707            rng: XorShift64Star::new(),
708            registry: thread.registry,
709        }
710    }
711}
712
713impl Drop for WorkerThread {
714    fn drop(&mut self) {
715        // Undo `set_current`
716        WORKER_THREAD_STATE.with(|t| {
717            assert!(t.get().eq(&(self as *const _)));
718            t.set(ptr::null());
719        });
720    }
721}
722
723impl WorkerThread {
724    /// Gets the `WorkerThread` index for the current thread; returns
725    /// NULL if this is not a worker thread. This pointer is valid
726    /// anywhere on the current thread.
727    #[inline]
728    pub(super) fn current() -> *const WorkerThread {
729        WORKER_THREAD_STATE.with(Cell::get)
730    }
731
732    /// Sets `self` as the worker thread index for the current thread.
733    /// This is done during worker thread startup.
734    unsafe fn set_current(thread: *const WorkerThread) {
735        WORKER_THREAD_STATE.with(|t| {
736            assert!(t.get().is_null());
737            t.set(thread);
738        });
739    }
740
741    /// Returns the registry that owns this worker thread.
742    #[inline]
743    pub(super) fn registry(&self) -> &Arc<Registry> {
744        &self.registry
745    }
746
747    /// Our index amongst the worker threads (ranges from `0..self.num_threads()`).
748    #[inline]
749    pub(super) fn index(&self) -> usize {
750        self.index
751    }
752
753    #[inline]
754    pub(super) unsafe fn push(&self, job: JobRef) {
755        let queue_was_empty = self.worker.is_empty();
756        self.worker.push(job);
757        self.registry.sleep.new_internal_jobs(1, queue_was_empty);
758    }
759
760    #[inline]
761    pub(super) unsafe fn push_fifo(&self, job: JobRef) {
762        unsafe { self.push(self.fifo.push(job)) };
763    }
764
765    #[inline]
766    pub(super) fn local_deque_is_empty(&self) -> bool {
767        self.worker.is_empty()
768    }
769
770    /// Attempts to obtain a "local" job -- typically this means
771    /// popping from the top of the stack, though if we are configured
772    /// for breadth-first execution, it would mean dequeuing from the
773    /// bottom.
774    #[inline]
775    pub(super) fn take_local_job(&self) -> Option<JobRef> {
776        let popped_job = self.worker.pop();
777
778        if popped_job.is_some() {
779            return popped_job;
780        }
781
782        loop {
783            match self.stealer.steal() {
784                Steal::Success(job) => return Some(job),
785                Steal::Empty => return None,
786                Steal::Retry => {}
787            }
788        }
789    }
790
791    pub(super) fn has_injected_job(&self) -> bool {
792        !self.stealer.is_empty() || self.registry.has_injected_job()
793    }
794
795    /// Wait until the latch is set. Try to keep busy by popping and
796    /// stealing tasks as necessary.
797    #[inline]
798    pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
799        let latch = latch.as_core_latch();
800        if !latch.probe() {
801            unsafe { self.wait_until_cold(latch) };
802        }
803    }
804
805    #[cold]
806    unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
807        // the code below should swallow all panics and hence never
808        // unwind; but if something does wrong, we want to abort,
809        // because otherwise other code in rayon may assume that the
810        // latch has been signaled, and that can lead to random memory
811        // accesses, which would be *very bad*
812        let abort_guard = unwind::AbortIfPanic;
813
814        'outer: while !latch.probe() {
815            // Check for local work *before* we start marking ourself idle,
816            // especially to avoid modifying shared sleep state.
817            if let Some(job) = self.take_local_job() {
818                unsafe { self.execute(job) };
819                continue;
820            }
821
822            let mut idle_state = self.registry.sleep.start_looking(self.index);
823            while !latch.probe() {
824                if let Some(job) = self.find_work() {
825                    self.registry.sleep.work_found();
826                    unsafe { self.execute(job) };
827                    // The job might have injected local work, so go back to the outer loop.
828                    continue 'outer;
829                } else {
830                    self.registry.sleep.no_work_found(&mut idle_state, latch, &self)
831                }
832            }
833
834            // If we were sleepy, we are not anymore. We "found work" --
835            // whatever the surrounding thread was doing before it had to wait.
836            self.registry.sleep.work_found();
837            break;
838        }
839
840        mem::forget(abort_guard); // successful execution, do not abort
841    }
842
843    unsafe fn wait_until_out_of_work(&self) {
844        debug_assert_eq!(self as *const _, WorkerThread::current());
845        let registry = &*self.registry;
846        let index = self.index;
847
848        registry.acquire_thread();
849        unsafe { self.wait_until(&registry.thread_infos[index].terminate) };
850
851        // Should not be any work left in our queue.
852        debug_assert!(self.take_local_job().is_none());
853
854        // Let registry know we are done
855        unsafe { Latch::set(&registry.thread_infos[index].stopped) };
856    }
857
858    fn find_work(&self) -> Option<JobRef> {
859        // Try to find some work to do. We give preference first
860        // to things in our local deque, then in other workers
861        // deques, and finally to injected jobs from the
862        // outside. The idea is to finish what we started before
863        // we take on something new.
864        self.take_local_job().or_else(|| self.steal()).or_else(|| self.registry.pop_injected_job())
865    }
866
867    pub(super) fn yield_now(&self) -> Yield {
868        match self.find_work() {
869            Some(job) => unsafe {
870                self.execute(job);
871                Yield::Executed
872            },
873            None => Yield::Idle,
874        }
875    }
876
877    pub(super) fn yield_local(&self) -> Yield {
878        match self.take_local_job() {
879            Some(job) => unsafe {
880                self.execute(job);
881                Yield::Executed
882            },
883            None => Yield::Idle,
884        }
885    }
886
887    #[inline]
888    pub(super) unsafe fn execute(&self, job: JobRef) {
889        unsafe { job.execute() };
890    }
891
892    /// Try to steal a single job and return it.
893    ///
894    /// This should only be done as a last resort, when there is no
895    /// local work to do.
896    fn steal(&self) -> Option<JobRef> {
897        // we only steal when we don't have any work to do locally
898        debug_assert!(self.local_deque_is_empty());
899
900        // otherwise, try to steal
901        let thread_infos = &self.registry.thread_infos.as_slice();
902        let num_threads = thread_infos.len();
903        if num_threads <= 1 {
904            return None;
905        }
906
907        loop {
908            let mut retry = false;
909            let start = self.rng.next_usize(num_threads);
910            let job = (start..num_threads)
911                .chain(0..start)
912                .filter(move |&i| i != self.index)
913                .find_map(|victim_index| {
914                    let victim = &thread_infos[victim_index];
915                    match victim.stealer.steal() {
916                        Steal::Success(job) => Some(job),
917                        Steal::Empty => None,
918                        Steal::Retry => {
919                            retry = true;
920                            None
921                        }
922                    }
923                });
924            if job.is_some() || !retry {
925                return job;
926            }
927        }
928    }
929}
930
931/// ////////////////////////////////////////////////////////////////////////
932
933unsafe fn main_loop(thread: ThreadBuilder) {
934    let worker_thread = &WorkerThread::from(thread);
935    unsafe { WorkerThread::set_current(worker_thread) };
936    let registry = &*worker_thread.registry;
937    let index = worker_thread.index;
938
939    // let registry know we are ready to do work
940    unsafe { Latch::set(&registry.thread_infos[index].primed) };
941
942    // Worker threads should not panic. If they do, just abort, as the
943    // internal state of the threadpool is corrupted. Note that if
944    // **user code** panics, we should catch that and redirect.
945    let abort_guard = unwind::AbortIfPanic;
946
947    // Inform a user callback that we started a thread.
948    if let Some(ref handler) = registry.start_handler {
949        registry.catch_unwind(|| handler(index));
950    }
951
952    unsafe { worker_thread.wait_until_out_of_work() };
953
954    // Normal termination, do not abort.
955    mem::forget(abort_guard);
956
957    // Inform a user callback that we exited a thread.
958    if let Some(ref handler) = registry.exit_handler {
959        registry.catch_unwind(|| handler(index));
960        // We're already exiting the thread, there's nothing else to do.
961    }
962
963    registry.release_thread();
964}
965
966/// If already in a worker-thread, just execute `op`. Otherwise,
967/// execute `op` in the default thread-pool. Either way, block until
968/// `op` completes and return its return value. If `op` panics, that
969/// panic will be propagated as well. The second argument indicates
970/// `true` if injection was performed, `false` if executed directly.
971pub(super) fn in_worker<OP, R>(op: OP) -> R
972where
973    OP: FnOnce(&WorkerThread, bool) -> R + Send,
974    R: Send,
975{
976    unsafe {
977        let owner_thread = WorkerThread::current();
978        if !owner_thread.is_null() {
979            // Perfectly valid to give them a `&T`: this is the
980            // current thread, so we know the data structure won't be
981            // invalidated until we return.
982            op(&*owner_thread, false)
983        } else {
984            global_registry().in_worker(op)
985        }
986    }
987}
988
989/// [xorshift*] is a fast pseudorandom number generator which will
990/// even tolerate weak seeding, as long as it's not zero.
991///
992/// [xorshift*]: https://en.wikipedia.org/wiki/Xorshift#xorshift*
993struct XorShift64Star {
994    state: Cell<u64>,
995}
996
997impl XorShift64Star {
998    fn new() -> Self {
999        // Any non-zero seed will do -- this uses the hash of a global counter.
1000        let mut seed = 0;
1001        while seed == 0 {
1002            let mut hasher = DefaultHasher::new();
1003            static COUNTER: AtomicUsize = AtomicUsize::new(0);
1004            hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
1005            seed = hasher.finish();
1006        }
1007
1008        XorShift64Star { state: Cell::new(seed) }
1009    }
1010
1011    fn next(&self) -> u64 {
1012        let mut x = self.state.get();
1013        debug_assert_ne!(x, 0);
1014        x ^= x >> 12;
1015        x ^= x << 25;
1016        x ^= x >> 27;
1017        self.state.set(x);
1018        x.wrapping_mul(0x2545_f491_4f6c_dd1d)
1019    }
1020
1021    /// Return a value from `0..n`.
1022    fn next_usize(&self, n: usize) -> usize {
1023        (self.next() % n as u64) as usize
1024    }
1025}