阿里妹导读
努力成为全网最好理解的「C++20 协程」原理解析文章。
协程概念
这里稍微提一下共享栈协程(Copying the Stack Coroutine),既然每个协程都创建一个额外的栈太浪费了,那就只创建一个。在协程切换的时候,拷贝当前已使用的栈到另外的内存里,腾出来栈给新的协程用即可。需要切换回来的时候,拷贝回来之前的栈即可。共享栈协程解决了预分配的内存浪费问题,但是引入了栈备份和还原的开销。要想性能好的话,还需要尽量减少函数调用深度以及尽量不在栈上分配太大的数据结构。所以共享栈协程只是一种优化手段,所以一般不单独拿出来对比和讨论。
说多了容易让读者糊涂,但是不说的话又担心让读者一叶障目。本文介绍具体的实现原理是为了快速的理解抽象理论,但也容易造成对协程理论的狭隘理解,这一点还请读者注意。
C++20 的协程实现
从一个“简单”的 demo 开始
template <bool READY>
struct Awaiter {
bool await_ready() noexcept {
std::cout << "await_ready: " << READY << std::endl;
return READY;
}
void await_resume() noexcept {
std::cout << "await_resume" << std::endl;
}
void await_suspend(std::coroutine_handle<>) noexcept {
std::cout << "await_suspend" << std::endl;
}
};
struct TaskPromise {
struct promise_type {
TaskPromise get_return_object() {
std::cout << "get_return_object" << std::endl;
return TaskPromise{std::coroutine_handle<promise_type>::from_promise(*this)};
}
Awaiter<true> initial_suspend() noexcept {
std::cout << "initial_suspend" << std::endl;
return {};
}
Awaiter<true> final_suspend() noexcept {
std::cout << "final_suspend" << std::endl;
return {};
}
void unhandled_exception() {
std::cout << "unhandled_exception" << std::endl;
}
void return_void() noexcept {
std::cout << "return_void" << std::endl;
}
};
void resume() {
std::cout << "resume" << std::endl;
handle.resume();
}
std::coroutine_handle<promise_type> handle;
};
TaskPromise task_func() {
std::cout << "task first run" << std::endl;
co_await Awaiter<false>{};
std::cout << "task resume" << std::endl;
}
int main() {
auto promise = task_func();
promise.resume();
return 0;
}
这段代码[4]运行后输出:
get_return_object
initial_suspend
await_ready: 1
await_resume
task first run
await_ready: 0
await_suspend
resume
await_resume
task resume
return_void
final_suspend
await_ready: 1
await_resume
尽管我已经尽力把第一个 demo 写的足够小了,但是依旧比其他语言的协程 demo 长很多。原因也很简单,C++ 想让程序员可以定制协程创建和执行的任意一个阶段的任意步骤的行为,那么就必须定义足够多的回调函数来定义每个阶段的行为。不参考任何文献想看懂上面的代码的话还是有些难度的。但是回想一下上文所说的“编译器展开”代码的实现原理,结合这些函数的执行顺序就可以大致推测出来编译器展开代码后函数调用的顺序。其实 cppreference 给出来了具体的执行过程 [5],但是有些语焉不详,这里我们更详细的讲述流程并解释一些细节:
至于promise.unhandled_exception(),是在协程里出现未捕获的异常时候调用的。但是注意在promise.get_return_object()之前就抛出的异常不会来到这里,比如new导致的std::bad_alloc异常就不会调用到这里。异常处理流程不详细解释了,cppreference 描述的很清楚。
TaskPromise task_func() {
// No parameters and local variables.
auto state = new __TaskPromise_state_(); // has TaskPromise::promise_type promise;
TaskPromise coro = state.promise.get_return_object();
try {
co_await p.inital_suspend();
std::cout << "task first run" << std::endl;
co_await Awaiter<false>{};
std::cout << "task resume" << std::endl;
} catch (...) {
state.promise.unhandled_exception();
}
co_await state.promise.final_suspend();
}
struct TaskPromise {
struct promise_type {
TaskPromise get_return_object() {
std::cout << "get_return_object(), thread_id: " << std::this_thread::get_id() << std::endl;
return TaskPromise{std::coroutine_handle<promise_type>::from_promise(*this)};
}
std::suspend_always initial_suspend() noexcept { return {}; }
std::suspend_always final_suspend() noexcept { return {}; }
void unhandled_exception() {}
void return_void() noexcept {}
size_t data = 0;
};
std::coroutine_handle<promise_type> handle;
};
struct Awaiter {
bool await_ready() noexcept {
std::cout << "await_ready(), thread_id: " << std::this_thread::get_id() << std::endl;
return false;
}
void await_suspend(std::coroutine_handle<TaskPromise::promise_type> handle) noexcept {
std::cout << "await_suspend(), thread_id: " << std::this_thread::get_id() << std::endl;
auto thread = std::thread([=]() {
std::this_thread::sleep_for(std::chrono::seconds(1));
handle.promise().data = 1;
handle.resume();
});
thread.join();
}
void await_resume() noexcept {
std::cout << "await_resume(), thread_id: " << std::this_thread::get_id() << std::endl;
}
};
TaskPromise task_func() {
std::cout << "task_func() step 1, thread_id: " << std::this_thread::get_id() << std::endl;
co_await Awaiter{};
std::cout << "task_func() step 2, thread_id: " << std::this_thread::get_id() << std::endl;
}
int main() {
std::cout << "main(), thread_id: " << std::this_thread::get_id() << std::endl;
auto promise = task_func();
std::cout << "main(), data: " << promise.handle.promise().data << ", thread_id: " << std::this_thread::get_id() << std::endl;
promise.handle.resume();
std::cout << "main(), data: " << promise.handle.promise().data << ", thread_id: " << std::this_thread::get_id() << std::endl;
return 0;
}
执行结果如下:
main(), thread_id: 0x1d9d91ec0
get_return_object(), thread_id: 0x1d9d91ec0
main(), data: 0, thread_id: 0x1d9d91ec0
task_func() step 1, thread_id: 0x1d9d91ec0
await_ready(), thread_id: 0x1d9d91ec0
await_suspend(), thread_id: 0x1d9d91ec0
await_resume(), thread_id: 0x16dce7000
task_func() step 2, thread_id: 0x16dce7000
main(), data: 1, thread_id: 0x1d9d91ec0
结合日志和上面的流程说明很好理解这段代码。代码里 26 行调用await_suspend()的handle参数是编译器帮着传入的(这个 handle 就一个void *指针,值传递的成本很低)。而 31 行代码是在另一个线程上调用的,而后续的协程代码也是在另一个线程运行的。这也揭示了协程的跨线程传递的能力,只要传递协程句柄,就可以实现在任意线程上恢复协程的执行。那么当协程跨线程传递时,线程安全的问题依旧要注意,而且因为执行流可以任意移动,因此带来的其他同步问题需要额外注意。
实现一个简单的 generator
template <typename T>
struct Generator {
struct promise_type {
Generator get_return_object() {
return Generator{std::coroutine_handle<promise_type>::from_promise(*this)};
}
std::suspend_always initial_suspend() noexcept { return {}; }
std::suspend_always final_suspend() noexcept { return {}; }
void unhandled_exception() {}
void return_value(T t) noexcept {
v = t;
}
std::suspend_always yield_value(T t) {
v = t;
return {};
}
T v{};
};
bool has_next() {
return !handle.done();
}
size_t next() {
handle.resume();
return handle.promise().v;
}
std::coroutine_handle<promise_type> handle;
};
Generator<size_t> fib(size_t max_count) {
co_yield 1;
size_t a = 0, b = 1, count = 0;
while (++count < max_count - 1) {
co_yield a + b;
b = a + b;
a = b - a;
}
co_return a + b;
}
int main() {
size_t max_count = 10;
auto generator = fib(max_count);
size_t i = 0;
while (generator.has_next()) {
std::cout << "No." << ++i << ": " << generator.next() << std::endl;
}
return 0;
}
代码运行结果为:
No.1: 1
No.2: 1
No.3: 2
No.4: 3
No.5: 5
No.6: 8
No.7: 13
No.8: 21
No.9: 34
No.10: 55
前文介绍过co_await和co_return,新出现的标识符co_yield算是co_await的语法糖,可以比co_await更方便的传递出来一个值[9]。
template <typename T>
class Generator {
public:
struct promise_type;
using promise_handle_t = std::coroutine_handle<promise_type>;
explicit Generator(promise_handle_t h) : handle(h) {}
explicit Generator(Generator &&generator) : handle(std::exchange(generator.handle, {})) {}
{
if (handle) {
handle.destroy();
}
}
&) = delete;
Generator &operator=(Generator &) = delete;
struct promise_type {
Generator get_return_object() {
return Generator{std::coroutine_handle<promise_type>::from_promise(*this)};
}
std::suspend_always initial_suspend() noexcept { return {}; }
std::suspend_always final_suspend() noexcept { return {}; }
void unhandled_exception() {}
void return_value(T t) noexcept {
v = t;
}
std::suspend_always yield_value(T t) {
v = t;
return {};
}
T v{};
};
bool has_next() {
return !handle.done();
}
size_t next() {
handle.resume();
return handle.promise().v;
}
private:
std::coroutine_handle<promise_type> handle;
};
通用的协程返回类 Task
template <typename T>
class Task {
public:
struct promise_type;
using promise_handle_t = std::coroutine_handle<promise_type>;
explicit Task(promise_handle_t h) : handle(h) {}
Task(Task &&task) noexcept : handle(std::exchange(task.handle, {})) {}
~Task() { if (handle) { handle.destroy(); } }
template <typename R>
struct task_awaiter {
explicit task_awaiter(Task<R> &&task) noexcept : task(std::move(task)) {}
task_awaiter(task_awaiter &) = delete;
task_awaiter &operator=(task_awaiter &) = delete;
bool await_ready() noexcept { return false; }
void await_suspend(std::coroutine_handle<> handle) noexcept {
task.finally([handle]() { handle.resume(); });
}
R await_resume() noexcept { return task.get_result(); }
private:
Task<R> task;
};
struct promise_type {
Task get_return_object() {
return Task(promise_handle_t::from_promise(*this));
}
std::suspend_never initial_suspend() { return {}; }
std::suspend_always final_suspend() noexcept { return {}; }
template <typename U>
task_awaiter<U> await_transform(Task<U> &&task) {
return task_awaiter<U>(std::move(task));
}
void unhandled_exception() {}
void return_value(T t) {
data_ = t;
notify_callbacks();
}
void on_completed(std::function<void(T)> &&callback) {
if (data_.has_value()) {
callback(data_.value());
} else {
callbacks_.push_back(callback);
}
}
T get() {
return data_.value();
}
private:
void notify_callbacks() {
for (auto &callback : callbacks_) {
callback(data_.value());
}
callbacks_.clear();
}
std::optional<T> data_;
std::deque<std::function<void(T)>> callbacks_;
};
T get_result() {
return handle.promise().get();
}
void then(std::function<void(T)> &&callback) {
handle.promise().on_completed([callback](auto data) {
callback(data);
});
}
void finally(std::function<void()> &&callback) {
handle.promise().on_completed([callback](auto result) {
callback();
});
}
private:
promise_handle_t handle;
};
Task<int> task1() {
std::cout << "task1 run" << std::endl;
co_return 1;
}
Task<int> task2() {
std::cout << "task2 run" << std::endl;
co_return 2;
}
Task<int> call_task() {
std::cout << "call_task" << std::endl;
int data1 = co_await task1();
std::cout << "call_task task1 data: " << data1 << std::endl;
int data2 = co_await task2();
std::cout << "call_task task2 data: " << data2 << std::endl;
co_return data1 + data2;
}
int main() {
Task<int> task = call_task();
task.then([](int data) {
std::cout << "call_task data: " << data << std::endl;
});
return 0;
}
代码执行结果如下:
call_task
task1 run
call_task task1 data: 1
task2 run
call_task task2 data: 2
call_task data: 3
可以在源码的函数里像之前的 demo 里加点日志去理解流程(或者直接使用调试器单步跟踪)。实际上这个 demo 没有实现真正意义上的等待和唤醒(甚至有些唤醒的逻辑都没运行到)。特别地,如果这些协程任务需要被调度到其他线程执行的话,还要考虑这些对象内部数据结构的并发安全性(直接使用std::mutex写不好就比普通的函数更容易出现死锁了 )。另外这个 demo 也揭示了一旦使用了C++20 的协程异步手段,从入口开始要一直改造下去,最终「传染」到整个项目的所有异步函数。
写在后面
你曾经担任的角色是 CodeReviewer 还是 被 CodeReviewer ?
CodeReview 是开发过程不可或缺的重要一环,如果将代码发布比作一个工厂的流水线,那么 CodeReview 就是流水线接近于终点的质检员,他要担负着对产品质量的保障工作,将“缺陷”从众多的“产品”中挑出,反向推动“生产方”改进生产质量。
截止2024年1月1日24时,参与本期话题讨论,将会选出 3 名幸运用户和 3 个优质回答分别获得阿里云开发者家用蓝牙新款智能电子秤人体体脂称一个。快点击阅读原文参加讨论吧~
【阿里云开发者公众号】读者群,是一个专门面向公众号的读者交流空间,你可以探讨技术和实践或参与群活动。欢迎添加微信:argentinaliu (备注读者群)入群。