piranha  0.10
base_series_multiplier.hpp
1 /* Copyright 2009-2017 Francesco Biscani (bluescarni@gmail.com)
2 
3 This file is part of the Piranha library.
4 
5 The Piranha library is free software; you can redistribute it and/or modify
6 it under the terms of either:
7 
8  * the GNU Lesser General Public License as published by the Free
9  Software Foundation; either version 3 of the License, or (at your
10  option) any later version.
11 
12 or
13 
14  * the GNU General Public License as published by the Free Software
15  Foundation; either version 3 of the License, or (at your option) any
16  later version.
17 
18 or both in parallel, as here.
19 
20 The Piranha library is distributed in the hope that it will be useful, but
21 WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
22 or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
23 for more details.
24 
25 You should have received copies of the GNU General Public License and the
26 GNU Lesser General Public License along with the Piranha library. If not,
27 see https://www.gnu.org/licenses/. */
28 
29 #ifndef PIRANHA_DETAIL_BASE_SERIES_MULTIPLIER_HPP
30 #define PIRANHA_DETAIL_BASE_SERIES_MULTIPLIER_HPP
31 
32 #include <algorithm>
33 #include <array>
34 #include <boost/numeric/conversion/cast.hpp>
35 #include <cmath>
36 #include <cstddef>
37 #include <iterator>
38 #include <limits>
39 #include <mutex>
40 #include <random>
41 #include <stdexcept>
42 #include <type_traits>
43 #include <utility>
44 #include <vector>
45 
46 #include <piranha/config.hpp>
47 #include <piranha/detail/atomic_flag_array.hpp>
48 #include <piranha/detail/atomic_lock_guard.hpp>
49 #include <piranha/exceptions.hpp>
50 #include <piranha/key_is_multipliable.hpp>
51 #include <piranha/math.hpp>
52 #include <piranha/mp_integer.hpp>
53 #include <piranha/mp_rational.hpp>
54 #include <piranha/safe_cast.hpp>
55 #include <piranha/series.hpp>
56 #include <piranha/settings.hpp>
57 #include <piranha/symbol_utils.hpp>
58 #include <piranha/thread_pool.hpp>
59 #include <piranha/tuning.hpp>
60 #include <piranha/type_traits.hpp>
61 
62 namespace piranha
63 {
64 
65 namespace detail
66 {
67 
68 template <typename Series, typename Derived, typename = void>
69 struct base_series_multiplier_impl {
70  using term_type = typename Series::term_type;
71  using container_type = typename std::decay<decltype(std::declval<Series>()._container())>::type;
72  using c_size_type = typename container_type::size_type;
73  using v_size_type = typename std::vector<term_type const *>::size_type;
74  template <typename Term,
75  typename std::enable_if<!is_less_than_comparable<typename Term::key_type>::value, int>::type = 0>
76  void fill_term_pointers(const container_type &c1, const container_type &c2, std::vector<Term const *> &v1,
77  std::vector<Term const *> &v2)
78  {
79  // If the key is not less-than comparable, we can only copy over the pointers as they are.
80  std::transform(c1.begin(), c1.end(), std::back_inserter(v1), [](const term_type &t) { return &t; });
81  std::transform(c2.begin(), c2.end(), std::back_inserter(v2), [](const term_type &t) { return &t; });
82  }
83  template <typename Term,
84  typename std::enable_if<is_less_than_comparable<typename Term::key_type>::value, int>::type = 0>
85  void fill_term_pointers(const container_type &c1, const container_type &c2, std::vector<Term const *> &v1,
86  std::vector<Term const *> &v2)
87  {
88  // Fetch the number of threads from the derived class.
89  const unsigned n_threads = static_cast<Derived *>(this)->m_n_threads;
90  piranha_assert(n_threads > 0u);
91  // Threading functor.
92  auto thread_func
93  = [n_threads](unsigned thread_idx, const container_type *c, std::vector<term_type const *> *v) {
94  piranha_assert(thread_idx < n_threads);
95  // Total bucket count.
96  const auto b_count = c->bucket_count();
97  // Buckets per thread.
98  const auto bpt = b_count / n_threads;
99  // End index.
100  const auto end
101  = static_cast<c_size_type>((thread_idx == n_threads - 1u) ? b_count : (bpt * (thread_idx + 1u)));
102  // Sorter.
103  auto sorter = [](term_type const *p1, term_type const *p2) { return p1->m_key < p2->m_key; };
104  v_size_type j = 0u;
105  for (auto start = static_cast<c_size_type>(bpt * thread_idx); start < end; ++start) {
106  const auto &b = c->_get_bucket_list(start);
107  v_size_type tmp = 0u;
108  for (const auto &t : b) {
109  v->push_back(&t);
110  ++tmp;
111  }
112  std::stable_sort(v->data() + j, v->data() + j + tmp, sorter);
113  j += tmp;
114  }
115  };
116  if (n_threads == 1u) {
117  thread_func(0u, &c1, &v1);
118  thread_func(0u, &c2, &v2);
119  return;
120  }
121  auto thread_wrapper = [&thread_func, n_threads](const container_type *c, std::vector<term_type const *> *v) {
122  // In the multi-threaded case, each thread needs to work on a separate vector.
123  // We will merge the vectors later.
124  using vv_t = std::vector<std::vector<Term const *>>;
125  using vv_size_t = typename vv_t::size_type;
126  vv_t vv(safe_cast<vv_size_t>(n_threads));
127  // Go with the threads.
128  future_list<void> ff_list;
129  try {
130  for (unsigned i = 0u; i < n_threads; ++i) {
131  ff_list.push_back(thread_pool::enqueue(i, thread_func, i, c, &(vv[static_cast<vv_size_t>(i)])));
132  }
133  // First let's wait for everything to finish.
134  ff_list.wait_all();
135  // Then, let's handle the exceptions.
136  ff_list.get_all();
137  } catch (...) {
138  ff_list.wait_all();
139  throw;
140  }
141  // Last, we need to merge everything into v.
142  for (const auto &vi : vv) {
143  v->insert(v->end(), vi.begin(), vi.end());
144  }
145  };
146  thread_wrapper(&c1, &v1);
147  thread_wrapper(&c2, &v2);
148  }
149 };
150 
151 template <typename Series, typename Derived>
152 struct base_series_multiplier_impl<
153  Series, Derived, typename std::enable_if<is_mp_rational<typename Series::term_type::cf_type>::value>::type> {
154  // Useful shortcuts.
155  using term_type = typename Series::term_type;
156  using rat_type = typename term_type::cf_type;
157  using int_type = typename std::decay<decltype(std::declval<rat_type>().num())>::type;
158  using container_type = typename std::decay<decltype(std::declval<Series>()._container())>::type;
159  void fill_term_pointers(const container_type &c1, const container_type &c2, std::vector<term_type const *> &v1,
160  std::vector<term_type const *> &v2)
161  {
162  // Compute the least common multiplier.
163  m_lcm = 1;
164  auto it_f = c1.end();
165  int_type g;
166  for (auto it = c1.begin(); it != it_f; ++it) {
167  math::gcd3(g, m_lcm, it->m_cf.den());
168  math::mul3(m_lcm, m_lcm, it->m_cf.den());
169  divexact(m_lcm, m_lcm, g);
170  }
171  it_f = c2.end();
172  for (auto it = c2.begin(); it != it_f; ++it) {
173  math::gcd3(g, m_lcm, it->m_cf.den());
174  math::mul3(m_lcm, m_lcm, it->m_cf.den());
175  divexact(m_lcm, m_lcm, g);
176  }
177  // All these computations involve only positive numbers,
178  // the GCD must always be positive.
179  piranha_assert(m_lcm.sgn() == 1);
180  // Copy over the terms and renormalise to lcm.
181  it_f = c1.end();
182  for (auto it = c1.begin(); it != it_f; ++it) {
183  // NOTE: these divisions are exact, we could take advantage of that.
184  m_terms1.push_back(term_type(rat_type(m_lcm / it->m_cf.den() * it->m_cf.num(), int_type(1)), it->m_key));
185  }
186  it_f = c2.end();
187  for (auto it = c2.begin(); it != it_f; ++it) {
188  m_terms2.push_back(term_type(rat_type(m_lcm / it->m_cf.den() * it->m_cf.num(), int_type(1)), it->m_key));
189  }
190  // Copy over the pointers.
191  std::transform(m_terms1.begin(), m_terms1.end(), std::back_inserter(v1), [](const term_type &t) { return &t; });
192  std::transform(m_terms2.begin(), m_terms2.end(), std::back_inserter(v2), [](const term_type &t) { return &t; });
193  piranha_assert(v1.size() == c1.size());
194  piranha_assert(v2.size() == c2.size());
195  }
196  std::vector<term_type> m_terms1;
197  std::vector<term_type> m_terms2;
198  int_type m_lcm;
199 };
200 }
201 
203 
222 // Some performance ideas:
223 // - optimisation in case one series has 1 term with unitary key and both series same type: multiply directly
224 // coefficients;
225 // - optimisation for coefficient series that merges all args, similar to the rational optimisation;
226 // - optimisation for load balancing similar to the poly multiplier.
227 template <typename Series>
228 class base_series_multiplier : private detail::base_series_multiplier_impl<Series, base_series_multiplier<Series>>
229 {
230  PIRANHA_TT_CHECK(is_series, Series);
231  // Make friends with the base, so it can access protected/private members of this.
232  friend struct detail::base_series_multiplier_impl<Series, base_series_multiplier<Series>>;
233  // Alias for the series' container type.
234  using container_type = uncvref_t<decltype(std::declval<const Series &>()._container())>;
235 
236 public:
238  using v_ptr = std::vector<typename Series::term_type const *>;
240  using size_type = typename v_ptr::size_type;
242  using bucket_size_type = typename Series::size_type;
243 
244 private:
245  // The default limit functor: it will include all terms in the second series.
246  struct default_limit_functor {
247  default_limit_functor(const base_series_multiplier &m) : m_size2(m.m_v2.size()) {}
248  size_type operator()(const size_type &) const
249  {
250  return m_size2;
251  }
252  const size_type m_size2;
253  };
254  // The purpose of this helper is to move in a coefficient series during insertion. For series,
255  // we know that moves leave the series in a valid state, and series multiplications do not benefit
256  // from an already-constructed destination - hence it is convenient to move them rather than copy.
257  template <typename Term, typename std::enable_if<!is_series<typename Term::cf_type>::value, int>::type = 0>
258  static Term &term_insertion(Term &t)
259  {
260  return t;
261  }
262  template <typename Term, typename std::enable_if<is_series<typename Term::cf_type>::value, int>::type = 0>
263  static Term term_insertion(Term &t)
264  {
265  return Term{std::move(t.m_cf), t.m_key};
266  }
267  // Implementation of finalise().
268  template <typename T, typename std::enable_if<is_mp_rational<typename T::term_type::cf_type>::value, int>::type = 0>
269  void finalise_impl(T &s) const
270  {
271  // Nothing to do if the lcm is unitary.
272  if (math::is_unitary(this->m_lcm)) {
273  return;
274  }
275  // NOTE: this has to be the square of the lcm, as in addition to uniformising
276  // the denominators in each series we are also multiplying the two series.
277  const auto l2 = this->m_lcm * this->m_lcm;
278  auto &container = s._container();
279  // Single thread implementation.
280  if (m_n_threads == 1u) {
281  for (const auto &t : container) {
282  t.m_cf._set_den(l2);
283  t.m_cf.canonicalise();
284  }
285  return;
286  }
287  // Multi-thread implementation.
288  // Buckets per thread.
289  const bucket_size_type bpt = static_cast<bucket_size_type>(container.bucket_count() / m_n_threads);
290  auto thread_func = [l2, &container, this, bpt](unsigned t_idx) {
291  bucket_size_type start_idx = static_cast<bucket_size_type>(t_idx * bpt);
292  // Special handling for the last thread.
293  const bucket_size_type end_idx = t_idx == (this->m_n_threads - 1u)
294  ? container.bucket_count()
295  : static_cast<bucket_size_type>((t_idx + 1u) * bpt);
296  for (; start_idx != end_idx; ++start_idx) {
297  auto &list = container._get_bucket_list(start_idx);
298  for (const auto &t : list) {
299  t.m_cf._set_den(l2);
300  t.m_cf.canonicalise();
301  }
302  }
303  };
304  // Go with the threads.
305  future_list<decltype(thread_func(0u))> ff_list;
306  try {
307  for (unsigned i = 0u; i < m_n_threads; ++i) {
308  ff_list.push_back(thread_pool::enqueue(i, thread_func, i));
309  }
310  // First let's wait for everything to finish.
311  ff_list.wait_all();
312  // Then, let's handle the exceptions.
313  ff_list.get_all();
314  } catch (...) {
315  ff_list.wait_all();
316  throw;
317  }
318  }
319  template <typename T,
320  typename std::enable_if<!is_mp_rational<typename T::term_type::cf_type>::value, int>::type = 0>
321  void finalise_impl(T &) const
322  {
323  }
324 
325 public:
327 
355  explicit base_series_multiplier(const Series &s1, const Series &s2) : m_ss(s1.get_symbol_set())
356  {
357  if (unlikely(s1.get_symbol_set() != s2.get_symbol_set())) {
358  piranha_throw(std::invalid_argument, "incompatible arguments sets");
359  }
360  // The largest series goes first.
361  const Series *p1 = &s1, *p2 = &s2;
362  if (s1.size() < s2.size()) {
363  std::swap(p1, p2);
364  }
365  // This is just an optimisation, no troubles if there is a truncation due to static_cast.
366  m_v1.reserve(static_cast<size_type>(p1->size()));
367  m_v2.reserve(static_cast<size_type>(p2->size()));
368  container_type const *ctr1 = &p1->_container(), *ctr2 = &p2->_container();
369  // NOTE: if the zero element of Series is not absorbing, we need to create a temporary zero series in place
370  // of any factor that is zero, and then use it in the multiplication. This ensures a correct series
371  // multiplication result for coefficient types (such as IEEE floats) for which 0 times x is not necessarily
372  // always 0. The temporary zero series is stored in the m_zero_f member as a collection of 1 term with
373  // zero coefficient.
375  using term_type = typename Series::term_type;
376  using cf_type = typename term_type::cf_type;
377  using key_type = typename term_type::key_type;
378  if (p1->empty()) {
379  m_zero_f1.insert(term_type{cf_type(0), key_type(s1.get_symbol_set())});
380  ctr1 = &m_zero_f1;
381  }
382  if (p2->empty()) {
383  m_zero_f2.insert(term_type{cf_type(0), key_type(s1.get_symbol_set())});
384  ctr2 = &m_zero_f2;
385  }
386  }
387  // Set the number of threads.
388  m_n_threads = (ctr1->size() && ctr2->size())
389  ? thread_pool::use_threads(integer(ctr1->size()) * ctr2->size(),
391  : 1u;
392  this->fill_term_pointers(*ctr1, *ctr2, m_v1, m_v2);
393  }
394 
395 private:
396  base_series_multiplier() = delete;
399  base_series_multiplier &operator=(const base_series_multiplier &) = delete;
400  base_series_multiplier &operator=(base_series_multiplier &&) = delete;
401 
402 protected:
404 
441  template <typename MultFunctor, typename LimitFunctor>
442  void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1,
443  const LimitFunctor &lf) const
444  {
445  PIRANHA_TT_CHECK(is_function_object, MultFunctor, void, const size_type &, const size_type &);
446  if (unlikely(start1 > end1 || start1 > m_v1.size() || end1 > m_v1.size())) {
447  piranha_throw(std::invalid_argument, "invalid bounds in blocked_multiplication");
448  }
449  // Block size and number of regular blocks.
451  nblocks1 = static_cast<size_type>((end1 - start1) / bsize),
452  nblocks2 = static_cast<size_type>(m_v2.size() / bsize);
453  // Start and end of last (possibly irregular) blocks.
454  const size_type i_ir_start = static_cast<size_type>(nblocks1 * bsize + start1), i_ir_end = end1;
455  const size_type j_ir_start = static_cast<size_type>(nblocks2 * bsize), j_ir_end = m_v2.size();
456  for (size_type n1 = 0u; n1 < nblocks1; ++n1) {
457  const size_type i_start = static_cast<size_type>(n1 * bsize + start1),
458  i_end = static_cast<size_type>(i_start + bsize);
459  // regulars1 * regulars2
460  for (size_type n2 = 0u; n2 < nblocks2; ++n2) {
461  const size_type j_start = static_cast<size_type>(n2 * bsize),
462  j_end = static_cast<size_type>(j_start + bsize);
463  for (size_type i = i_start; i < i_end; ++i) {
464  const size_type limit = std::min<size_type>(lf(i), j_end);
465  for (size_type j = j_start; j < limit; ++j) {
466  mf(i, j);
467  }
468  }
469  }
470  // regulars1 * rem2
471  for (size_type i = i_start; i < i_end; ++i) {
472  const size_type limit = std::min<size_type>(lf(i), j_ir_end);
473  for (size_type j = j_ir_start; j < limit; ++j) {
474  mf(i, j);
475  }
476  }
477  }
478  // rem1 * regulars2
479  for (size_type n2 = 0u; n2 < nblocks2; ++n2) {
480  const size_type j_start = static_cast<size_type>(n2 * bsize),
481  j_end = static_cast<size_type>(j_start + bsize);
482  for (size_type i = i_ir_start; i < i_ir_end; ++i) {
483  const size_type limit = std::min<size_type>(lf(i), j_end);
484  for (size_type j = j_start; j < limit; ++j) {
485  mf(i, j);
486  }
487  }
488  }
489  // rem1 * rem2.
490  for (size_type i = i_ir_start; i < i_ir_end; ++i) {
491  const size_type limit = std::min<size_type>(lf(i), j_ir_end);
492  for (size_type j = j_ir_start; j < limit; ++j) {
493  mf(i, j);
494  }
495  }
496  }
498 
508  template <typename MultFunctor>
509  void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1) const
510  {
511  blocked_multiplication(mf, start1, end1, default_limit_functor{*this});
512  }
514 
555  template <std::size_t MultArity, typename MultFunctor, typename LimitFunctor>
556  bucket_size_type estimate_final_series_size(const LimitFunctor &lf) const
557  {
558  PIRANHA_TT_CHECK(is_function_object, MultFunctor, void, const size_type &, const size_type &);
559  PIRANHA_TT_CHECK(std::is_constructible, MultFunctor, const base_series_multiplier &, Series &);
560  PIRANHA_TT_CHECK(is_function_object, LimitFunctor, size_type, const size_type &);
561  // Cache these.
562  const size_type size1 = m_v1.size(), size2 = m_v2.size();
563  constexpr std::size_t result_size = MultArity;
564  // If one of the two series is empty, just return 0.
565  if (unlikely(!size1 || !size2)) {
566  return 1u;
567  }
568  // If either series has a size of 1, just return size1 * size2 * result_size.
569  if (size1 == 1u || size2 == 1u) {
570  return static_cast<bucket_size_type>(integer(size1) * size2 * result_size);
571  }
572  // NOTE: Hard-coded number of trials.
573  // NOTE: here consider that in case of extremely sparse series with few terms this will incur in noticeable
574  // overhead, since we will need many term-by-term before encountering the first duplicate.
575  const unsigned n_trials = 15u;
576  // NOTE: Hard-coded value for the estimation multiplier.
577  // NOTE: This value should be tuned for performance/memory usage tradeoffs.
578  const unsigned multiplier = 2u;
579  // Number of threads to use. If there are more threads than trials, then reduce
580  // the number of actual threads to use.
581  // NOTE: this is a bit different from usual, where we do not care if the workload per thread is zero.
582  // We do like this because n_trials is a small number and there still seems to be benefit in running
583  // just 1 trial per thread.
584  const unsigned n_threads = (n_trials >= m_n_threads) ? m_n_threads : n_trials;
585  piranha_assert(n_threads > 0u);
586  // Trials per thread. This will always be at least 1.
587  const unsigned tpt = n_trials / n_threads;
588  piranha_assert(tpt >= 1u);
589  // The cumulative estimate.
590  integer c_estimate(0);
591  // Sync mutex - actually used only in multithreading.
592  std::mutex mut;
593  // The estimation functor.
594  auto estimator = [&lf, size1, n_threads, tpt, this, &c_estimate, &mut](unsigned thread_idx) {
595  piranha_assert(thread_idx < n_threads);
596  // Vectors of indices into m_v1.
597  std::vector<size_type> v_idx1(safe_cast<typename std::vector<size_type>::size_type>(size1));
598  std::iota(v_idx1.begin(), v_idx1.end(), size_type(0));
599  // Copy in order to reset to initial state later.
600  const auto v_idx1_copy = v_idx1;
601  // Random number engine.
602  std::mt19937 engine;
603  // Uniform int distribution.
604  using dist_type = std::uniform_int_distribution<size_type>;
605  dist_type dist;
606  // Init the accumulated estimation for averaging later.
607  integer acc(0);
608  // Number of trials for this thread - usual special casing for the last thread.
609  const unsigned cur_trials = (thread_idx == n_threads - 1u) ? (n_trials - thread_idx * tpt) : tpt;
610  // This should always be guaranteed because tpt is never 0.
611  piranha_assert(cur_trials > 0u);
612  // Create and setup the temp series.
613  Series tmp;
614  tmp.set_symbol_set(m_ss);
615  // Create the multiplier.
616  MultFunctor mf(*this, tmp);
617  // Go with the trials.
618  for (auto n = 0u; n < cur_trials; ++n) {
619  // Seed the engine. The seed should be the global trial number, accounting for multiple
620  // threads. This way the estimation will not depend on the number of threads.
621  engine.seed(static_cast<std::mt19937::result_type>(tpt * thread_idx + n));
622  // Reset the indices vector and re-randomise it.
623  // NOTE: we need to do this as every run inside this for loop must be completely independent
624  // of any previous run, we cannot keep any state.
625  v_idx1 = v_idx1_copy;
626  std::shuffle(v_idx1.begin(), v_idx1.end(), engine);
627  // The counter. This will be increased each time a term-by-term multiplication
628  // does not generate a duplicate term.
629  size_type count = 0u;
630  // This will be used to determine the average number of terms in s2
631  // that participate in the multiplication.
632  integer acc_s2(0);
633  auto it1 = v_idx1.begin();
634  for (; it1 != v_idx1.end(); ++it1) {
635  // Get the limit idx in s2.
636  const size_type limit = lf(*it1);
637  // This is the upper limit of an open ended interval, so it needs
638  // to be decreased by one in order to be used in dist. If zero, it means
639  // there are no terms in v2 that can be multiplied by the current term in t1.
640  if (limit == 0u) {
641  continue;
642  }
643  acc_s2 += limit;
644  // Pick a random index in m_v2 within the limit.
645  const size_type idx2
646  = dist(engine, typename dist_type::param_type(static_cast<size_type>(0u),
647  static_cast<size_type>(limit - 1u)));
648  // Perform term multiplication.
649  mf(*it1, idx2);
650  // Check for unlikely overflows when increasing count.
651  if (unlikely(result_size > std::numeric_limits<size_type>::max()
652  || count > std::numeric_limits<size_type>::max() - result_size)) {
653  piranha_throw(std::overflow_error, "overflow error");
654  }
655  if (tmp.size() != count + result_size) {
656  break;
657  }
658  // Increase cycle variables.
659  count = static_cast<size_type>(count + result_size);
660  }
661  integer add;
662  if (it1 == v_idx1.end()) {
663  // We never found a duplicate. count is now the number of terms in s1
664  // which actually participate in the multiplication, while acc_s2 / count
665  // is the average number of terms in s2 that participate in the multiplication.
666  // The result will be then count * acc_s2 / count = acc_s2.
667  add = acc_s2;
668  } else {
669  // If we found a duplicate, we use the heuristic.
670  add = integer(multiplier) * count * count;
671  }
672  // Fix if zero, so that the average later never results in zero.
673  if (add.sgn() == 0) {
674  add = 1;
675  }
676  acc += add;
677  // Reset tmp.
678  tmp._container().clear();
679  }
680  // Accumulate in the shared variable.
681  if (n_threads == 1u) {
682  // No locking needed.
683  c_estimate += acc;
684  } else {
685  std::lock_guard<std::mutex> lock(mut);
686  c_estimate += acc;
687  }
688  };
689  // Run the estimation functor.
690  if (n_threads == 1u) {
691  estimator(0u);
692  } else {
693  future_list<void> f_list;
694  try {
695  for (unsigned i = 0u; i < n_threads; ++i) {
696  f_list.push_back(thread_pool::enqueue(i, estimator, i));
697  }
698  // First let's wait for everything to finish.
699  f_list.wait_all();
700  // Then, let's handle the exceptions.
701  f_list.get_all();
702  } catch (...) {
703  f_list.wait_all();
704  throw;
705  }
706  }
707  piranha_assert(c_estimate >= n_trials);
708  // Return the mean.
709  return static_cast<bucket_size_type>(c_estimate / n_trials);
710  }
712 
718  template <std::size_t MultArity, typename MultFunctor>
720  {
721  return estimate_final_series_size<MultArity, MultFunctor>(default_limit_functor{*this});
722  }
724 
743  template <bool FastMode>
745  {
746  using term_type = typename Series::term_type;
747  using key_type = typename term_type::key_type;
748  PIRANHA_TT_CHECK(key_is_multipliable, typename term_type::cf_type, key_type);
749  using it_type = decltype(std::declval<Series &>()._container().end());
750  static constexpr std::size_t m_arity = key_type::multiply_arity;
751 
752  public:
754 
761  explicit plain_multiplier(const base_series_multiplier &bsm, Series &retval)
762  : m_v1(bsm.m_v1), m_v2(bsm.m_v2), m_retval(retval), m_c_end(retval._container().end())
763  {
764  }
765 
766  private:
767  plain_multiplier(const plain_multiplier &) = delete;
768  plain_multiplier(plain_multiplier &&) = delete;
769  plain_multiplier &operator=(const plain_multiplier &) = delete;
770  plain_multiplier &operator=(plain_multiplier &&) = delete;
771 
772  public:
774 
787  void operator()(const size_type &i, const size_type &j) const
788  {
789  // First perform the multiplication.
790  key_type::multiply(m_tmp_t, *m_v1[i], *m_v2[j], m_retval.get_symbol_set());
791  for (std::size_t n = 0u; n < m_arity; ++n) {
792  auto &tmp_term = m_tmp_t[n];
793  if (FastMode) {
794  auto &container = m_retval._container();
795  // Try to locate the term into retval.
796  auto bucket_idx = container._bucket(tmp_term);
797  const auto it = container._find(tmp_term, bucket_idx);
798  if (it == m_c_end) {
799  container._unique_insert(term_insertion(tmp_term), bucket_idx);
800  } else {
801  it->m_cf += tmp_term.m_cf;
802  }
803  } else {
804  m_retval.insert(term_insertion(tmp_term));
805  }
806  }
807  }
808 
809  private:
810  mutable std::array<term_type, m_arity> m_tmp_t;
811  const std::vector<term_type const *> &m_v1;
812  const std::vector<term_type const *> &m_v2;
813  Series &m_retval;
814  const it_type m_c_end;
815  };
817 
845  static void sanitise_series(Series &retval, unsigned n_threads)
846  {
847  using term_type = typename Series::term_type;
848  if (unlikely(n_threads == 0u)) {
849  piranha_throw(std::invalid_argument, "invalid number of threads");
850  }
851  auto &container = retval._container();
852  const auto &args = retval.get_symbol_set();
853  // Reset the size to zero before doing anything.
854  container._update_size(static_cast<bucket_size_type>(0u));
855  // Single-thread implementation.
856  if (n_threads == 1u) {
857  const auto it_end = container.end();
858  for (auto it = container.begin(); it != it_end;) {
859  if (unlikely(!it->is_compatible(args))) {
860  piranha_throw(std::invalid_argument, "incompatible term");
861  }
862  if (unlikely(container.size() == std::numeric_limits<bucket_size_type>::max())) {
863  piranha_throw(std::overflow_error, "overflow error in the number of terms of a series");
864  }
865  // First update the size, it will be scaled back in the erase() method if necessary.
866  container._update_size(static_cast<bucket_size_type>(container.size() + 1u));
867  if (unlikely(it->is_zero(args))) {
868  it = container.erase(it);
869  } else {
870  ++it;
871  }
872  }
873  return;
874  }
875  // Multi-thread implementation.
876  const auto b_count = container.bucket_count();
877  std::mutex m;
878  integer global_count(0);
879  auto eraser = [b_count, &container, &m, &args, &global_count](const bucket_size_type &start,
880  const bucket_size_type &end) {
881  piranha_assert(start <= end && end <= b_count);
882  (void)b_count;
883  bucket_size_type count = 0u;
884  std::vector<term_type> term_list;
885  // Examine and count the terms bucket-by-bucket.
886  for (bucket_size_type i = start; i != end; ++i) {
887  term_list.clear();
888  const auto &bl = container._get_bucket_list(i);
889  const auto it_f = bl.end();
890  for (auto it = bl.begin(); it != it_f; ++it) {
891  // Check first for compatibility.
892  if (unlikely(!it->is_compatible(args))) {
893  piranha_throw(std::invalid_argument, "incompatible term");
894  }
895  // Check for ignorability.
896  if (unlikely(it->is_zero(args))) {
897  term_list.push_back(*it);
898  }
899  // Update the count of terms.
900  if (unlikely(count == std::numeric_limits<bucket_size_type>::max())) {
901  piranha_throw(std::overflow_error, "overflow error in the number of terms of a series");
902  }
903  count = static_cast<bucket_size_type>(count + 1u);
904  }
905  for (auto it = term_list.begin(); it != term_list.end(); ++it) {
906  // NOTE: must use _erase to avoid concurrent modifications
907  // to the number of elements in the table.
908  container._erase(container._find(*it, i));
909  // Account for the erased term.
910  piranha_assert(count > 0u);
911  count = static_cast<bucket_size_type>(count - 1u);
912  }
913  }
914  // Update the global count.
915  std::lock_guard<std::mutex> lock(m);
916  global_count += count;
917  };
919  try {
920  for (unsigned i = 0u; i < n_threads; ++i) {
921  const auto start = static_cast<bucket_size_type>((b_count / n_threads) * i),
922  end = static_cast<bucket_size_type>(
923  (i == n_threads - 1u) ? b_count : (b_count / n_threads) * (i + 1u));
924  f_list.push_back(thread_pool::enqueue(i, eraser, start, end));
925  }
926  // First let's wait for everything to finish.
927  f_list.wait_all();
928  // Then, let's handle the exceptions.
929  f_list.get_all();
930  } catch (...) {
931  f_list.wait_all();
932  // NOTE: there's not need to clear retval here - it was already in an inconsistent
933  // state coming into this method. We rather need to make sure sanitise_series() is always
934  // called in a try/catch block that clears retval in case of errors.
935  throw;
936  }
937  // Final update of the total count.
938  container._update_size(static_cast<bucket_size_type>(global_count));
939  }
941 
977  template <typename LimitFunctor>
978  Series plain_multiplication(const LimitFunctor &lf) const
979  {
980  // Shortcuts.
981  using term_type = typename Series::term_type;
982  using cf_type = typename term_type::cf_type;
983  using key_type = typename term_type::key_type;
984  PIRANHA_TT_CHECK(key_is_multipliable, cf_type, key_type);
985  constexpr std::size_t m_arity = key_type::multiply_arity;
986  // Setup the return value with the merged symbol set.
987  Series retval;
988  retval.set_symbol_set(m_ss);
989  // Do not do anything if one of the two series is empty.
990  if (unlikely(m_v1.empty() || m_v2.empty())) {
991  return retval;
992  }
993  const size_type size1 = m_v1.size(), size2 = m_v2.size();
994  (void)size2;
995  piranha_assert(size1 && size2);
996  // Convert n_threads to size_type for convenience.
997  const size_type n_threads = safe_cast<size_type>(m_n_threads);
998  piranha_assert(n_threads);
999  // Determine if we should estimate the size. We check the threshold, but we always
1000  // need to estimate in multithreaded mode.
1001  bool estimate = true;
1002  const auto e_thr = tuning::get_estimate_threshold();
1003  if (integer(m_v1.size()) * m_v2.size() < integer(e_thr) * e_thr && n_threads == 1u) {
1004  estimate = false;
1005  }
1006  if (estimate) {
1007  // Estimate and rehash.
1008  const auto est = estimate_final_series_size<m_arity, plain_multiplier<false>>(lf);
1009  // NOTE: use numeric cast here as safe_cast is expensive, going through an integer-double conversion,
1010  // and in this case the behaviour of numeric_cast is appropriate.
1011  const auto n_buckets = boost::numeric_cast<bucket_size_type>(
1012  std::ceil(static_cast<double>(est) / retval._container().max_load_factor()));
1013  piranha_assert(n_buckets > 0u);
1014  // Check if we want to use the parallel memory set.
1015  // NOTE: it is important here that we use the same n_threads for multiplication and memset as
1016  // we tie together pinned threads with potentially different NUMA regions.
1017  const unsigned n_threads_rehash = tuning::get_parallel_memory_set() ? static_cast<unsigned>(n_threads) : 1u;
1018  retval._container().rehash(n_buckets, n_threads_rehash);
1019  }
1020  if (n_threads == 1u) {
1021  try {
1022  // Single-thread case.
1023  if (estimate) {
1024  blocked_multiplication(plain_multiplier<true>(*this, retval), 0u, size1, lf);
1025  // If we estimated beforehand, we need to sanitise the series.
1026  sanitise_series(retval, static_cast<unsigned>(n_threads));
1027  } else {
1028  blocked_multiplication(plain_multiplier<false>(*this, retval), 0u, size1, lf);
1029  }
1030  finalise_series(retval);
1031  return retval;
1032  } catch (...) {
1033  retval._container().clear();
1034  throw;
1035  }
1036  }
1037  // Multi-threaded case.
1038  piranha_assert(estimate);
1039  // Init the vector of spinlocks.
1040  detail::atomic_flag_array sl_array(safe_cast<std::size_t>(retval._container().bucket_count()));
1041  // Init the future list.
1042  future_list<void> f_list;
1043  // Thread block size.
1044  const auto block_size = size1 / n_threads;
1045  try {
1046  for (size_type idx = 0u; idx < n_threads; ++idx) {
1047  // Thread functor.
1048  auto tf = [idx, this, block_size, n_threads, &sl_array, &retval, &lf]() {
1049  // Used to store the result of term multiplication.
1050  std::array<term_type, key_type::multiply_arity> tmp_t;
1051  // End of retval container (thread-safe).
1052  const auto c_end = retval._container().end();
1053  // Block functor.
1054  // NOTE: this is very similar to the plain functor, but it does the bucket locking
1055  // additionally.
1056  auto f = [&c_end, &tmp_t, this, &retval, &sl_array](const size_type &i, const size_type &j) {
1057  // Run the term multiplication.
1058  key_type::multiply(tmp_t, *(this->m_v1[i]), *(this->m_v2[j]), retval.get_symbol_set());
1059  for (std::size_t n = 0u; n < key_type::multiply_arity; ++n) {
1060  auto &container = retval._container();
1061  auto &tmp_term = tmp_t[n];
1062  // Try to locate the term into retval.
1063  auto bucket_idx = container._bucket(tmp_term);
1064  // Lock the bucket.
1065  detail::atomic_lock_guard alg(sl_array[static_cast<std::size_t>(bucket_idx)]);
1066  const auto it = container._find(tmp_term, bucket_idx);
1067  if (it == c_end) {
1068  container._unique_insert(term_insertion(tmp_term), bucket_idx);
1069  } else {
1070  it->m_cf += tmp_term.m_cf;
1071  }
1072  }
1073  };
1074  // Thread block limit.
1075  const auto e1
1076  = (idx == n_threads - 1u) ? this->m_v1.size() : static_cast<size_type>((idx + 1u) * block_size);
1077  this->blocked_multiplication(f, static_cast<size_type>(idx * block_size), e1, lf);
1078  };
1079  f_list.push_back(thread_pool::enqueue(static_cast<unsigned>(idx), tf));
1080  }
1081  f_list.wait_all();
1082  f_list.get_all();
1083  sanitise_series(retval, static_cast<unsigned>(n_threads));
1084  finalise_series(retval);
1085  } catch (...) {
1086  f_list.wait_all();
1087  // Clean up retval as it might be in an inconsistent state.
1088  retval._container().clear();
1089  throw;
1090  }
1091  return retval;
1092  }
1094 
1100  Series plain_multiplication() const
1101  {
1102  return plain_multiplication(default_limit_functor{*this});
1103  }
1105 
1120  void finalise_series(Series &s) const
1121  {
1122  finalise_impl(s);
1123  }
1124 
1125 protected:
1127  mutable v_ptr m_v1;
1129  mutable v_ptr m_v2;
1133 
1138  unsigned m_n_threads;
1139 
1140 private:
1141  // See the constructor for an explanation.
1142  container_type m_zero_f1;
1143  container_type m_zero_f2;
1144 };
1145 }
1146 
1147 #endif
typename Series::size_type bucket_size_type
The size type of Series.
void operator()(const size_type &i, const size_type &j) const
Call operator.
void wait_all()
Wait on all the futures.
Type trait for multipliable key.
void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1, const LimitFunctor &lf) const
Blocked multiplication.
Function object type trait.
v_ptr m_v2
Vector of const pointers to the terms in the smaller series.
void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1) const
Blocked multiplication (convenience overload).
Multiprecision integer class.
Definition: mp++.hpp:869
mp_integer< 1 > integer
Alias for piranha::mp_integer with 1 limb of static storage.
Definition: mp_integer.hpp:63
Series plain_multiplication(const LimitFunctor &lf) const
A plain series multiplication routine.
void finalise_series(Series &s) const
Finalise series.
bucket_size_type estimate_final_series_size() const
Estimate size of series multiplication (convenience overload)
static unsigned use_threads(const Int &work_size, const Int &min_work_per_thread)
Compute number of threads to use.
bucket_size_type estimate_final_series_size(const LimitFunctor &lf) const
Estimate size of series multiplication.
Exceptions.
Series plain_multiplication() const
A plain series multiplication routine (convenience overload).
STL namespace.
void get_all()
Get all the futures.
static enqueue_t< F &&, Args &&... > enqueue(unsigned n, F &&f, Args &&... args)
Enqueue task.
base_series_multiplier(const Series &s1, const Series &s2)
Constructor.
const symbol_fset m_ss
The symbol set of the series used during construction.
typename v_ptr::size_type size_type
The size type of base_series_multiplier::v_ptr.
#define piranha_throw(exception_type,...)
Exception-throwing macro.
Definition: exceptions.hpp:118
boost::container::flat_set< std::string > symbol_fset
Flat set of symbols.
Detect if zero is a multiplicative absorber.
v_ptr m_v1
Vector of const pointers to the terms in the larger series.
Class to store a list of futures.
static unsigned long get_estimate_threshold()
Get the series estimation threshold.
Definition: tuning.hpp:159
Root piranha namespace.
Definition: array_key.hpp:52
Type traits.
auto mul3(T &a, const T &b, const T &c) -> decltype(mul3_impl< T >()(a, b, c))
Ternary multiplication.
Definition: math.hpp:2726
plain_multiplier(const base_series_multiplier &bsm, Series &retval)
Constructor.
static unsigned long get_multiplication_block_size()
Get the multiplication block size.
Definition: tuning.hpp:117
static bool get_parallel_memory_set()
Get the parallel_memory_set flag.
Definition: tuning.hpp:81
static void sanitise_series(Series &retval, unsigned n_threads)
Sanitise series.
int sgn() const
Sign.
Definition: mp++.hpp:1611
Type trait to detect series types.
Definition: series_fwd.hpp:49
unsigned m_n_threads
Number of threads.
static unsigned long long get_min_work_per_thread()
Get the minimum work per thread.
Definition: settings.hpp:233
auto gcd3(T &out, const T &a, const T &b) -> decltype(gcd3_impl< T >()(out, a, b))
Ternary GCD.
Definition: math.hpp:2947
std::vector< typename Series::term_type const * > v_ptr
Alias for a vector of const pointers to series terms.
safe_cast_type< To, From > safe_cast(const From &x)
Safe cast.
Definition: safe_cast.hpp:219
bool is_unitary(const T &x)
Unitary test.
Definition: math.hpp:242
void push_back(std::future< T > &&f)
Move-insert a future.