piranha  0.10
32 #include <algorithm>
33 #include <array>
34 #include <boost/numeric/conversion/cast.hpp>
35 #include <cmath>
36 #include <cstddef>
37 #include <iterator>
38 #include <limits>
39 #include <mutex>
40 #include <random>
41 #include <stdexcept>
42 #include <type_traits>
43 #include <utility>
44 #include <vector>
46 #include <piranha/config.hpp>
47 #include <piranha/detail/atomic_flag_array.hpp>
48 #include <piranha/detail/atomic_lock_guard.hpp>
49 #include <piranha/exceptions.hpp>
50 #include <piranha/key_is_multipliable.hpp>
51 #include <piranha/math.hpp>
52 #include <piranha/mp_integer.hpp>
53 #include <piranha/mp_rational.hpp>
54 #include <piranha/safe_cast.hpp>
55 #include <piranha/series.hpp>
56 #include <piranha/settings.hpp>
57 #include <piranha/symbol_utils.hpp>
58 #include <piranha/thread_pool.hpp>
59 #include <piranha/tuning.hpp>
60 #include <piranha/type_traits.hpp>
62 namespace piranha
63 {
65 namespace detail
66 {
68 template <typename Series, typename Derived, typename = void>
69 struct base_series_multiplier_impl {
70  using term_type = typename Series::term_type;
71  using container_type = typename std::decay<decltype(std::declval<Series>()._container())>::type;
72  using c_size_type = typename container_type::size_type;
73  using v_size_type = typename std::vector<term_type const *>::size_type;
74  template <typename Term,
75  typename std::enable_if<!is_less_than_comparable<typename Term::key_type>::value, int>::type = 0>
76  void fill_term_pointers(const container_type &c1, const container_type &c2, std::vector<Term const *> &v1,
77  std::vector<Term const *> &v2)
78  {
79  // If the key is not less-than comparable, we can only copy over the pointers as they are.
80  std::transform(c1.begin(), c1.end(), std::back_inserter(v1), [](const term_type &t) { return &t; });
81  std::transform(c2.begin(), c2.end(), std::back_inserter(v2), [](const term_type &t) { return &t; });
82  }
83  template <typename Term,
84  typename std::enable_if<is_less_than_comparable<typename Term::key_type>::value, int>::type = 0>
85  void fill_term_pointers(const container_type &c1, const container_type &c2, std::vector<Term const *> &v1,
86  std::vector<Term const *> &v2)
87  {
88  // Fetch the number of threads from the derived class.
89  const unsigned n_threads = static_cast<Derived *>(this)->m_n_threads;
90  piranha_assert(n_threads > 0u);
91  // Threading functor.
92  auto thread_func
93  = [n_threads](unsigned thread_idx, const container_type *c, std::vector<term_type const *> *v) {
94  piranha_assert(thread_idx < n_threads);
95  // Total bucket count.
96  const auto b_count = c->bucket_count();
97  // Buckets per thread.
98  const auto bpt = b_count / n_threads;
99  // End index.
100  const auto end
101  = static_cast<c_size_type>((thread_idx == n_threads - 1u) ? b_count : (bpt * (thread_idx + 1u)));
102  // Sorter.
103  auto sorter = [](term_type const *p1, term_type const *p2) { return p1->m_key < p2->m_key; };
104  v_size_type j = 0u;
105  for (auto start = static_cast<c_size_type>(bpt * thread_idx); start < end; ++start) {
106  const auto &b = c->_get_bucket_list(start);
107  v_size_type tmp = 0u;
108  for (const auto &t : b) {
109  v->push_back(&t);
110  ++tmp;
111  }
112  std::stable_sort(v->data() + j, v->data() + j + tmp, sorter);
113  j += tmp;
114  }
115  };
116  if (n_threads == 1u) {
117  thread_func(0u, &c1, &v1);
118  thread_func(0u, &c2, &v2);
119  return;
120  }
121  auto thread_wrapper = [&thread_func, n_threads](const container_type *c, std::vector<term_type const *> *v) {
122  // In the multi-threaded case, each thread needs to work on a separate vector.
123  // We will merge the vectors later.
124  using vv_t = std::vector<std::vector<Term const *>>;
125  using vv_size_t = typename vv_t::size_type;
126  vv_t vv(safe_cast<vv_size_t>(n_threads));
127  // Go with the threads.
128  future_list<void> ff_list;
129  try {
130  for (unsigned i = 0u; i < n_threads; ++i) {
131  ff_list.push_back(thread_pool::enqueue(i, thread_func, i, c, &(vv[static_cast<vv_size_t>(i)])));
132  }
133  // First let's wait for everything to finish.
134  ff_list.wait_all();
135  // Then, let's handle the exceptions.
136  ff_list.get_all();
137  } catch (...) {
138  ff_list.wait_all();
139  throw;
140  }
141  // Last, we need to merge everything into v.
142  for (const auto &vi : vv) {
143  v->insert(v->end(), vi.begin(), vi.end());
144  }
145  };
146  thread_wrapper(&c1, &v1);
147  thread_wrapper(&c2, &v2);
148  }
149 };
151 template <typename Series, typename Derived>
152 struct base_series_multiplier_impl<
153  Series, Derived, typename std::enable_if<is_mp_rational<typename Series::term_type::cf_type>::value>::type> {
154  // Useful shortcuts.
155  using term_type = typename Series::term_type;
156  using rat_type = typename term_type::cf_type;
157  using int_type = typename std::decay<decltype(std::declval<rat_type>().num())>::type;
158  using container_type = typename std::decay<decltype(std::declval<Series>()._container())>::type;
159  void fill_term_pointers(const container_type &c1, const container_type &c2, std::vector<term_type const *> &v1,
160  std::vector<term_type const *> &v2)
161  {
162  // Compute the least common multiplier.
163  m_lcm = 1;
164  auto it_f = c1.end();
165  int_type g;
166  for (auto it = c1.begin(); it != it_f; ++it) {
167  math::gcd3(g, m_lcm, it->m_cf.den());
168  math::mul3(m_lcm, m_lcm, it->m_cf.den());
169  divexact(m_lcm, m_lcm, g);
170  }
171  it_f = c2.end();
172  for (auto it = c2.begin(); it != it_f; ++it) {
173  math::gcd3(g, m_lcm, it->m_cf.den());
174  math::mul3(m_lcm, m_lcm, it->m_cf.den());
175  divexact(m_lcm, m_lcm, g);
176  }
177  // All these computations involve only positive numbers,
178  // the GCD must always be positive.
179  piranha_assert(m_lcm.sgn() == 1);
180  // Copy over the terms and renormalise to lcm.
181  it_f = c1.end();
182  for (auto it = c1.begin(); it != it_f; ++it) {
183  // NOTE: these divisions are exact, we could take advantage of that.
184  m_terms1.push_back(term_type(rat_type(m_lcm / it->m_cf.den() * it->m_cf.num(), int_type(1)), it->m_key));
185  }
186  it_f = c2.end();
187  for (auto it = c2.begin(); it != it_f; ++it) {
188  m_terms2.push_back(term_type(rat_type(m_lcm / it->m_cf.den() * it->m_cf.num(), int_type(1)), it->m_key));
189  }
190  // Copy over the pointers.
191  std::transform(m_terms1.begin(), m_terms1.end(), std::back_inserter(v1), [](const term_type &t) { return &t; });
192  std::transform(m_terms2.begin(), m_terms2.end(), std::back_inserter(v2), [](const term_type &t) { return &t; });
193  piranha_assert(v1.size() == c1.size());
194  piranha_assert(v2.size() == c2.size());
195  }
196  std::vector<term_type> m_terms1;
197  std::vector<term_type> m_terms2;
198  int_type m_lcm;
199 };
200 }
222 // Some performance ideas:
223 // - optimisation in case one series has 1 term with unitary key and both series same type: multiply directly
224 // coefficients;
225 // - optimisation for coefficient series that merges all args, similar to the rational optimisation;
226 // - optimisation for load balancing similar to the poly multiplier.
227 template <typename Series>
228 class base_series_multiplier : private detail::base_series_multiplier_impl<Series, base_series_multiplier<Series>>
229 {
230  PIRANHA_TT_CHECK(is_series, Series);
231  // Make friends with the base, so it can access protected/private members of this.
232  friend struct detail::base_series_multiplier_impl<Series, base_series_multiplier<Series>>;
233  // Alias for the series' container type.
234  using container_type = uncvref_t<decltype(std::declval<const Series &>()._container())>;
236 public:
238  using v_ptr = std::vector<typename Series::term_type const *>;
240  using size_type = typename v_ptr::size_type;
242  using bucket_size_type = typename Series::size_type;
244 private:
245  // The default limit functor: it will include all terms in the second series.
246  struct default_limit_functor {
247  default_limit_functor(const base_series_multiplier &m) : m_size2(m.m_v2.size()) {}
248  size_type operator()(const size_type &) const
249  {
250  return m_size2;
251  }
252  const size_type m_size2;
253  };
254  // The purpose of this helper is to move in a coefficient series during insertion. For series,
255  // we know that moves leave the series in a valid state, and series multiplications do not benefit
256  // from an already-constructed destination - hence it is convenient to move them rather than copy.
257  template <typename Term, typename std::enable_if<!is_series<typename Term::cf_type>::value, int>::type = 0>
258  static Term &term_insertion(Term &t)
259  {
260  return t;
261  }
262  template <typename Term, typename std::enable_if<is_series<typename Term::cf_type>::value, int>::type = 0>
263  static Term term_insertion(Term &t)
264  {
265  return Term{std::move(t.m_cf), t.m_key};
266  }
267  // Implementation of finalise().
268  template <typename T, typename std::enable_if<is_mp_rational<typename T::term_type::cf_type>::value, int>::type = 0>
269  void finalise_impl(T &s) const
270  {
271  // Nothing to do if the lcm is unitary.
272  if (math::is_unitary(this->m_lcm)) {
273  return;
274  }
275  // NOTE: this has to be the square of the lcm, as in addition to uniformising
276  // the denominators in each series we are also multiplying the two series.
277  const auto l2 = this->m_lcm * this->m_lcm;
278  auto &container = s._container();
279  // Single thread implementation.
280  if (m_n_threads == 1u) {
281  for (const auto &t : container) {
282  t.m_cf._set_den(l2);
283  t.m_cf.canonicalise();
284  }
285  return;
286  }
287  // Multi-thread implementation.
288  // Buckets per thread.
289  const bucket_size_type bpt = static_cast<bucket_size_type>(container.bucket_count() / m_n_threads);
290  auto thread_func = [l2, &container, this, bpt](unsigned t_idx) {
291  bucket_size_type start_idx = static_cast<bucket_size_type>(t_idx * bpt);
292  // Special handling for the last thread.
293  const bucket_size_type end_idx = t_idx == (this->m_n_threads - 1u)
294  ? container.bucket_count()
295  : static_cast<bucket_size_type>((t_idx + 1u) * bpt);
296  for (; start_idx != end_idx; ++start_idx) {
297  auto &list = container._get_bucket_list(start_idx);
298  for (const auto &t : list) {
299  t.m_cf._set_den(l2);
300  t.m_cf.canonicalise();
301  }
302  }
303  };
304  // Go with the threads.
305  future_list<decltype(thread_func(0u))> ff_list;
306  try {
307  for (unsigned i = 0u; i < m_n_threads; ++i) {
308  ff_list.push_back(thread_pool::enqueue(i, thread_func, i));
309  }
310  // First let's wait for everything to finish.
311  ff_list.wait_all();
312  // Then, let's handle the exceptions.
313  ff_list.get_all();
314  } catch (...) {
315  ff_list.wait_all();
316  throw;
317  }
318  }
319  template <typename T,
320  typename std::enable_if<!is_mp_rational<typename T::term_type::cf_type>::value, int>::type = 0>
321  void finalise_impl(T &) const
322  {
323  }
325 public:
355  explicit base_series_multiplier(const Series &s1, const Series &s2) : m_ss(s1.get_symbol_set())
356  {
357  if (unlikely(s1.get_symbol_set() != s2.get_symbol_set())) {
358  piranha_throw(std::invalid_argument, "incompatible arguments sets");
359  }
360  // The largest series goes first.
361  const Series *p1 = &s1, *p2 = &s2;
362  if (s1.size() < s2.size()) {
363  std::swap(p1, p2);
364  }
365  // This is just an optimisation, no troubles if there is a truncation due to static_cast.
366  m_v1.reserve(static_cast<size_type>(p1->size()));
367  m_v2.reserve(static_cast<size_type>(p2->size()));
368  container_type const *ctr1 = &p1->_container(), *ctr2 = &p2->_container();
369  // NOTE: if the zero element of Series is not absorbing, we need to create a temporary zero series in place
370  // of any factor that is zero, and then use it in the multiplication. This ensures a correct series
371  // multiplication result for coefficient types (such as IEEE floats) for which 0 times x is not necessarily
372  // always 0. The temporary zero series is stored in the m_zero_f member as a collection of 1 term with
373  // zero coefficient.
375  using term_type = typename Series::term_type;
376  using cf_type = typename term_type::cf_type;
377  using key_type = typename term_type::key_type;
378  if (p1->empty()) {
379  m_zero_f1.insert(term_type{cf_type(0), key_type(s1.get_symbol_set())});
380  ctr1 = &m_zero_f1;
381  }
382  if (p2->empty()) {
383  m_zero_f2.insert(term_type{cf_type(0), key_type(s1.get_symbol_set())});
384  ctr2 = &m_zero_f2;
385  }
386  }
387  // Set the number of threads.
388  m_n_threads = (ctr1->size() && ctr2->size())
389  ? thread_pool::use_threads(integer(ctr1->size()) * ctr2->size(),
391  : 1u;
392  this->fill_term_pointers(*ctr1, *ctr2, m_v1, m_v2);
393  }
395 private:
396  base_series_multiplier() = delete;
399  base_series_multiplier &operator=(const base_series_multiplier &) = delete;
400  base_series_multiplier &operator=(base_series_multiplier &&) = delete;
402 protected:
441  template <typename MultFunctor, typename LimitFunctor>
442  void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1,
443  const LimitFunctor &lf) const
444  {
445  PIRANHA_TT_CHECK(is_function_object, MultFunctor, void, const size_type &, const size_type &);
446  if (unlikely(start1 > end1 || start1 > m_v1.size() || end1 > m_v1.size())) {
447  piranha_throw(std::invalid_argument, "invalid bounds in blocked_multiplication");
448  }
449  // Block size and number of regular blocks.
451  nblocks1 = static_cast<size_type>((end1 - start1) / bsize),
452  nblocks2 = static_cast<size_type>(m_v2.size() / bsize);
453  // Start and end of last (possibly irregular) blocks.
454  const size_type i_ir_start = static_cast<size_type>(nblocks1 * bsize + start1), i_ir_end = end1;
455  const size_type j_ir_start = static_cast<size_type>(nblocks2 * bsize), j_ir_end = m_v2.size();
456  for (size_type n1 = 0u; n1 < nblocks1; ++n1) {
457  const size_type i_start = static_cast<size_type>(n1 * bsize + start1),
458  i_end = static_cast<size_type>(i_start + bsize);
459  // regulars1 * regulars2
460  for (size_type n2 = 0u; n2 < nblocks2; ++n2) {
461  const size_type j_start = static_cast<size_type>(n2 * bsize),
462  j_end = static_cast<size_type>(j_start + bsize);
463  for (size_type i = i_start; i < i_end; ++i) {
464  const size_type limit = std::min<size_type>(lf(i), j_end);
465  for (size_type j = j_start; j < limit; ++j) {
466  mf(i, j);
467  }
468  }
469  }
470  // regulars1 * rem2
471  for (size_type i = i_start; i < i_end; ++i) {
472  const size_type limit = std::min<size_type>(lf(i), j_ir_end);
473  for (size_type j = j_ir_start; j < limit; ++j) {
474  mf(i, j);
475  }
476  }
477  }
478  // rem1 * regulars2
479  for (size_type n2 = 0u; n2 < nblocks2; ++n2) {
480  const size_type j_start = static_cast<size_type>(n2 * bsize),
481  j_end = static_cast<size_type>(j_start + bsize);
482  for (size_type i = i_ir_start; i < i_ir_end; ++i) {
483  const size_type limit = std::min<size_type>(lf(i), j_end);
484  for (size_type j = j_start; j < limit; ++j) {
485  mf(i, j);
486  }
487  }
488  }
489  // rem1 * rem2.
490  for (size_type i = i_ir_start; i < i_ir_end; ++i) {
491  const size_type limit = std::min<size_type>(lf(i), j_ir_end);
492  for (size_type j = j_ir_start; j < limit; ++j) {
493  mf(i, j);
494  }
495  }
496  }
508  template <typename MultFunctor>
509  void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1) const
510  {
511  blocked_multiplication(mf, start1, end1, default_limit_functor{*this});
512  }
555  template <std::size_t MultArity, typename MultFunctor, typename LimitFunctor>
556  bucket_size_type estimate_final_series_size(const LimitFunctor &lf) const
557  {
558  PIRANHA_TT_CHECK(is_function_object, MultFunctor, void, const size_type &, const size_type &);
559  PIRANHA_TT_CHECK(std::is_constructible, MultFunctor, const base_series_multiplier &, Series &);
560  PIRANHA_TT_CHECK(is_function_object, LimitFunctor, size_type, const size_type &);
561  // Cache these.
562  const size_type size1 = m_v1.size(), size2 = m_v2.size();
563  constexpr std::size_t result_size = MultArity;
564  // If one of the two series is empty, just return 0.
565  if (unlikely(!size1 || !size2)) {
566  return 1u;
567  }
568  // If either series has a size of 1, just return size1 * size2 * result_size.
569  if (size1 == 1u || size2 == 1u) {
570  return static_cast<bucket_size_type>(integer(size1) * size2 * result_size);
571  }
572  // NOTE: Hard-coded number of trials.
573  // NOTE: here consider that in case of extremely sparse series with few terms this will incur in noticeable
574  // overhead, since we will need many term-by-term before encountering the first duplicate.
575  const unsigned n_trials = 15u;
576  // NOTE: Hard-coded value for the estimation multiplier.
577  // NOTE: This value should be tuned for performance/memory usage tradeoffs.
578  const unsigned multiplier = 2u;
579  // Number of threads to use. If there are more threads than trials, then reduce
580  // the number of actual threads to use.
581  // NOTE: this is a bit different from usual, where we do not care if the workload per thread is zero.
582  // We do like this because n_trials is a small number and there still seems to be benefit in running
583  // just 1 trial per thread.
584  const unsigned n_threads = (n_trials >= m_n_threads) ? m_n_threads : n_trials;
585  piranha_assert(n_threads > 0u);
586  // Trials per thread. This will always be at least 1.
587  const unsigned tpt = n_trials / n_threads;
588  piranha_assert(tpt >= 1u);
589  // The cumulative estimate.
590  integer c_estimate(0);
591  // Sync mutex - actually used only in multithreading.
592  std::mutex mut;
593  // The estimation functor.
594  auto estimator = [&lf, size1, n_threads, tpt, this, &c_estimate, &mut](unsigned thread_idx) {
595  piranha_assert(thread_idx < n_threads);
596  // Vectors of indices into m_v1.
597  std::vector<size_type> v_idx1(safe_cast<typename std::vector<size_type>::size_type>(size1));
598  std::iota(v_idx1.begin(), v_idx1.end(), size_type(0));
599  // Copy in order to reset to initial state later.
600  const auto v_idx1_copy = v_idx1;
601  // Random number engine.
602  std::mt19937 engine;
603  // Uniform int distribution.
604  using dist_type = std::uniform_int_distribution<size_type>;
605  dist_type dist;
606  // Init the accumulated estimation for averaging later.
607  integer acc(0);
608  // Number of trials for this thread - usual special casing for the last thread.
609  const unsigned cur_trials = (thread_idx == n_threads - 1u) ? (n_trials - thread_idx * tpt) : tpt;
610  // This should always be guaranteed because tpt is never 0.
611  piranha_assert(cur_trials > 0u);
612  // Create and setup the temp series.
613  Series tmp;
614  tmp.set_symbol_set(m_ss);
615  // Create the multiplier.
616  MultFunctor mf(*this, tmp);
617  // Go with the trials.
618  for (auto n = 0u; n < cur_trials; ++n) {
619  // Seed the engine. The seed should be the global trial number, accounting for multiple
620  // threads. This way the estimation will not depend on the number of threads.
621  engine.seed(static_cast<std::mt19937::result_type>(tpt * thread_idx + n));
622  // Reset the indices vector and re-randomise it.
623  // NOTE: we need to do this as every run inside this for loop must be completely independent
624  // of any previous run, we cannot keep any state.
625  v_idx1 = v_idx1_copy;
626  std::shuffle(v_idx1.begin(), v_idx1.end(), engine);
627  // The counter. This will be increased each time a term-by-term multiplication
628  // does not generate a duplicate term.
629  size_type count = 0u;
630  // This will be used to determine the average number of terms in s2
631  // that participate in the multiplication.
632  integer acc_s2(0);
633  auto it1 = v_idx1.begin();
634  for (; it1 != v_idx1.end(); ++it1) {
635  // Get the limit idx in s2.
636  const size_type limit = lf(*it1);
637  // This is the upper limit of an open ended interval, so it needs
638  // to be decreased by one in order to be used in dist. If zero, it means
639  // there are no terms in v2 that can be multiplied by the current term in t1.
640  if (limit == 0u) {
641  continue;
642  }
643  acc_s2 += limit;
644  // Pick a random index in m_v2 within the limit.
645  const size_type idx2
646  = dist(engine, typename dist_type::param_type(static_cast<size_type>(0u),
647  static_cast<size_type>(limit - 1u)));
648  // Perform term multiplication.
649  mf(*it1, idx2);
650  // Check for unlikely overflows when increasing count.
651  if (unlikely(result_size > std::numeric_limits<size_type>::max()
652  || count > std::numeric_limits<size_type>::max() - result_size)) {
653  piranha_throw(std::overflow_error, "overflow error");
654  }
655  if (tmp.size() != count + result_size) {
656  break;
657  }
658  // Increase cycle variables.
659  count = static_cast<size_type>(count + result_size);
660  }
661  integer add;
662  if (it1 == v_idx1.end()) {
663  // We never found a duplicate. count is now the number of terms in s1
664  // which actually participate in the multiplication, while acc_s2 / count
665  // is the average number of terms in s2 that participate in the multiplication.
666  // The result will be then count * acc_s2 / count = acc_s2.
667  add = acc_s2;
668  } else {
669  // If we found a duplicate, we use the heuristic.
670  add = integer(multiplier) * count * count;
671  }
672  // Fix if zero, so that the average later never results in zero.
673  if (add.sgn() == 0) {
674  add = 1;
675  }
676  acc += add;
677  // Reset tmp.
678  tmp._container().clear();
679  }
680  // Accumulate in the shared variable.
681  if (n_threads == 1u) {
682  // No locking needed.
683  c_estimate += acc;
684  } else {
685  std::lock_guard<std::mutex> lock(mut);
686  c_estimate += acc;
687  }
688  };
689  // Run the estimation functor.
690  if (n_threads == 1u) {
691  estimator(0u);
692  } else {
693  future_list<void> f_list;
694  try {
695  for (unsigned i = 0u; i < n_threads; ++i) {
696  f_list.push_back(thread_pool::enqueue(i, estimator, i));
697  }
698  // First let's wait for everything to finish.
699  f_list.wait_all();
700  // Then, let's handle the exceptions.
701  f_list.get_all();
702  } catch (...) {
703  f_list.wait_all();
704  throw;
705  }
706  }
707  piranha_assert(c_estimate >= n_trials);
708  // Return the mean.
709  return static_cast<bucket_size_type>(c_estimate / n_trials);
710  }
718  template <std::size_t MultArity, typename MultFunctor>
720  {
721  return estimate_final_series_size<MultArity, MultFunctor>(default_limit_functor{*this});
722  }
743  template <bool FastMode>
745  {
746  using term_type = typename Series::term_type;
747  using key_type = typename term_type::key_type;
748  PIRANHA_TT_CHECK(key_is_multipliable, typename term_type::cf_type, key_type);
749  using it_type = decltype(std::declval<Series &>()._container().end());
750  static constexpr std::size_t m_arity = key_type::multiply_arity;
752  public:
761  explicit plain_multiplier(const base_series_multiplier &bsm, Series &retval)
762  : m_v1(bsm.m_v1), m_v2(bsm.m_v2), m_retval(retval), m_c_end(retval._container().end())
763  {
764  }
766  private:
767  plain_multiplier(const plain_multiplier &) = delete;
768  plain_multiplier(plain_multiplier &&) = delete;
769  plain_multiplier &operator=(const plain_multiplier &) = delete;
770  plain_multiplier &operator=(plain_multiplier &&) = delete;
772  public:
787  void operator()(const size_type &i, const size_type &j) const
788  {
789  // First perform the multiplication.
790  key_type::multiply(m_tmp_t, *m_v1[i], *m_v2[j], m_retval.get_symbol_set());
791  for (std::size_t n = 0u; n < m_arity; ++n) {
792  auto &tmp_term = m_tmp_t[n];
793  if (FastMode) {
794  auto &container = m_retval._container();
795  // Try to locate the term into retval.
796  auto bucket_idx = container._bucket(tmp_term);
797  const auto it = container._find(tmp_term, bucket_idx);
798  if (it == m_c_end) {
799  container._unique_insert(term_insertion(tmp_term), bucket_idx);
800  } else {
801  it->m_cf += tmp_term.m_cf;
802  }
803  } else {
804  m_retval.insert(term_insertion(tmp_term));
805  }
806  }
807  }
809  private:
810  mutable std::array<term_type, m_arity> m_tmp_t;
811  const std::vector<term_type const *> &m_v1;
812  const std::vector<term_type const *> &m_v2;
813  Series &m_retval;
814  const it_type m_c_end;
815  };
845  static void sanitise_series(Series &retval, unsigned n_threads)
846  {
847  using term_type = typename Series::term_type;
848  if (unlikely(n_threads == 0u)) {
849  piranha_throw(std::invalid_argument, "invalid number of threads");
850  }
851  auto &container = retval._container();
852  const auto &args = retval.get_symbol_set();
853  // Reset the size to zero before doing anything.
854  container._update_size(static_cast<bucket_size_type>(0u));
855  // Single-thread implementation.
856  if (n_threads == 1u) {
857  const auto it_end = container.end();
858  for (auto it = container.begin(); it != it_end;) {
859  if (unlikely(!it->is_compatible(args))) {
860  piranha_throw(std::invalid_argument, "incompatible term");
861  }
862  if (unlikely(container.size() == std::numeric_limits<bucket_size_type>::max())) {
863  piranha_throw(std::overflow_error, "overflow error in the number of terms of a series");
864  }
865  // First update the size, it will be scaled back in the erase() method if necessary.
866  container._update_size(static_cast<bucket_size_type>(container.size() + 1u));
867  if (unlikely(it->is_zero(args))) {
868  it = container.erase(it);
869  } else {
870  ++it;
871  }
872  }
873  return;
874  }
875  // Multi-thread implementation.
876  const auto b_count = container.bucket_count();
877  std::mutex m;
878  integer global_count(0);
879  auto eraser = [b_count, &container, &m, &args, &global_count](const bucket_size_type &start,
880  const bucket_size_type &end) {
881  piranha_assert(start <= end && end <= b_count);
882  (void)b_count;
883  bucket_size_type count = 0u;
884  std::vector<term_type> term_list;
885  // Examine and count the terms bucket-by-bucket.
886  for (bucket_size_type i = start; i != end; ++i) {
887  term_list.clear();
888  const auto &bl = container._get_bucket_list(i);
889  const auto it_f = bl.end();
890  for (auto it = bl.begin(); it != it_f; ++it) {
891  // Check first for compatibility.
892  if (unlikely(!it->is_compatible(args))) {
893  piranha_throw(std::invalid_argument, "incompatible term");
894  }
895  // Check for ignorability.
896  if (unlikely(it->is_zero(args))) {
897  term_list.push_back(*it);
898  }
899  // Update the count of terms.
900  if (unlikely(count == std::numeric_limits<bucket_size_type>::max())) {
901  piranha_throw(std::overflow_error, "overflow error in the number of terms of a series");
902  }
903  count = static_cast<bucket_size_type>(count + 1u);
904  }
905  for (auto it = term_list.begin(); it != term_list.end(); ++it) {
906  // NOTE: must use _erase to avoid concurrent modifications
907  // to the number of elements in the table.
908  container._erase(container._find(*it, i));
909  // Account for the erased term.
910  piranha_assert(count > 0u);
911  count = static_cast<bucket_size_type>(count - 1u);
912  }
913  }
914  // Update the global count.
915  std::lock_guard<std::mutex> lock(m);
916  global_count += count;
917  };
919  try {
920  for (unsigned i = 0u; i < n_threads; ++i) {
921  const auto start = static_cast<bucket_size_type>((b_count / n_threads) * i),
922  end = static_cast<bucket_size_type>(
923  (i == n_threads - 1u) ? b_count : (b_count / n_threads) * (i + 1u));
924  f_list.push_back(thread_pool::enqueue(i, eraser, start, end));
925  }
926  // First let's wait for everything to finish.
927  f_list.wait_all();
928  // Then, let's handle the exceptions.
929  f_list.get_all();
930  } catch (...) {
931  f_list.wait_all();
932  // NOTE: there's not need to clear retval here - it was already in an inconsistent
933  // state coming into this method. We rather need to make sure sanitise_series() is always
934  // called in a try/catch block that clears retval in case of errors.
935  throw;
936  }
937  // Final update of the total count.
938  container._update_size(static_cast<bucket_size_type>(global_count));
939  }
977  template <typename LimitFunctor>
978  Series plain_multiplication(const LimitFunctor &lf) const
979  {
980  // Shortcuts.
981  using term_type = typename Series::term_type;
982  using cf_type = typename term_type::cf_type;
983  using key_type = typename term_type::key_type;
984  PIRANHA_TT_CHECK(key_is_multipliable, cf_type, key_type);
985  constexpr std::size_t m_arity = key_type::multiply_arity;
986  // Setup the return value with the merged symbol set.
987  Series retval;
988  retval.set_symbol_set(m_ss);
989  // Do not do anything if one of the two series is empty.
990  if (unlikely(m_v1.empty() || m_v2.empty())) {
991  return retval;
992  }
993  const size_type size1 = m_v1.size(), size2 = m_v2.size();
994  (void)size2;
995  piranha_assert(size1 && size2);
996  // Convert n_threads to size_type for convenience.
997  const size_type n_threads = safe_cast<size_type>(m_n_threads);
998  piranha_assert(n_threads);
999  // Determine if we should estimate the size. We check the threshold, but we always
1000  // need to estimate in multithreaded mode.
1001  bool estimate = true;
1002  const auto e_thr = tuning::get_estimate_threshold();
1003  if (integer(m_v1.size()) * m_v2.size() < integer(e_thr) * e_thr && n_threads == 1u) {
1004  estimate = false;
1005  }
1006  if (estimate) {
1007  // Estimate and rehash.
1008  const auto est = estimate_final_series_size<m_arity, plain_multiplier<false>>(lf);
1009  // NOTE: use numeric cast here as safe_cast is expensive, going through an integer-double conversion,
1010  // and in this case the behaviour of numeric_cast is appropriate.
1011  const auto n_buckets = boost::numeric_cast<bucket_size_type>(
1012  std::ceil(static_cast<double>(est) / retval._container().max_load_factor()));
1013  piranha_assert(n_buckets > 0u);
1014  // Check if we want to use the parallel memory set.
1015  // NOTE: it is important here that we use the same n_threads for multiplication and memset as
1016  // we tie together pinned threads with potentially different NUMA regions.
1017  const unsigned n_threads_rehash = tuning::get_parallel_memory_set() ? static_cast<unsigned>(n_threads) : 1u;
1018  retval._container().rehash(n_buckets, n_threads_rehash);
1019  }
1020  if (n_threads == 1u) {
1021  try {
1022  // Single-thread case.
1023  if (estimate) {
1024  blocked_multiplication(plain_multiplier<true>(*this, retval), 0u, size1, lf);
1025  // If we estimated beforehand, we need to sanitise the series.
1026  sanitise_series(retval, static_cast<unsigned>(n_threads));
1027  } else {
1028  blocked_multiplication(plain_multiplier<false>(*this, retval), 0u, size1, lf);
1029  }
1030  finalise_series(retval);
1031  return retval;
1032  } catch (...) {
1033  retval._container().clear();
1034  throw;
1035  }
1036  }
1037  // Multi-threaded case.
1038  piranha_assert(estimate);
1039  // Init the vector of spinlocks.
1040  detail::atomic_flag_array sl_array(safe_cast<std::size_t>(retval._container().bucket_count()));
1041  // Init the future list.
1042  future_list<void> f_list;
1043  // Thread block size.
1044  const auto block_size = size1 / n_threads;
1045  try {
1046  for (size_type idx = 0u; idx < n_threads; ++idx) {
1047  // Thread functor.
1048  auto tf = [idx, this, block_size, n_threads, &sl_array, &retval, &lf]() {
1049  // Used to store the result of term multiplication.
1050  std::array<term_type, key_type::multiply_arity> tmp_t;
1051  // End of retval container (thread-safe).
1052  const auto c_end = retval._container().end();
1053  // Block functor.
1054  // NOTE: this is very similar to the plain functor, but it does the bucket locking
1055  // additionally.
1056  auto f = [&c_end, &tmp_t, this, &retval, &sl_array](const size_type &i, const size_type &j) {
1057  // Run the term multiplication.
1058  key_type::multiply(tmp_t, *(this->m_v1[i]), *(this->m_v2[j]), retval.get_symbol_set());
1059  for (std::size_t n = 0u; n < key_type::multiply_arity; ++n) {
1060  auto &container = retval._container();
1061  auto &tmp_term = tmp_t[n];
1062  // Try to locate the term into retval.
1063  auto bucket_idx = container._bucket(tmp_term);
1064  // Lock the bucket.
1065  detail::atomic_lock_guard alg(sl_array[static_cast<std::size_t>(bucket_idx)]);
1066  const auto it = container._find(tmp_term, bucket_idx);
1067  if (it == c_end) {
1068  container._unique_insert(term_insertion(tmp_term), bucket_idx);
1069  } else {
1070  it->m_cf += tmp_term.m_cf;
1071  }
1072  }
1073  };
1074  // Thread block limit.
1075  const auto e1
1076  = (idx == n_threads - 1u) ? this->m_v1.size() : static_cast<size_type>((idx + 1u) * block_size);
1077  this->blocked_multiplication(f, static_cast<size_type>(idx * block_size), e1, lf);
1078  };
1079  f_list.push_back(thread_pool::enqueue(static_cast<unsigned>(idx), tf));
1080  }
1081  f_list.wait_all();
1082  f_list.get_all();
1083  sanitise_series(retval, static_cast<unsigned>(n_threads));
1084  finalise_series(retval);
1085  } catch (...) {
1086  f_list.wait_all();
1087  // Clean up retval as it might be in an inconsistent state.
1088  retval._container().clear();
1089  throw;
1090  }
1091  return retval;
1092  }
1100  Series plain_multiplication() const
1101  {
1102  return plain_multiplication(default_limit_functor{*this});
1103  }
1120  void finalise_series(Series &s) const
1121  {
1122  finalise_impl(s);
1123  }
1125 protected:
1127  mutable v_ptr m_v1;
1129  mutable v_ptr m_v2;
1138  unsigned m_n_threads;
1140 private:
1141  // See the constructor for an explanation.
1142  container_type m_zero_f1;
1143  container_type m_zero_f2;
1144 };
1145 }
1147 #endif
