29 #ifndef PIRANHA_DETAIL_BASE_SERIES_MULTIPLIER_HPP 30 #define PIRANHA_DETAIL_BASE_SERIES_MULTIPLIER_HPP 34 #include <boost/numeric/conversion/cast.hpp> 42 #include <type_traits> 46 #include <piranha/config.hpp> 47 #include <piranha/detail/atomic_flag_array.hpp> 48 #include <piranha/detail/atomic_lock_guard.hpp> 50 #include <piranha/key_is_multipliable.hpp> 51 #include <piranha/math.hpp> 52 #include <piranha/mp_integer.hpp> 53 #include <piranha/mp_rational.hpp> 54 #include <piranha/safe_cast.hpp> 55 #include <piranha/series.hpp> 56 #include <piranha/settings.hpp> 57 #include <piranha/symbol_utils.hpp> 58 #include <piranha/thread_pool.hpp> 59 #include <piranha/tuning.hpp> 68 template <
typename Series,
typename Derived,
typename =
void>
69 struct base_series_multiplier_impl {
70 using term_type =
typename Series::term_type;
71 using container_type =
typename std::decay<decltype(std::declval<Series>()._container())>::type;
72 using c_size_type =
typename container_type::size_type;
73 using v_size_type =
typename std::vector<term_type const *>::size_type;
74 template <
typename Term,
75 typename std::enable_if<!is_less_than_comparable<typename Term::key_type>::value,
int>::type = 0>
76 void fill_term_pointers(
const container_type &c1,
const container_type &c2, std::vector<Term const *> &v1,
77 std::vector<Term const *> &v2)
80 std::transform(c1.begin(), c1.end(), std::back_inserter(v1), [](
const term_type &t) {
return &t; });
81 std::transform(c2.begin(), c2.end(), std::back_inserter(v2), [](
const term_type &t) {
return &t; });
83 template <
typename Term,
84 typename std::enable_if<is_less_than_comparable<typename Term::key_type>::value,
int>::type = 0>
85 void fill_term_pointers(
const container_type &c1,
const container_type &c2, std::vector<Term const *> &v1,
86 std::vector<Term const *> &v2)
89 const unsigned n_threads =
static_cast<Derived *
>(
this)->m_n_threads;
90 piranha_assert(n_threads > 0u);
93 = [n_threads](
unsigned thread_idx,
const container_type *c, std::vector<term_type const *> *v) {
94 piranha_assert(thread_idx < n_threads);
96 const auto b_count = c->bucket_count();
98 const auto bpt = b_count / n_threads;
101 =
static_cast<c_size_type
>((thread_idx == n_threads - 1u) ? b_count : (bpt * (thread_idx + 1u)));
103 auto sorter = [](term_type
const *p1, term_type
const *p2) {
return p1->m_key < p2->m_key; };
105 for (
auto start = static_cast<c_size_type>(bpt * thread_idx); start < end; ++start) {
106 const auto &b = c->_get_bucket_list(start);
107 v_size_type tmp = 0u;
108 for (
const auto &t : b) {
112 std::stable_sort(v->data() + j, v->data() + j + tmp, sorter);
116 if (n_threads == 1u) {
117 thread_func(0u, &c1, &v1);
118 thread_func(0u, &c2, &v2);
121 auto thread_wrapper = [&thread_func, n_threads](
const container_type *c, std::vector<term_type const *> *v) {
124 using vv_t = std::vector<std::vector<Term const *>>;
125 using vv_size_t =
typename vv_t::size_type;
126 vv_t vv(safe_cast<vv_size_t>(n_threads));
128 future_list<void> ff_list;
130 for (
unsigned i = 0u; i < n_threads; ++i) {
131 ff_list.push_back(
thread_pool::enqueue(i, thread_func, i, c, &(vv[static_cast<vv_size_t>(i)])));
142 for (
const auto &vi : vv) {
143 v->insert(v->end(), vi.begin(), vi.end());
146 thread_wrapper(&c1, &v1);
147 thread_wrapper(&c2, &v2);
151 template <
typename Series,
typename Derived>
152 struct base_series_multiplier_impl<
153 Series, Derived, typename
std::enable_if<is_mp_rational<typename Series::term_type::cf_type>::value>::type> {
155 using term_type =
typename Series::term_type;
156 using rat_type =
typename term_type::cf_type;
157 using int_type =
typename std::decay<decltype(std::declval<rat_type>().num())>::type;
158 using container_type =
typename std::decay<decltype(std::declval<Series>()._container())>::type;
159 void fill_term_pointers(
const container_type &c1,
const container_type &c2, std::vector<term_type const *> &v1,
160 std::vector<term_type const *> &v2)
164 auto it_f = c1.end();
166 for (
auto it = c1.begin(); it != it_f; ++it) {
169 divexact(m_lcm, m_lcm, g);
172 for (
auto it = c2.begin(); it != it_f; ++it) {
175 divexact(m_lcm, m_lcm, g);
179 piranha_assert(m_lcm.sgn() == 1);
182 for (
auto it = c1.begin(); it != it_f; ++it) {
184 m_terms1.push_back(term_type(rat_type(m_lcm / it->m_cf.den() * it->m_cf.num(), int_type(1)), it->m_key));
187 for (
auto it = c2.begin(); it != it_f; ++it) {
188 m_terms2.push_back(term_type(rat_type(m_lcm / it->m_cf.den() * it->m_cf.num(), int_type(1)), it->m_key));
191 std::transform(m_terms1.begin(), m_terms1.end(), std::back_inserter(v1), [](
const term_type &t) {
return &t; });
192 std::transform(m_terms2.begin(), m_terms2.end(), std::back_inserter(v2), [](
const term_type &t) {
return &t; });
193 piranha_assert(v1.size() == c1.size());
194 piranha_assert(v2.size() == c2.size());
196 std::vector<term_type> m_terms1;
197 std::vector<term_type> m_terms2;
227 template <
typename Series>
234 using container_type = uncvref_t<decltype(std::declval<const Series &>()._container())>;
238 using v_ptr = std::vector<typename Series::term_type const *>;
246 struct default_limit_functor {
257 template <typename Term, typename std::enable_if<!is_series<typename Term::cf_type>::value,
int>::type = 0>
258 static Term &term_insertion(Term &t)
262 template <typename Term, typename std::enable_if<is_series<typename Term::cf_type>::value,
int>::type = 0>
263 static Term term_insertion(Term &t)
265 return Term{std::move(t.m_cf), t.m_key};
268 template <typename T, typename std::enable_if<is_mp_rational<typename T::term_type::cf_type>::value,
int>::type = 0>
269 void finalise_impl(T &s)
const 277 const auto l2 = this->m_lcm * this->m_lcm;
278 auto &container = s._container();
281 for (
const auto &t : container) {
283 t.m_cf.canonicalise();
290 auto thread_func = [l2, &container,
this, bpt](
unsigned t_idx) {
294 ? container.bucket_count()
296 for (; start_idx != end_idx; ++start_idx) {
297 auto &list = container._get_bucket_list(start_idx);
298 for (
const auto &t : list) {
300 t.m_cf.canonicalise();
305 future_list<decltype(thread_func(0u))> ff_list;
319 template <
typename T,
320 typename std::enable_if<!is_mp_rational<typename T::term_type::cf_type>::value,
int>::type = 0>
321 void finalise_impl(T &)
const 357 if (unlikely(s1.get_symbol_set() != s2.get_symbol_set())) {
358 piranha_throw(std::invalid_argument,
"incompatible arguments sets");
361 const Series *p1 = &s1, *p2 = &s2;
362 if (s1.size() < s2.size()) {
366 m_v1.reserve(static_cast<size_type>(p1->size()));
367 m_v2.reserve(static_cast<size_type>(p2->size()));
368 container_type
const *ctr1 = &p1->_container(), *ctr2 = &p2->_container();
375 using term_type =
typename Series::term_type;
376 using cf_type =
typename term_type::cf_type;
377 using key_type =
typename term_type::key_type;
379 m_zero_f1.insert(term_type{cf_type(0), key_type(s1.get_symbol_set())});
383 m_zero_f2.insert(term_type{cf_type(0), key_type(s1.get_symbol_set())});
392 this->fill_term_pointers(*ctr1, *ctr2,
m_v1,
m_v2);
441 template <
typename MultFunctor,
typename LimitFunctor>
443 const LimitFunctor &lf)
const 446 if (unlikely(start1 > end1 || start1 >
m_v1.size() || end1 >
m_v1.size())) {
447 piranha_throw(std::invalid_argument,
"invalid bounds in blocked_multiplication");
451 nblocks1 = static_cast<size_type>((end1 - start1) / bsize),
452 nblocks2 = static_cast<size_type>(
m_v2.size() / bsize);
454 const size_type i_ir_start =
static_cast<size_type>(nblocks1 * bsize + start1), i_ir_end = end1;
456 for (
size_type n1 = 0u; n1 < nblocks1; ++n1) {
458 i_end = static_cast<size_type>(i_start + bsize);
460 for (
size_type n2 = 0u; n2 < nblocks2; ++n2) {
462 j_end = static_cast<size_type>(j_start + bsize);
463 for (
size_type i = i_start; i < i_end; ++i) {
464 const size_type limit = std::min<size_type>(lf(i), j_end);
465 for (
size_type j = j_start; j < limit; ++j) {
471 for (
size_type i = i_start; i < i_end; ++i) {
472 const size_type limit = std::min<size_type>(lf(i), j_ir_end);
473 for (
size_type j = j_ir_start; j < limit; ++j) {
479 for (
size_type n2 = 0u; n2 < nblocks2; ++n2) {
481 j_end = static_cast<size_type>(j_start + bsize);
482 for (
size_type i = i_ir_start; i < i_ir_end; ++i) {
483 const size_type limit = std::min<size_type>(lf(i), j_end);
484 for (
size_type j = j_start; j < limit; ++j) {
490 for (
size_type i = i_ir_start; i < i_ir_end; ++i) {
491 const size_type limit = std::min<size_type>(lf(i), j_ir_end);
492 for (
size_type j = j_ir_start; j < limit; ++j) {
508 template <
typename MultFunctor>
555 template <std::
size_t MultArity,
typename MultFunctor,
typename LimitFunctor>
563 constexpr std::size_t result_size = MultArity;
565 if (unlikely(!size1 || !size2)) {
569 if (size1 == 1u || size2 == 1u) {
575 const unsigned n_trials = 15u;
578 const unsigned multiplier = 2u;
585 piranha_assert(n_threads > 0u);
587 const unsigned tpt = n_trials / n_threads;
588 piranha_assert(tpt >= 1u);
594 auto estimator = [&lf, size1, n_threads, tpt,
this, &c_estimate, &mut](
unsigned thread_idx) {
595 piranha_assert(thread_idx < n_threads);
597 std::vector<size_type> v_idx1(
safe_cast<
typename std::vector<size_type>::size_type>(size1));
598 std::iota(v_idx1.begin(), v_idx1.end(),
size_type(0));
600 const auto v_idx1_copy = v_idx1;
604 using dist_type = std::uniform_int_distribution<size_type>;
609 const unsigned cur_trials = (thread_idx == n_threads - 1u) ? (n_trials - thread_idx * tpt) : tpt;
611 piranha_assert(cur_trials > 0u);
614 tmp.set_symbol_set(
m_ss);
616 MultFunctor mf(*
this, tmp);
618 for (
auto n = 0u; n < cur_trials; ++n) {
621 engine.seed(static_cast<std::mt19937::result_type>(tpt * thread_idx + n));
625 v_idx1 = v_idx1_copy;
626 std::shuffle(v_idx1.begin(), v_idx1.end(), engine);
633 auto it1 = v_idx1.begin();
634 for (; it1 != v_idx1.end(); ++it1) {
646 = dist(engine,
typename dist_type::param_type(static_cast<size_type>(0u),
647 static_cast<size_type>(limit - 1u)));
651 if (unlikely(result_size > std::numeric_limits<size_type>::max()
652 || count > std::numeric_limits<size_type>::max() - result_size)) {
655 if (tmp.size() != count + result_size) {
659 count =
static_cast<size_type>(count + result_size);
662 if (it1 == v_idx1.end()) {
670 add =
integer(multiplier) * count * count;
673 if (add.
sgn() == 0) {
678 tmp._container().clear();
681 if (n_threads == 1u) {
685 std::lock_guard<std::mutex> lock(mut);
690 if (n_threads == 1u) {
695 for (
unsigned i = 0u; i < n_threads; ++i) {
707 piranha_assert(c_estimate >= n_trials);
718 template <std::
size_t MultArity,
typename MultFunctor>
721 return estimate_final_series_size<MultArity, MultFunctor>(default_limit_functor{*
this});
743 template <
bool FastMode>
746 using term_type =
typename Series::term_type;
747 using key_type =
typename term_type::key_type;
749 using it_type = decltype(std::declval<Series &>()._container().end());
750 static constexpr std::size_t m_arity = key_type::multiply_arity;
762 : m_v1(bsm.m_v1), m_v2(bsm.m_v2), m_retval(retval), m_c_end(retval._container().end())
790 key_type::multiply(m_tmp_t, *m_v1[i], *m_v2[j], m_retval.get_symbol_set());
791 for (std::size_t n = 0u; n < m_arity; ++n) {
792 auto &tmp_term = m_tmp_t[n];
794 auto &container = m_retval._container();
796 auto bucket_idx = container._bucket(tmp_term);
797 const auto it = container._find(tmp_term, bucket_idx);
799 container._unique_insert(term_insertion(tmp_term), bucket_idx);
801 it->m_cf += tmp_term.m_cf;
804 m_retval.insert(term_insertion(tmp_term));
810 mutable std::array<term_type, m_arity> m_tmp_t;
811 const std::vector<term_type const *> &m_v1;
812 const std::vector<term_type const *> &m_v2;
814 const it_type m_c_end;
847 using term_type =
typename Series::term_type;
848 if (unlikely(n_threads == 0u)) {
849 piranha_throw(std::invalid_argument,
"invalid number of threads");
851 auto &container = retval._container();
852 const auto &args = retval.get_symbol_set();
854 container._update_size(static_cast<bucket_size_type>(0u));
856 if (n_threads == 1u) {
857 const auto it_end = container.end();
858 for (
auto it = container.begin(); it != it_end;) {
859 if (unlikely(!it->is_compatible(args))) {
862 if (unlikely(container.size() == std::numeric_limits<bucket_size_type>::max())) {
863 piranha_throw(std::overflow_error,
"overflow error in the number of terms of a series");
866 container._update_size(static_cast<bucket_size_type>(container.size() + 1u));
867 if (unlikely(it->is_zero(args))) {
868 it = container.erase(it);
876 const auto b_count = container.bucket_count();
879 auto eraser = [b_count, &container, &m, &args, &global_count](
const bucket_size_type &start,
881 piranha_assert(start <= end && end <= b_count);
884 std::vector<term_type> term_list;
888 const auto &bl = container._get_bucket_list(i);
889 const auto it_f = bl.end();
890 for (
auto it = bl.begin(); it != it_f; ++it) {
892 if (unlikely(!it->is_compatible(args))) {
896 if (unlikely(it->is_zero(args))) {
897 term_list.push_back(*it);
900 if (unlikely(count == std::numeric_limits<bucket_size_type>::max())) {
901 piranha_throw(std::overflow_error,
"overflow error in the number of terms of a series");
905 for (
auto it = term_list.begin(); it != term_list.end(); ++it) {
908 container._erase(container._find(*it, i));
910 piranha_assert(count > 0u);
915 std::lock_guard<std::mutex> lock(m);
916 global_count += count;
920 for (
unsigned i = 0u; i < n_threads; ++i) {
921 const auto start =
static_cast<bucket_size_type>((b_count / n_threads) * i),
923 (i == n_threads - 1u) ? b_count : (b_count / n_threads) * (i + 1u));
938 container._update_size(static_cast<bucket_size_type>(global_count));
977 template <
typename LimitFunctor>
981 using term_type =
typename Series::term_type;
982 using cf_type =
typename term_type::cf_type;
983 using key_type =
typename term_type::key_type;
985 constexpr std::size_t m_arity = key_type::multiply_arity;
988 retval.set_symbol_set(
m_ss);
990 if (unlikely(
m_v1.empty() ||
m_v2.empty())) {
995 piranha_assert(size1 && size2);
998 piranha_assert(n_threads);
1001 bool estimate =
true;
1008 const auto est = estimate_final_series_size<m_arity, plain_multiplier<false>>(lf);
1012 std::ceil(static_cast<double>(est) / retval._container().max_load_factor()));
1013 piranha_assert(n_buckets > 0u);
1018 retval._container().rehash(n_buckets, n_threads_rehash);
1020 if (n_threads == 1u) {
1033 retval._container().clear();
1038 piranha_assert(estimate);
1040 detail::atomic_flag_array sl_array(safe_cast<std::size_t>(retval._container().bucket_count()));
1044 const auto block_size = size1 / n_threads;
1046 for (
size_type idx = 0u; idx < n_threads; ++idx) {
1048 auto tf = [idx,
this, block_size, n_threads, &sl_array, &retval, &lf]() {
1050 std::array<term_type, key_type::multiply_arity> tmp_t;
1052 const auto c_end = retval._container().end();
1056 auto f = [&c_end, &tmp_t,
this, &retval, &sl_array](
const size_type &i,
const size_type &j) {
1058 key_type::multiply(tmp_t, *(this->
m_v1[i]), *(this->
m_v2[j]), retval.get_symbol_set());
1059 for (std::size_t n = 0u; n < key_type::multiply_arity; ++n) {
1060 auto &container = retval._container();
1061 auto &tmp_term = tmp_t[n];
1063 auto bucket_idx = container._bucket(tmp_term);
1065 detail::atomic_lock_guard alg(sl_array[static_cast<std::size_t>(bucket_idx)]);
1066 const auto it = container._find(tmp_term, bucket_idx);
1068 container._unique_insert(term_insertion(tmp_term), bucket_idx);
1070 it->m_cf += tmp_term.m_cf;
1076 = (idx == n_threads - 1u) ? this->
m_v1.size() :
static_cast<size_type>((idx + 1u) * block_size);
1088 retval._container().clear();
1142 container_type m_zero_f1;
1143 container_type m_zero_f2;
typename Series::size_type bucket_size_type
The size type of Series.
void operator()(const size_type &i, const size_type &j) const
Call operator.
void wait_all()
Wait on all the futures.
Type trait for multipliable key.
void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1, const LimitFunctor &lf) const
Blocked multiplication.
Function object type trait.
v_ptr m_v2
Vector of const pointers to the terms in the smaller series.
void blocked_multiplication(const MultFunctor &mf, const size_type &start1, const size_type &end1) const
Blocked multiplication (convenience overload).
Multiprecision integer class.
A plain multiplier functor.
mp_integer< 1 > integer
Alias for piranha::mp_integer with 1 limb of static storage.
Series plain_multiplication(const LimitFunctor &lf) const
A plain series multiplication routine.
void finalise_series(Series &s) const
Finalise series.
bucket_size_type estimate_final_series_size() const
Estimate size of series multiplication (convenience overload)
static unsigned use_threads(const Int &work_size, const Int &min_work_per_thread)
Compute number of threads to use.
bucket_size_type estimate_final_series_size(const LimitFunctor &lf) const
Estimate size of series multiplication.
Series plain_multiplication() const
A plain series multiplication routine (convenience overload).
void get_all()
Get all the futures.
static enqueue_t< F &&, Args &&... > enqueue(unsigned n, F &&f, Args &&... args)
Enqueue task.
base_series_multiplier(const Series &s1, const Series &s2)
Constructor.
const symbol_fset m_ss
The symbol set of the series used during construction.
typename v_ptr::size_type size_type
The size type of base_series_multiplier::v_ptr.
#define piranha_throw(exception_type,...)
Exception-throwing macro.
boost::container::flat_set< std::string > symbol_fset
Flat set of symbols.
Detect if zero is a multiplicative absorber.
v_ptr m_v1
Vector of const pointers to the terms in the larger series.
Class to store a list of futures.
static unsigned long get_estimate_threshold()
Get the series estimation threshold.
auto mul3(T &a, const T &b, const T &c) -> decltype(mul3_impl< T >()(a, b, c))
Ternary multiplication.
plain_multiplier(const base_series_multiplier &bsm, Series &retval)
Constructor.
static unsigned long get_multiplication_block_size()
Get the multiplication block size.
static bool get_parallel_memory_set()
Get the parallel_memory_set flag.
static void sanitise_series(Series &retval, unsigned n_threads)
Sanitise series.
Type trait to detect series types.
unsigned m_n_threads
Number of threads.
static unsigned long long get_min_work_per_thread()
Get the minimum work per thread.
auto gcd3(T &out, const T &a, const T &b) -> decltype(gcd3_impl< T >()(out, a, b))
Ternary GCD.
std::vector< typename Series::term_type const * > v_ptr
Alias for a vector of const pointers to series terms.
safe_cast_type< To, From > safe_cast(const From &x)
Safe cast.
bool is_unitary(const T &x)
Unitary test.
void push_back(std::future< T > &&f)
Move-insert a future.