diff --git a/engine/src/flutter/fml/message_loop.cc b/engine/src/flutter/fml/message_loop.cc index 652fb659bd2..bface19f9e1 100644 --- a/engine/src/flutter/fml/message_loop.cc +++ b/engine/src/flutter/fml/message_loop.cc @@ -81,4 +81,10 @@ void MessageLoop::RunExpiredTasksNow() { loop_->RunExpiredTasksNow(); } +void MessageLoop::SwapTaskQueues(MessageLoop* other) { + FML_CHECK(loop_); + FML_CHECK(other->loop_); + loop_->SwapTaskQueues(other->loop_); +} + } // namespace fml diff --git a/engine/src/flutter/fml/message_loop.h b/engine/src/flutter/fml/message_loop.h index 3400bf47f59..30db680cb9d 100644 --- a/engine/src/flutter/fml/message_loop.h +++ b/engine/src/flutter/fml/message_loop.h @@ -38,6 +38,8 @@ class MessageLoop { // instead of dedicating a thread to the message loop. void RunExpiredTasksNow(); + void SwapTaskQueues(MessageLoop* other); + static void EnsureInitializedForCurrentThread(); static bool IsInitializedForCurrentThread(); diff --git a/engine/src/flutter/fml/message_loop_impl.cc b/engine/src/flutter/fml/message_loop_impl.cc index 94474fd1af9..e12923d886a 100644 --- a/engine/src/flutter/fml/message_loop_impl.cc +++ b/engine/src/flutter/fml/message_loop_impl.cc @@ -53,6 +53,7 @@ void MessageLoopImpl::AddTaskObserver(intptr_t key, fml::closure callback) { FML_DCHECK(MessageLoop::GetCurrent().GetLoopImpl().get() == this) << "Message loop task observer must be added on the same thread as the " "loop."; + std::lock_guard observers_lock(observers_mutex_); task_observers_[key] = std::move(callback); } @@ -60,6 +61,7 @@ void MessageLoopImpl::RemoveTaskObserver(intptr_t key) { FML_DCHECK(MessageLoop::GetCurrent().GetLoopImpl().get() == this) << "Message loop task observer must be removed from the same thread as " "the loop."; + std::lock_guard observers_lock(observers_mutex_); task_observers_.erase(key); } @@ -95,6 +97,32 @@ void MessageLoopImpl::DoTerminate() { Terminate(); } +// Thread safety analysis disabled as it does not account for defered locks. +void MessageLoopImpl::SwapTaskQueues(const fml::RefPtr& other) + FML_NO_THREAD_SAFETY_ANALYSIS { + if (terminated_ || other->terminated_) { + return; + } + + // task_flushing locks + std::unique_lock t1(tasks_flushing_mutex_, std::defer_lock); + std::unique_lock t2(other->tasks_flushing_mutex_, + std::defer_lock); + + // task_observers locks + std::unique_lock o1(observers_mutex_, std::defer_lock); + std::unique_lock o2(other->observers_mutex_, std::defer_lock); + + // delayed_tasks locks + std::unique_lock d1(delayed_tasks_mutex_, std::defer_lock); + std::unique_lock d2(other->delayed_tasks_mutex_, std::defer_lock); + + std::lock(t1, t2, o1, o2, d1, d2); + + std::swap(task_observers_, other->task_observers_); + std::swap(delayed_tasks_, other->delayed_tasks_); +} + void MessageLoopImpl::RegisterTask(fml::closure task, fml::TimePoint target_time) { FML_DCHECK(task != nullptr); @@ -112,6 +140,14 @@ void MessageLoopImpl::FlushTasks(FlushType type) { TRACE_EVENT0("fml", "MessageLoop::FlushTasks"); std::vector invocations; + // We are grabbing this lock here as a proxy to indicate + // that we are running tasks and will invoke the + // "right" observers, we are trying to avoid the scenario + // where: + // gather invocations -> Swap -> execute invocations + // will lead us to run invocations on the wrong thread. + std::lock_guard task_flush_lock(tasks_flushing_mutex_); + { std::lock_guard lock(delayed_tasks_mutex_); @@ -138,6 +174,7 @@ void MessageLoopImpl::FlushTasks(FlushType type) { for (const auto& invocation : invocations) { invocation(); + std::lock_guard observers_lock(observers_mutex_); for (const auto& observer : task_observers_) { observer.second(); } diff --git a/engine/src/flutter/fml/message_loop_impl.h b/engine/src/flutter/fml/message_loop_impl.h index 4248061ad0d..5b41d66d4d0 100644 --- a/engine/src/flutter/fml/message_loop_impl.h +++ b/engine/src/flutter/fml/message_loop_impl.h @@ -43,6 +43,8 @@ class MessageLoopImpl : public fml::RefCountedThreadSafe { void DoTerminate(); + void SwapTaskQueues(const fml::RefPtr& other); + protected: // Exposed for the embedder shell which allows clients to poll for events // instead of dedicating a thread to the message loop. @@ -80,7 +82,12 @@ class MessageLoopImpl : public fml::RefCountedThreadSafe { using DelayedTaskQueue = std:: priority_queue, DelayedTaskCompare>; - std::map task_observers_; + std::mutex tasks_flushing_mutex_; + + std::mutex observers_mutex_; + std::map task_observers_ + FML_GUARDED_BY(observers_mutex_); + std::mutex delayed_tasks_mutex_; DelayedTaskQueue delayed_tasks_ FML_GUARDED_BY(delayed_tasks_mutex_); size_t order_ FML_GUARDED_BY(delayed_tasks_mutex_); diff --git a/engine/src/flutter/fml/message_loop_unittests.cc b/engine/src/flutter/fml/message_loop_unittests.cc index 55de92bb15e..d7c0a9cb926 100644 --- a/engine/src/flutter/fml/message_loop_unittests.cc +++ b/engine/src/flutter/fml/message_loop_unittests.cc @@ -295,3 +295,139 @@ TEST(MessageLoop, CanCreateConcurrentMessageLoop) { } latch.Wait(); } + +TEST(MessageLoop, CanSwapMessageLoopsAndPreserveThreadConfiguration) { + // synchronization notes: + // 1. term1 and term2 are to wait for Swap. + // 2. task_started_1 is to wait for the task runners + // to signal that they are done. + // 3. loop_init_1 and loop_init_2 are to wait for the message loops to + // get initialized. + + fml::MessageLoop* loop1 = nullptr; + fml::AutoResetWaitableEvent loop_init_1; + fml::AutoResetWaitableEvent task_started_1; + fml::AutoResetWaitableEvent term1; + std::thread thread1([&loop1, &loop_init_1, &term1, &task_started_1]() { + fml::MessageLoop::EnsureInitializedForCurrentThread(); + loop1 = &fml::MessageLoop::GetCurrent(); + // this task will be run on thread1 after Swap. + loop1->GetTaskRunner()->PostTask([&task_started_1]() { + task_started_1.Signal(); + fml::MessageLoop::GetCurrent().Terminate(); + }); + loop_init_1.Signal(); + term1.Wait(); + loop1->Run(); + }); + + loop_init_1.Wait(); + + fml::MessageLoop* loop2 = nullptr; + fml::AutoResetWaitableEvent loop_init_2; + fml::AutoResetWaitableEvent task_started_2; + fml::AutoResetWaitableEvent term2; + std::thread thread2( + [&loop2, &loop_init_2, &term2, &task_started_2, &loop1]() { + fml::MessageLoop::EnsureInitializedForCurrentThread(); + loop2 = &fml::MessageLoop::GetCurrent(); + // this task will be run on thread1 after Swap. + loop2->GetTaskRunner()->PostTask([&task_started_2, &loop1]() { + // ensure that we run the task on loop1 after the swap. + ASSERT_TRUE(loop1 == &fml::MessageLoop::GetCurrent()); + task_started_2.Signal(); + fml::MessageLoop::GetCurrent().Terminate(); + }); + loop_init_2.Signal(); + term2.Wait(); + loop2->Run(); + }); + + loop_init_2.Wait(); + + // swap the loops. + loop1->SwapTaskQueues(loop2); + + // thread_1 should wait for tr_term2 latch. + term1.Signal(); + task_started_2.Wait(); + + // thread_2 should wait for tr_term2 latch. + term2.Signal(); + task_started_1.Wait(); + + thread1.join(); + thread2.join(); +} + +TEST(MessageLoop, TIME_SENSITIVE(DelayedTaskSwap)) { + // Task execution order: + // time (ms): 0 10 20 30 40 + // thread 1: A1 A2 A3 A4 TERM + // thread 2: B1 B2 B3 TERM + + // At time 15, we swap thread 1 and 2, and assert + // that tasks run on the right threads. + + std::thread::id t1, t2; + fml::AutoResetWaitableEvent tid_1, tid_2; + fml::MessageLoop* loop1 = nullptr; + fml::MessageLoop* loop2 = nullptr; + + std::thread thread_1([&loop1, &t1, &t2, &tid_1, &tid_2]() { + t1 = std::this_thread::get_id(); + tid_1.Signal(); + tid_2.Wait(); + fml::MessageLoop::EnsureInitializedForCurrentThread(); + loop1 = &fml::MessageLoop::GetCurrent(); + for (int t = 0; t <= 4; t++) { + loop1->GetTaskRunner()->PostDelayedTask( + [t, &t1, &t2]() { + auto cur_tid = std::this_thread::get_id(); + if (t <= 1) { + ASSERT_EQ(cur_tid, t1); + } else { + ASSERT_EQ(cur_tid, t2); + } + + if (t == 4) { + fml::MessageLoop::GetCurrent().Terminate(); + } + }, + fml::TimeDelta::FromMilliseconds(t * 10)); + } + loop1->Run(); + }); + + std::thread thread_2([&loop2, &t1, &t2, &tid_1, &tid_2]() { + t2 = std::this_thread::get_id(); + tid_2.Signal(); + tid_1.Wait(); + fml::MessageLoop::EnsureInitializedForCurrentThread(); + loop2 = &fml::MessageLoop::GetCurrent(); + for (int t = 1; t <= 4; t++) { + loop2->GetTaskRunner()->PostDelayedTask( + [t, &t1, &t2]() { + auto cur_tid = std::this_thread::get_id(); + if (t <= 1) { + ASSERT_EQ(cur_tid, t2); + } else { + ASSERT_EQ(cur_tid, t1); + } + + if (t == 4) { + fml::MessageLoop::GetCurrent().Terminate(); + } + }, + fml::TimeDelta::FromMilliseconds(t * 10)); + } + loop2->Run(); + }); + + // on main thread we swap the threads at 15 ms. + std::this_thread::sleep_for(std::chrono::milliseconds(15)); + loop1->SwapTaskQueues(loop2); + + thread_1.join(); + thread_2.join(); +}