piranha  0.10
thread_pool.hpp
1 /* Copyright 2009-2017 Francesco Biscani (bluescarni@gmail.com)
2 
3 This file is part of the Piranha library.
4 
5 The Piranha library is free software; you can redistribute it and/or modify
6 it under the terms of either:
7 
8  * the GNU Lesser General Public License as published by the Free
9  Software Foundation; either version 3 of the License, or (at your
10  option) any later version.
11 
12 or
13 
14  * the GNU General Public License as published by the Free Software
15  Foundation; either version 3 of the License, or (at your option) any
16  later version.
17 
18 or both in parallel, as here.
19 
20 The Piranha library is distributed in the hope that it will be useful, but
21 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
22 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
23 for more details.
24 
25 You should have received copies of the GNU General Public License and the
26 GNU Lesser General Public License along with the Piranha library. If not,
27 see https://www.gnu.org/licenses/. */
28 
29 #ifndef PIRANHA_THREAD_POOL_HPP
30 #define PIRANHA_THREAD_POOL_HPP
31 
32 #include <algorithm>
33 #include <atomic>
34 #include <boost/lexical_cast.hpp>
35 #include <condition_variable>
36 #include <cstdlib>
37 #include <functional>
38 #include <future>
39 // See old usage of cout below.
40 // #include <iostream>
41 #include <list>
42 #include <memory>
43 #include <mutex>
44 #include <queue>
45 #include <stdexcept>
46 #include <string>
47 #include <thread>
48 #include <type_traits>
49 #include <unordered_set>
50 #include <utility>
51 #include <vector>
52 
53 #include <piranha/config.hpp>
54 #include <piranha/detail/atomic_lock_guard.hpp>
55 #include <piranha/detail/mpfr.hpp>
56 #include <piranha/exceptions.hpp>
57 #include <piranha/mp_integer.hpp>
58 #include <piranha/runtime_info.hpp>
59 #include <piranha/thread_management.hpp>
60 #include <piranha/type_traits.hpp>
61 
62 namespace piranha
63 {
64 
65 inline namespace impl
66 {
67 
68 // Task queue class. Inspired by:
69 // https://github.com/progschj/ThreadPool
70 struct task_queue {
71  task_queue(unsigned n, bool bind) : m_stop(false)
72  {
73  auto runner = [this, n, bind]() {
74  if (bind) {
75  try {
76  bind_to_proc(n);
77  } catch (...) {
78  // Don't stop if we cannot bind.
79  // NOTE: logging candidate.
80  }
81  }
82  try {
83  while (true) {
84  std::unique_lock<std::mutex> lock(this->m_mutex);
85  while (!this->m_stop && this->m_tasks.empty()) {
86  // Need to wait for something to happen only if the task
87  // list is empty and we are not stopping.
88  // NOTE: wait will be noexcept in C++14.
89  this->m_cond.wait(lock);
90  }
91  if (this->m_stop && this->m_tasks.empty()) {
92  // If the stop flag was set, and we do not have more tasks,
93  // just exit.
94  break;
95  }
96  // NOTE: move constructor of std::function could throw, unfortunately.
97  std::function<void()> task(std::move(this->m_tasks.front()));
98  this->m_tasks.pop();
99  lock.unlock();
100  task();
101  }
102  } catch (...) {
103  // The errors we could get here are:
104  // - threading primitives,
105  // - move-construction of std::function,
106  // - queue popping (I guess unlikely, as the destructor of std::function
107  // is noexcept).
108  // In any case, not much that can be done to recover from this, better to abort.
109  // NOTE: logging candidate.
110  std::abort();
111  }
112  // Free the MPFR caches.
113  ::mpfr_free_cache();
114  };
115  m_thread = std::thread(std::move(runner));
116  }
117  ~task_queue()
118  {
119  // NOTE: logging candidate (catch any exception,
120  // log it and abort as there is not much we can do).
121  try {
122  stop();
123  } catch (...) {
124  std::abort();
125  }
126  }
127  // Small utility to remove reference_wrapper.
128  template <typename T>
129  struct unwrap_ref {
130  using type = T;
131  };
132  template <typename T>
133  struct unwrap_ref<std::reference_wrapper<T>> {
134  using type = T;
135  };
136  template <typename T>
137  using unwrap_ref_t = typename unwrap_ref<T>::type;
138  // NOTE: the functor F will be forwarded to std::bind in order to create a nullary wrapper. The nullary wrapper
139  // will create copies of the input arguments, and it will then pass these copies as lvalue refs to a copy of the
140  // original functor when the call operator is invoked (with special handling of reference wrappers). Thus, the
141  // real invocation of F is not simply F(args), but this more complicated type below.
142  // NOTE: this is one place where it seems we really want decay instead of uncvref, as decay is applied
143  // also by std::bind() to F.
144  template <typename F, typename... Args>
145  using f_ret_type = decltype(std::declval<decay_t<F> &>()(std::declval<unwrap_ref_t<uncvref_t<Args>> &>()...));
146  // enqueue() will be enabled if:
147  // - f_ret_type is a valid type (checked in the return type),
148  // - we can construct the nullary wrapper via std::bind() (this requires F and Args to be ctible from the input
149  // arguments),
150  // - we can build a packaged_task from the nullary wrapper (requires F and Args to be move/copy ctible),
151  // - the return type of F is returnable.
152  template <typename F, typename... Args>
153  using enabler
154  = enable_if_t<conjunction<std::is_constructible<decay_t<F>, F>, std::is_constructible<uncvref_t<Args>, Args>...,
155  disjunction<std::is_copy_constructible<decay_t<F>>,
156  std::is_move_constructible<decay_t<F>>>,
157  conjunction<disjunction<std::is_copy_constructible<uncvref_t<Args>>,
158  std::is_move_constructible<uncvref_t<Args>>>>...,
159  is_returnable<f_ret_type<F, Args...>>>::value,
160  int>;
161  // Main enqueue function.
162  template <typename F, typename... Args, enabler<F &&, Args &&...> = 0>
163  std::future<f_ret_type<F &&, Args &&...>> enqueue(F &&f, Args &&... args)
164  {
165  using ret_type = f_ret_type<F &&, Args &&...>;
166  using p_task_type = std::packaged_task<ret_type()>;
167  // NOTE: here we have a multi-stage construction of the task:
168  // - std::bind() turns F into a nullary functor,
169  // - std::packaged_task gives us the std::future machinery,
170  // - std::function (in m_tasks) gives the uniform type interface via type erasure.
171  auto task = std::make_shared<p_task_type>(std::bind(std::forward<F>(f), std::forward<Args>(args)...));
172  std::future<ret_type> res = task->get_future();
173  {
174  std::unique_lock<std::mutex> lock(m_mutex);
175  if (unlikely(m_stop)) {
176  // Enqueueing is not allowed if the queue is stopped.
177  piranha_throw(std::runtime_error, "cannot enqueue task while the task queue is stopping");
178  }
179  m_tasks.push([task]() { (*task)(); });
180  }
181  // NOTE: notify_one is noexcept.
182  m_cond.notify_one();
183  return res;
184  }
185  // NOTE: we call this only from dtor, it is here in order to be able to test it.
186  // So the exception handling in dtor will suffice, keep it in mind if things change.
187  void stop()
188  {
189  {
190  std::unique_lock<std::mutex> lock(m_mutex);
191  if (m_stop) {
192  // Already stopped.
193  return;
194  }
195  m_stop = true;
196  }
197  // Notify the thread that queue has been stopped, wait for it
198  // to consume the remaining tasks and exit.
199  m_cond.notify_one();
200  m_thread.join();
201  }
202  // Data members.
203  bool m_stop;
204  std::condition_variable m_cond;
205  std::mutex m_mutex;
206  std::queue<std::function<void()>> m_tasks;
207  std::thread m_thread;
208 };
209 
210 // Type to represent thread queues: a vector of task queues paired with a set of thread ids.
211 using thread_queues_t = std::pair<std::vector<std::unique_ptr<task_queue>>, std::unordered_set<std::thread::id>>;
212 
213 inline thread_queues_t get_initial_thread_queues()
214 {
215  // NOTE: we used to have this print statement here, but it turns out that
216  // in certain setups the cout object is not yet constructed at this point,
217  // and a segfault is generated. I *think* it is possible to enforce the creation
218  // of cout via construction of an init object:
219  // http://en.cppreference.com/w/cpp/io/ios_base/Init
220  // However, this is hardly essential. Let's leave this disabled for the moment.
221  // std::cout << "Initializing the thread pool.\n";
222  thread_queues_t retval;
223  // Create the vector of queues.
224  const unsigned candidate = runtime_info::get_hardware_concurrency(), hc = (candidate > 0u) ? candidate : 1u;
225  retval.first.reserve(static_cast<decltype(retval.first.size())>(hc));
226  for (unsigned i = 0u; i < hc; ++i) {
227  // NOTE: thread binding is disabled on startup.
228  retval.first.emplace_back(::new task_queue(i, false));
229  }
230  // Generate the set of thread IDs.
231  for (const auto &ptr : retval.first) {
232  auto p = retval.second.insert(ptr->m_thread.get_id());
233  (void)p;
234  piranha_assert(p.second);
235  }
236  return retval;
237 }
238 
239 template <typename = void>
240 struct thread_pool_base {
241  static thread_queues_t s_queues;
242  static bool s_bind;
243  static std::atomic_flag s_atf;
244 };
245 
246 template <typename T>
247 thread_queues_t thread_pool_base<T>::s_queues = get_initial_thread_queues();
248 
249 template <typename T>
250 std::atomic_flag thread_pool_base<T>::s_atf = ATOMIC_FLAG_INIT;
251 
252 template <typename T>
253 bool thread_pool_base<T>::s_bind = false;
254 
255 template <typename>
256 void thread_pool_shutdown();
257 }
258 
260 
274 // \todo work around MSVC bug in destruction of statically allocated threads (if needed once we support MSVC), as per:
275 // http://stackoverflow.com/questions/10915233/stdthreadjoin-hangs-if-called-after-main-exits-when-using-vs2012-rc
276 // detach() and wait as a workaround?
277 // \todo try to understand if we can suppress the future list class below in favour of STL-like algorithms.
278 template <typename T = void>
279 class thread_pool_ : private thread_pool_base<>
280 {
281  friend void piranha::impl::thread_pool_shutdown<T>();
282  using base = thread_pool_base<>;
283  // Enabler for use_threads.
284  template <typename Int>
285  using use_threads_enabler
286  = enable_if_t<disjunction<std::is_same<Int, integer>,
287  conjunction<std::is_integral<Int>, std::is_unsigned<Int>>>::value,
288  int>;
289  // The return type for enqueue().
290  template <typename F, typename... Args>
291  using enqueue_t = decltype(std::declval<task_queue &>().enqueue(std::declval<F>(), std::declval<Args>()...));
292 
293 public:
295 
321  template <typename F, typename... Args>
322  static enqueue_t<F &&, Args &&...> enqueue(unsigned n, F &&f, Args &&... args)
323  {
324  detail::atomic_lock_guard lock(s_atf);
325  if (unlikely(n >= s_queues.first.size())) {
326  piranha_throw(std::invalid_argument, "the thread index " + std::to_string(n)
327  + " is out of range, the thread pool contains only "
328  + std::to_string(s_queues.first.size()) + " threads");
329  }
330  return base::s_queues.first[static_cast<decltype(base::s_queues.first.size())>(n)]->enqueue(
331  std::forward<F>(f), std::forward<Args>(args)...);
332  }
334 
337  static unsigned size()
338  {
339  detail::atomic_lock_guard lock(s_atf);
340  return static_cast<unsigned>(base::s_queues.first.size());
341  }
342 
343 private:
344  // Helper function to create 'new_size' new queues with thread binding set to 'bind'.
345  static thread_queues_t create_new_queues(unsigned new_size, bool bind)
346  {
347  thread_queues_t new_queues;
348  // Create the task queues.
349  new_queues.first.reserve(static_cast<decltype(new_queues.first.size())>(new_size));
350  for (auto i = 0u; i < new_size; ++i) {
351  new_queues.first.emplace_back(::new task_queue(i, bind));
352  }
353  // Fill in the thread ids set.
354  for (const auto &ptr : new_queues.first) {
355  auto p = new_queues.second.insert(ptr->m_thread.get_id());
356  (void)p;
357  piranha_assert(p.second);
358  }
359  return new_queues;
360  }
361  // Shutdown. This can be used to stop the threads at program shutdown.
362  static void shutdown()
363  {
364  thread_queues_t new_queues;
365  detail::atomic_lock_guard lock(s_atf);
366  new_queues.swap(base::s_queues);
367  }
368 
369 public:
371 
383  static void resize(unsigned new_size)
384  {
385  if (unlikely(new_size == 0u)) {
386  piranha_throw(std::invalid_argument, "cannot resize the thread pool to zero");
387  }
388  // NOTE: need to lock here as we are reading the s_bind member.
389  detail::atomic_lock_guard lock(s_atf);
390  auto new_queues = create_new_queues(new_size, base::s_bind);
391  // NOTE: here the allocator is not swapped, as std::allocator won't propagate on swap.
392  // Besides, all instances of std::allocator are equal, so the operation is well-defined.
393  // http://en.cppreference.com/w/cpp/container/vector/swap
394  // This holds for both std::vector and std::unordered_set.
395  // If an exception gets actually thrown, no big deal.
396  // NOTE: the dtor of the queues is effectively noexcept, as the program will just abort in case of errors
397  // in the dtor.
398  new_queues.swap(base::s_queues);
399  }
401 
416  static void set_binding(bool flag)
417  {
418  detail::atomic_lock_guard lock(s_atf);
419  if (flag == base::s_bind) {
420  // Don't do anything if we are not changing the binding policy.
421  return;
422  }
423  auto new_queues = create_new_queues(static_cast<unsigned>(base::s_queues.first.size()), flag);
424  new_queues.swap(base::s_queues);
425  base::s_bind = flag;
426  }
428 
434  static bool get_binding()
435  {
436  detail::atomic_lock_guard lock(s_atf);
437  return base::s_bind;
438  }
440 
459  template <typename Int, use_threads_enabler<Int> = 0>
460  static unsigned use_threads(const Int &work_size, const Int &min_work_per_thread)
461  {
462  // Check input params.
463  if (unlikely(work_size <= Int(0))) {
464  piranha_throw(std::invalid_argument, "invalid value of " + boost::lexical_cast<std::string>(work_size)
465  + " for work size (it must be strictly positive)");
466  }
467  if (unlikely(min_work_per_thread <= Int(0))) {
468  piranha_throw(std::invalid_argument, "invalid value of "
469  + boost::lexical_cast<std::string>(min_work_per_thread)
470  + " for minimum work per thread (it must be strictly positive)");
471  }
472  detail::atomic_lock_guard lock(s_atf);
473  // Don't use threads if the calling thread belongs to the pool.
474  if (base::s_queues.second.find(std::this_thread::get_id()) != base::s_queues.second.end()) {
475  return 1u;
476  }
477  const auto n_threads = static_cast<unsigned>(base::s_queues.first.size());
478  piranha_assert(n_threads);
479  if (work_size / n_threads >= min_work_per_thread) {
480  // Enough work per thread, use them all.
481  return n_threads;
482  }
483  // Return a number of threads such that each thread consumes at least min_work_per_thread.
484  // Never return 0.
485  return static_cast<unsigned>(std::max(Int(1), static_cast<Int>(work_size / min_work_per_thread)));
486  }
487 };
488 
490 
494 
495 inline namespace impl
496 {
497 
498 template <typename>
499 inline void thread_pool_shutdown()
500 {
501  thread_pool::shutdown();
502 }
503 }
504 
506 
510 // NOTE: we could provide method to retrieve future values from get_all() using a vector (in case the future type
511 // is not void or a reference, in which case the get_all() method stays as it is).
512 template <typename T>
514 {
515  // Wait on a valid future, or abort.
516  static void wait_or_abort(const std::future<T> &fut)
517  {
518  piranha_assert(fut.valid());
519  try {
520  fut.wait();
521  } catch (...) {
522  // NOTE: logging candidate, with info from exception.
523  std::abort();
524  }
525  }
526 
527 public:
529 
532  future_list() = default;
534  future_list(const future_list &) = delete;
536  future_list(future_list &&) = delete;
537 
538 private:
539  future_list &operator=(const future_list &) = delete;
540  future_list &operator=(future_list &&) = delete;
541 
542 public:
544 
548  {
549  wait_all();
550  }
552 
561  void push_back(std::future<T> &&f)
562  {
563  // Push back empty future.
564  try {
565  m_list.emplace_back();
566  } catch (...) {
567  // If we get some error here, we want to make sure we wait on the future
568  // before escaping out.
569  // NOTE: calling wait() on an invalid future is UB.
570  if (f.valid()) {
571  wait_or_abort(f);
572  }
573  throw;
574  }
575  // This cannot throw.
576  m_list.back() = std::move(f);
577  }
579 
582  void wait_all()
583  {
584  for (auto &f : m_list) {
585  if (f.valid()) {
586  wait_or_abort(f);
587  }
588  }
589  }
591 
597  void get_all()
598  {
599  for (auto &f : m_list) {
600  // NOTE: std::future's valid() method is noexcept.
601  if (f.valid()) {
602  (void)f.get();
603  }
604  }
605  }
606 
607 private:
608  std::list<std::future<T>> m_list;
609 };
610 }
611 
612 #endif
void wait_all()
Wait on all the futures.
void bind_to_proc(unsigned n)
Bind thread to specific processor.
static unsigned use_threads(const Int &work_size, const Int &min_work_per_thread)
Compute number of threads to use.
Exceptions.
STL namespace.
void get_all()
Get all the futures.
static enqueue_t< F &&, Args &&... > enqueue(unsigned n, F &&f, Args &&... args)
Enqueue task.
#define piranha_throw(exception_type,...)
Exception-throwing macro.
Definition: exceptions.hpp:118
future_list()=default
Defaulted default constructor.
Class to store a list of futures.
Root piranha namespace.
Definition: array_key.hpp:52
Type traits.
static unsigned size()
Size.
~future_list()
Destructor.
static unsigned get_hardware_concurrency()
Hardware concurrency.
static void set_binding(bool flag)
Set the thread binding policy.
Static thread pool.
static void resize(unsigned new_size)
Change the number of threads.
static bool get_binding()
Get the thread binding policy.
void push_back(std::future< T > &&f)
Move-insert a future.