LCOV - code coverage report
Current view: top level - capy - when_all.hpp (source / functions) Coverage Total Hit Missed
Test: coverage_remapped.info Lines: 98.0 % 98 96 2
Test Date: 2026-02-17 18:14:47 Functions: 89.0 % 617 549 68

           TLA  Line data    Source code
       1                 : //
       2                 : // Copyright (c) 2026 Steve Gerbino
       3                 : //
       4                 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
       5                 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
       6                 : //
       7                 : // Official repository: https://github.com/cppalliance/capy
       8                 : //
       9                 : 
      10                 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
      11                 : #define BOOST_CAPY_WHEN_ALL_HPP
      12                 : 
      13                 : #include <boost/capy/detail/config.hpp>
      14                 : #include <boost/capy/concept/executor.hpp>
      15                 : #include <boost/capy/concept/io_awaitable.hpp>
      16                 : #include <coroutine>
      17                 : #include <boost/capy/ex/io_env.hpp>
      18                 : #include <boost/capy/ex/frame_allocator.hpp>
      19                 : #include <boost/capy/task.hpp>
      20                 : 
      21                 : #include <array>
      22                 : #include <atomic>
      23                 : #include <exception>
      24                 : #include <optional>
      25                 : #include <stop_token>
      26                 : #include <tuple>
      27                 : #include <type_traits>
      28                 : #include <utility>
      29                 : 
      30                 : namespace boost {
      31                 : namespace capy {
      32                 : 
      33                 : namespace detail {
      34                 : 
      35                 : /** Type trait to filter void types from a tuple.
      36                 : 
      37                 :     Void-returning tasks do not contribute a value to the result tuple.
      38                 :     This trait computes the filtered result type.
      39                 : 
      40                 :     Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
      41                 : */
      42                 : template<typename T>
      43                 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
      44                 : 
      45                 : template<typename... Ts>
      46                 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
      47                 : 
      48                 : /** Holds the result of a single task within when_all.
      49                 : */
      50                 : template<typename T>
      51                 : struct result_holder
      52                 : {
      53                 :     std::optional<T> value_;
      54                 : 
      55 HIT          60 :     void set(T v)
      56                 :     {
      57              60 :         value_ = std::move(v);
      58              60 :     }
      59                 : 
      60              53 :     T get() &&
      61                 :     {
      62              53 :         return std::move(*value_);
      63                 :     }
      64                 : };
      65                 : 
      66                 : /** Specialization for void tasks - no value storage needed.
      67                 : */
      68                 : template<>
      69                 : struct result_holder<void>
      70                 : {
      71                 : };
      72                 : 
      73                 : /** Shared state for when_all operation.
      74                 : 
      75                 :     @tparam Ts The result types of the tasks.
      76                 : */
      77                 : template<typename... Ts>
      78                 : struct when_all_state
      79                 : {
      80                 :     static constexpr std::size_t task_count = sizeof...(Ts);
      81                 : 
      82                 :     // Completion tracking - when_all waits for all children
      83                 :     std::atomic<std::size_t> remaining_count_;
      84                 : 
      85                 :     // Result storage in input order
      86                 :     std::tuple<result_holder<Ts>...> results_;
      87                 : 
      88                 :     // Runner handles - destroyed in await_resume while allocator is valid
      89                 :     std::array<std::coroutine_handle<>, task_count> runner_handles_{};
      90                 : 
      91                 :     // Exception storage - first error wins, others discarded
      92                 :     std::atomic<bool> has_exception_{false};
      93                 :     std::exception_ptr first_exception_;
      94                 : 
      95                 :     // Stop propagation - on error, request stop for siblings
      96                 :     std::stop_source stop_source_;
      97                 : 
      98                 :     // Connects parent's stop_token to our stop_source
      99                 :     struct stop_callback_fn
     100                 :     {
     101                 :         std::stop_source* source_;
     102               4 :         void operator()() const { source_->request_stop(); }
     103                 :     };
     104                 :     using stop_callback_t = std::stop_callback<stop_callback_fn>;
     105                 :     std::optional<stop_callback_t> parent_stop_callback_;
     106                 : 
     107                 :     // Parent resumption
     108                 :     std::coroutine_handle<> continuation_;
     109                 :     io_env const* caller_env_ = nullptr;
     110                 : 
     111              59 :     when_all_state()
     112              59 :         : remaining_count_(task_count)
     113                 :     {
     114              59 :     }
     115                 : 
     116                 :     // Runners self-destruct in final_suspend. No destruction needed here.
     117                 : 
     118                 :     /** Capture an exception (first one wins).
     119                 :     */
     120              20 :     void capture_exception(std::exception_ptr ep)
     121                 :     {
     122              20 :         bool expected = false;
     123              20 :         if(has_exception_.compare_exchange_strong(
     124                 :             expected, true, std::memory_order_relaxed))
     125              17 :             first_exception_ = ep;
     126              20 :     }
     127                 : 
     128                 : };
     129                 : 
     130                 : /** Wrapper coroutine that intercepts task completion.
     131                 : 
     132                 :     This runner awaits its assigned task and stores the result in
     133                 :     the shared state, or captures the exception and requests stop.
     134                 : */
     135                 : template<typename T, typename... Ts>
     136                 : struct when_all_runner
     137                 : {
     138                 :     struct promise_type // : frame_allocating_base  // DISABLED FOR TESTING
     139                 :     {
     140                 :         when_all_state<Ts...>* state_ = nullptr;
     141                 :         io_env env_;
     142                 : 
     143             130 :         when_all_runner get_return_object()
     144                 :         {
     145             130 :             return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
     146                 :         }
     147                 : 
     148             130 :         std::suspend_always initial_suspend() noexcept
     149                 :         {
     150             130 :             return {};
     151                 :         }
     152                 : 
     153             130 :         auto final_suspend() noexcept
     154                 :         {
     155                 :             struct awaiter
     156                 :             {
     157                 :                 promise_type* p_;
     158                 : 
     159             130 :                 bool await_ready() const noexcept
     160                 :                 {
     161             130 :                     return false;
     162                 :                 }
     163                 : 
     164             130 :                 std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept
     165                 :                 {
     166                 :                     // Extract everything needed before self-destruction.
     167             130 :                     auto* state = p_->state_;
     168             130 :                     auto* counter = &state->remaining_count_;
     169             130 :                     auto* caller_env = state->caller_env_;
     170             130 :                     auto cont = state->continuation_;
     171                 : 
     172             130 :                     h.destroy();
     173                 : 
     174                 :                     // If last runner, dispatch parent for symmetric transfer.
     175             130 :                     auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
     176             130 :                     if(remaining == 1)
     177              59 :                         return caller_env->executor.dispatch(cont);
     178              71 :                     return std::noop_coroutine();
     179                 :                 }
     180                 : 
     181 MIS           0 :                 void await_resume() const noexcept
     182                 :                 {
     183               0 :                 }
     184                 :             };
     185 HIT         130 :             return awaiter{this};
     186                 :         }
     187                 : 
     188             110 :         void return_void()
     189                 :         {
     190             110 :         }
     191                 : 
     192              20 :         void unhandled_exception()
     193                 :         {
     194              20 :             state_->capture_exception(std::current_exception());
     195                 :             // Request stop for sibling tasks
     196              20 :             state_->stop_source_.request_stop();
     197              20 :         }
     198                 : 
     199                 :         template<class Awaitable>
     200                 :         struct transform_awaiter
     201                 :         {
     202                 :             std::decay_t<Awaitable> a_;
     203                 :             promise_type* p_;
     204                 : 
     205             130 :             bool await_ready()
     206                 :             {
     207             130 :                 return a_.await_ready();
     208                 :             }
     209                 : 
     210             130 :             decltype(auto) await_resume()
     211                 :             {
     212             130 :                 return a_.await_resume();
     213                 :             }
     214                 : 
     215                 :             template<class Promise>
     216             129 :             auto await_suspend(std::coroutine_handle<Promise> h)
     217                 :             {
     218             129 :                 return a_.await_suspend(h, &p_->env_);
     219                 :             }
     220                 :         };
     221                 : 
     222                 :         template<class Awaitable>
     223             130 :         auto await_transform(Awaitable&& a)
     224                 :         {
     225                 :             using A = std::decay_t<Awaitable>;
     226                 :             if constexpr (IoAwaitable<A>)
     227                 :             {
     228                 :                 return transform_awaiter<Awaitable>{
     229             260 :                     std::forward<Awaitable>(a), this};
     230                 :             }
     231                 :             else
     232                 :             {
     233                 :                 static_assert(sizeof(A) == 0, "requires IoAwaitable");
     234                 :             }
     235             130 :         }
     236                 :     };
     237                 : 
     238                 :     std::coroutine_handle<promise_type> h_;
     239                 : 
     240             130 :     explicit when_all_runner(std::coroutine_handle<promise_type> h)
     241             130 :         : h_(h)
     242                 :     {
     243             130 :     }
     244                 : 
     245                 :     // Enable move for all clang versions - some versions need it
     246                 :     when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
     247                 : 
     248                 :     // Non-copyable
     249                 :     when_all_runner(when_all_runner const&) = delete;
     250                 :     when_all_runner& operator=(when_all_runner const&) = delete;
     251                 :     when_all_runner& operator=(when_all_runner&&) = delete;
     252                 : 
     253             130 :     auto release() noexcept
     254                 :     {
     255             130 :         return std::exchange(h_, nullptr);
     256                 :     }
     257                 : };
     258                 : 
     259                 : /** Create a runner coroutine for a single awaitable.
     260                 : 
     261                 :     Awaitable is passed directly to ensure proper coroutine frame storage.
     262                 : */
     263                 : template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
     264                 : when_all_runner<awaitable_result_t<Awaitable>, Ts...>
     265             130 : make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
     266                 : {
     267                 :     using T = awaitable_result_t<Awaitable>;
     268                 :     if constexpr (std::is_void_v<T>)
     269                 :     {
     270                 :         co_await std::move(inner);
     271                 :     }
     272                 :     else
     273                 :     {
     274                 :         std::get<Index>(state->results_).set(co_await std::move(inner));
     275                 :     }
     276             260 : }
     277                 : 
     278                 : /** Internal awaitable that launches all runner coroutines and waits.
     279                 : 
     280                 :     This awaitable is used inside the when_all coroutine to handle
     281                 :     the concurrent execution of child awaitables.
     282                 : */
     283                 : template<IoAwaitable... Awaitables>
     284                 : class when_all_launcher
     285                 : {
     286                 :     using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
     287                 : 
     288                 :     std::tuple<Awaitables...>* awaitables_;
     289                 :     state_type* state_;
     290                 : 
     291                 : public:
     292              59 :     when_all_launcher(
     293                 :         std::tuple<Awaitables...>* awaitables,
     294                 :         state_type* state)
     295              59 :         : awaitables_(awaitables)
     296              59 :         , state_(state)
     297                 :     {
     298              59 :     }
     299                 : 
     300              59 :     bool await_ready() const noexcept
     301                 :     {
     302              59 :         return sizeof...(Awaitables) == 0;
     303                 :     }
     304                 : 
     305              59 :     std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
     306                 :     {
     307              59 :         state_->continuation_ = continuation;
     308              59 :         state_->caller_env_ = caller_env;
     309                 : 
     310                 :         // Forward parent's stop requests to children
     311              59 :         if(caller_env->stop_token.stop_possible())
     312                 :         {
     313              16 :             state_->parent_stop_callback_.emplace(
     314               8 :                 caller_env->stop_token,
     315               8 :                 typename state_type::stop_callback_fn{&state_->stop_source_});
     316                 : 
     317               8 :             if(caller_env->stop_token.stop_requested())
     318               4 :                 state_->stop_source_.request_stop();
     319                 :         }
     320                 : 
     321                 :         // CRITICAL: If the last task finishes synchronously then the parent
     322                 :         // coroutine resumes, destroying its frame, and destroying this object
     323                 :         // prior to the completion of await_suspend. Therefore, await_suspend
     324                 :         // must ensure `this` cannot be referenced after calling `launch_one`
     325                 :         // for the last time.
     326              59 :         auto token = state_->stop_source_.get_token();
     327              58 :         [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     328              59 :             (..., launch_one<Is>(caller_env->executor, token));
     329              59 :         }(std::index_sequence_for<Awaitables...>{});
     330                 : 
     331                 :         // Let signal_completion() handle resumption
     332             118 :         return std::noop_coroutine();
     333              59 :     }
     334                 : 
     335              59 :     void await_resume() const noexcept
     336                 :     {
     337                 :         // Results are extracted by the when_all coroutine from state
     338              59 :     }
     339                 : 
     340                 : private:
     341                 :     template<std::size_t I>
     342             130 :     void launch_one(executor_ref caller_ex, std::stop_token token)
     343                 :     {
     344             130 :         auto runner = make_when_all_runner<I>(
     345             130 :             std::move(std::get<I>(*awaitables_)), state_);
     346                 : 
     347             130 :         auto h = runner.release();
     348             130 :         h.promise().state_ = state_;
     349             130 :         h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->allocator};
     350                 : 
     351             130 :         std::coroutine_handle<> ch{h};
     352             130 :         state_->runner_handles_[I] = ch;
     353             130 :         state_->caller_env_->executor.post(ch);
     354             260 :     }
     355                 : };
     356                 : 
     357                 : /** Compute the result type for when_all.
     358                 : 
     359                 :     Returns void when all tasks are void (P2300 aligned),
     360                 :     otherwise returns a tuple with void types filtered out.
     361                 : */
     362                 : template<typename... Ts>
     363                 : using when_all_result_t = std::conditional_t<
     364                 :     std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
     365                 :     void,
     366                 :     filter_void_tuple_t<Ts...>>;
     367                 : 
     368                 : /** Helper to extract a single result, returning empty tuple for void.
     369                 :     This is a separate function to work around a GCC-11 ICE that occurs
     370                 :     when using nested immediately-invoked lambdas with pack expansion.
     371                 : */
     372                 : template<std::size_t I, typename... Ts>
     373              57 : auto extract_single_result(when_all_state<Ts...>& state)
     374                 : {
     375                 :     using T = std::tuple_element_t<I, std::tuple<Ts...>>;
     376                 :     if constexpr (std::is_void_v<T>)
     377               4 :         return std::tuple<>();
     378                 :     else
     379              53 :         return std::make_tuple(std::move(std::get<I>(state.results_)).get());
     380                 : }
     381                 : 
     382                 : /** Extract results from state, filtering void types.
     383                 : */
     384                 : template<typename... Ts>
     385              24 : auto extract_results(when_all_state<Ts...>& state)
     386                 : {
     387              43 :     return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
     388              25 :         return std::tuple_cat(extract_single_result<Is>(state)...);
     389              48 :     }(std::index_sequence_for<Ts...>{});
     390                 : }
     391                 : 
     392                 : } // namespace detail
     393                 : 
     394                 : /** Execute multiple awaitables concurrently and collect their results.
     395                 : 
     396                 :     Launches all awaitables simultaneously and waits for all to complete
     397                 :     before returning. Results are collected in input order. If any
     398                 :     awaitable throws, cancellation is requested for siblings and the first
     399                 :     exception is rethrown after all awaitables complete.
     400                 : 
     401                 :     @li All child awaitables run concurrently on the caller's executor
     402                 :     @li Results are returned as a tuple in input order
     403                 :     @li Void-returning awaitables do not contribute to the result tuple
     404                 :     @li If all awaitables return void, `when_all` returns `task<void>`
     405                 :     @li First exception wins; subsequent exceptions are discarded
     406                 :     @li Stop is requested for siblings on first error
     407                 :     @li Completes only after all children have finished
     408                 : 
     409                 :     @par Thread Safety
     410                 :     The returned task must be awaited from a single execution context.
     411                 :     Child awaitables execute concurrently but complete through the caller's
     412                 :     executor.
     413                 : 
     414                 :     @param awaitables The awaitables to execute concurrently. Each must
     415                 :         satisfy @ref IoAwaitable and is consumed (moved-from) when
     416                 :         `when_all` is awaited.
     417                 : 
     418                 :     @return A task yielding a tuple of non-void results. Returns
     419                 :         `task<void>` when all input awaitables return void.
     420                 : 
     421                 :     @par Example
     422                 : 
     423                 :     @code
     424                 :     task<> example()
     425                 :     {
     426                 :         // Concurrent fetch, results collected in order
     427                 :         auto [user, posts] = co_await when_all(
     428                 :             fetch_user( id ),      // task<User>
     429                 :             fetch_posts( id )      // task<std::vector<Post>>
     430                 :         );
     431                 : 
     432                 :         // Void awaitables don't contribute to result
     433                 :         co_await when_all(
     434                 :             log_event( "start" ),  // task<void>
     435                 :             notify_user( id )      // task<void>
     436                 :         );
     437                 :         // Returns task<void>, no result tuple
     438                 :     }
     439                 :     @endcode
     440                 : 
     441                 :     @see IoAwaitable, task
     442                 : */
     443                 : template<IoAwaitable... As>
     444              59 : [[nodiscard]] auto when_all(As... awaitables)
     445                 :     -> task<detail::when_all_result_t<detail::awaitable_result_t<As>...>>
     446                 : {
     447                 :     using result_type = detail::when_all_result_t<detail::awaitable_result_t<As>...>;
     448                 : 
     449                 :     // State is stored in the coroutine frame, using the frame allocator
     450                 :     detail::when_all_state<detail::awaitable_result_t<As>...> state;
     451                 : 
     452                 :     // Store awaitables in the frame
     453                 :     std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
     454                 : 
     455                 :     // Launch all awaitables and wait for completion
     456                 :     co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
     457                 : 
     458                 :     // Propagate first exception if any.
     459                 :     // Safe without explicit acquire: capture_exception() is sequenced-before
     460                 :     // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
     461                 :     // last task's decrement that resumes this coroutine.
     462                 :     if(state.first_exception_)
     463                 :         std::rethrow_exception(state.first_exception_);
     464                 : 
     465                 :     // Extract and return results
     466                 :     if constexpr (std::is_void_v<result_type>)
     467                 :         co_return;
     468                 :     else
     469                 :         co_return detail::extract_results(state);
     470             118 : }
     471                 : 
     472                 : /// Compute the result type of `when_all` for the given task types.
     473                 : template<typename... Ts>
     474                 : using when_all_result_type = detail::when_all_result_t<Ts...>;
     475                 : 
     476                 : } // namespace capy
     477                 : } // namespace boost
     478                 : 
     479                 : #endif
        

Generated by: LCOV version 2.3