diff options
-rw-r--r-- | src/core/hle/kernel/address_arbiter.cpp | 66 | ||||
-rw-r--r-- | src/core/hle/kernel/address_arbiter.h | 19 | ||||
-rw-r--r-- | src/core/hle/kernel/kernel.cpp | 6 |
3 files changed, 67 insertions, 24 deletions
diff --git a/src/core/hle/kernel/address_arbiter.cpp b/src/core/hle/kernel/address_arbiter.cpp index 8422d05e0..db189c8e3 100644 --- a/src/core/hle/kernel/address_arbiter.cpp +++ b/src/core/hle/kernel/address_arbiter.cpp @@ -17,10 +17,10 @@ #include "core/memory.h" namespace Kernel { -namespace { + // Wake up num_to_wake (or all) threads in a vector. -void WakeThreads(const std::vector<std::shared_ptr<Thread>>& waiting_threads, s32 num_to_wake) { - auto& system = Core::System::GetInstance(); +void AddressArbiter::WakeThreads(const std::vector<std::shared_ptr<Thread>>& waiting_threads, + s32 num_to_wake) { // Only process up to 'target' threads, unless 'target' is <= 0, in which case process // them all. std::size_t last = waiting_threads.size(); @@ -32,12 +32,12 @@ void WakeThreads(const std::vector<std::shared_ptr<Thread>>& waiting_threads, s3 for (std::size_t i = 0; i < last; i++) { ASSERT(waiting_threads[i]->GetStatus() == ThreadStatus::WaitArb); waiting_threads[i]->SetWaitSynchronizationResult(RESULT_SUCCESS); + RemoveThread(waiting_threads[i]); waiting_threads[i]->SetArbiterWaitAddress(0); waiting_threads[i]->ResumeFromWait(); system.PrepareReschedule(waiting_threads[i]->GetProcessorID()); } } -} // Anonymous namespace AddressArbiter::AddressArbiter(Core::System& system) : system{system} {} AddressArbiter::~AddressArbiter() = default; @@ -184,6 +184,7 @@ ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 t ResultCode AddressArbiter::WaitForAddressImpl(VAddr address, s64 timeout) { Thread* current_thread = system.CurrentScheduler().GetCurrentThread(); current_thread->SetArbiterWaitAddress(address); + InsertThread(SharedFrom(current_thread)); current_thread->SetStatus(ThreadStatus::WaitArb); current_thread->InvalidateWakeupCallback(); current_thread->WakeAfterDelay(timeout); @@ -192,26 +193,51 @@ ResultCode AddressArbiter::WaitForAddressImpl(VAddr address, s64 timeout) { return RESULT_TIMEOUT; } -std::vector<std::shared_ptr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress( - VAddr address) const { - - // Retrieve all threads that are waiting for this address. - std::vector<std::shared_ptr<Thread>> threads; - const auto& scheduler = system.GlobalScheduler(); - const auto& thread_list = scheduler.GetThreadList(); +void AddressArbiter::HandleWakeupThread(std::shared_ptr<Thread> thread) { + ASSERT(thread->GetStatus() == ThreadStatus::WaitArb); + RemoveThread(thread); + thread->SetArbiterWaitAddress(0); +} - for (const auto& thread : thread_list) { - if (thread->GetArbiterWaitAddress() == address) { - threads.push_back(thread); +void AddressArbiter::InsertThread(std::shared_ptr<Thread> thread) { + const VAddr arb_addr = thread->GetArbiterWaitAddress(); + std::list<std::shared_ptr<Thread>>& thread_list = arb_threads[arb_addr]; + auto it = thread_list.begin(); + while (it != thread_list.end()) { + const std::shared_ptr<Thread>& current_thread = *it; + if (current_thread->GetPriority() >= thread->GetPriority()) { + thread_list.insert(it, thread); + return; } + ++it; } + thread_list.push_back(std::move(thread)); +} - // Sort them by priority, such that the highest priority ones come first. - std::sort(threads.begin(), threads.end(), - [](const std::shared_ptr<Thread>& lhs, const std::shared_ptr<Thread>& rhs) { - return lhs->GetPriority() < rhs->GetPriority(); - }); +void AddressArbiter::RemoveThread(std::shared_ptr<Thread> thread) { + const VAddr arb_addr = thread->GetArbiterWaitAddress(); + std::list<std::shared_ptr<Thread>>& thread_list = arb_threads[arb_addr]; + auto it = thread_list.begin(); + while (it != thread_list.end()) { + const std::shared_ptr<Thread>& current_thread = *it; + if (current_thread.get() == thread.get()) { + thread_list.erase(it); + return; + } + ++it; + } + UNREACHABLE(); +} - return threads; +std::vector<std::shared_ptr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress(VAddr address) { + std::vector<std::shared_ptr<Thread>> result; + std::list<std::shared_ptr<Thread>>& thread_list = arb_threads[address]; + auto it = thread_list.begin(); + while (it != thread_list.end()) { + std::shared_ptr<Thread> current_thread = *it; + result.push_back(std::move(current_thread)); + ++it; + } + return result; } } // namespace Kernel diff --git a/src/core/hle/kernel/address_arbiter.h b/src/core/hle/kernel/address_arbiter.h index 1e1f00e60..386983e54 100644 --- a/src/core/hle/kernel/address_arbiter.h +++ b/src/core/hle/kernel/address_arbiter.h @@ -4,7 +4,9 @@ #pragma once +#include <list> #include <memory> +#include <unordered_map> #include <vector> #include "common/common_types.h" @@ -48,6 +50,9 @@ public: /// Waits on an address with a particular arbitration type. ResultCode WaitForAddress(VAddr address, ArbitrationType type, s32 value, s64 timeout_ns); + /// Removes a thread from the container and resets its address arbiter adress to 0 + void HandleWakeupThread(std::shared_ptr<Thread> thread); + private: /// Signals an address being waited on. ResultCode SignalToAddressOnly(VAddr address, s32 num_to_wake); @@ -71,8 +76,20 @@ private: // Waits on the given address with a timeout in nanoseconds ResultCode WaitForAddressImpl(VAddr address, s64 timeout); + /// Wake up num_to_wake (or all) threads in a vector. + void WakeThreads(const std::vector<std::shared_ptr<Thread>>& waiting_threads, s32 num_to_wake); + + /// Insert a thread into the address arbiter container + void InsertThread(std::shared_ptr<Thread> thread); + + /// Removes a thread from the address arbiter container + void RemoveThread(std::shared_ptr<Thread> thread); + // Gets the threads waiting on an address. - std::vector<std::shared_ptr<Thread>> GetThreadsWaitingOnAddress(VAddr address) const; + std::vector<std::shared_ptr<Thread>> GetThreadsWaitingOnAddress(VAddr address); + + /// List of threads waiting for a address arbiter + std::unordered_map<VAddr, std::list<std::shared_ptr<Thread>>> arb_threads; Core::System& system; }; diff --git a/src/core/hle/kernel/kernel.cpp b/src/core/hle/kernel/kernel.cpp index 0b149067a..1d0783bd3 100644 --- a/src/core/hle/kernel/kernel.cpp +++ b/src/core/hle/kernel/kernel.cpp @@ -78,9 +78,9 @@ static void ThreadWakeupCallback(u64 thread_handle, [[maybe_unused]] s64 cycles_ } } - if (thread->GetArbiterWaitAddress() != 0) { - ASSERT(thread->GetStatus() == ThreadStatus::WaitArb); - thread->SetArbiterWaitAddress(0); + if (thread->GetStatus() == ThreadStatus::WaitArb) { + auto& address_arbiter = thread->GetOwnerProcess()->GetAddressArbiter(); + address_arbiter.HandleWakeupThread(thread); } if (resume) { |