/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H #define ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H #include "HalInterfaces.h" #include <android-base/macros.h> #include <fmq/MessageQueue.h> #include <hidl/MQDescriptor.h> #include <atomic> #include <map> #include <memory> #include <mutex> #include <stack> #include <tuple> namespace android::nn { /** * Number of elements in the FMQ. */ constexpr const size_t kExecutionBurstChannelLength = 1024; /** * Function to serialize a request. * * Prefer calling RequestChannelSender::send. * * @param request Request object without the pool information. * @param measure Whether to collect timing information for the execution. * @param memoryIds Slot identifiers corresponding to memory resources for the * request. * @return Serialized FMQ request data. */ std::vector<FmqRequestDatum> serialize(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots); /** * Deserialize the FMQ result data. * * The three resulting fields are the status of the execution, the dynamic * shapes of the output tensors, and the timing information of the execution. * * @param data Serialized FMQ result data. * @return Result object if successfully deserialized, std::nullopt otherwise. */ std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> deserialize( const std::vector<FmqResultDatum>& data); /** * ResultChannelReceiver is responsible for waiting on the channel until the * packet is available, extracting the packet from the channel, and * deserializing the packet. * * Because the receiver can wait on a packet that may never come (e.g., because * the sending side of the packet has been closed), this object can be * invalidating, unblocking the receiver. */ class ResultChannelReceiver { using FmqResultDescriptor = ::android::hardware::MQDescriptorSync<FmqResultDatum>; using FmqResultChannel = hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>; public: /** * Create the receiving end of a result channel. * * Prefer this call over the constructor. * * @param channelLength Number of elements in the FMQ. * @param blocking 'true' if FMQ should use futex, 'false' if it should * spin-wait. * @return A pair of ResultChannelReceiver and the FMQ descriptor on * successful creation, both nullptr otherwise. */ static std::pair<std::unique_ptr<ResultChannelReceiver>, const FmqResultDescriptor*> create( size_t channelLength, bool blocking); /** * Get the result from the channel. * * This method will block until either: * 1) The packet has been retrieved, or * 2) The receiver has been invalidated * * @return Result object if successfully received, std::nullopt if error or * if the receiver object was invalidated. */ std::optional<std::tuple<ErrorStatus, std::vector<OutputShape>, Timing>> getBlocking(); /** * Method to mark the channel as invalid, unblocking any current or future * calls to ResultChannelReceiver::getBlocking. */ void invalidate(); // prefer calling ResultChannelReceiver::getBlocking std::optional<std::vector<FmqResultDatum>> getPacketBlocking(); ResultChannelReceiver(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking); private: const std::unique_ptr<FmqResultChannel> mFmqResultChannel; std::atomic<bool> mValid{true}; const bool mBlocking; }; /** * RequestChannelSender is responsible for serializing the result packet of * information, sending it on the result channel, and signaling that the data is * available. */ class RequestChannelSender { using FmqRequestDescriptor = ::android::hardware::MQDescriptorSync<FmqRequestDatum>; using FmqRequestChannel = hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>; public: /** * Create the sending end of a request channel. * * Prefer this call over the constructor. * * @param channelLength Number of elements in the FMQ. * @param blocking 'true' if FMQ should use futex, 'false' if it should * spin-wait. * @return A pair of ResultChannelReceiver and the FMQ descriptor on * successful creation, both nullptr otherwise. */ static std::pair<std::unique_ptr<RequestChannelSender>, const FmqRequestDescriptor*> create( size_t channelLength, bool blocking); /** * Send the request to the channel. * * @param request Request object without the pool information. * @param measure Whether to collect timing information for the execution. * @param memoryIds Slot identifiers corresponding to memory resources for * the request. * @return 'true' on successful send, 'false' otherwise. */ bool send(const Request& request, MeasureTiming measure, const std::vector<int32_t>& slots); /** * Method to mark the channel as invalid, causing all future calls to * RequestChannelSender::send to immediately return false without attempting * to send a message across the FMQ. */ void invalidate(); // prefer calling RequestChannelSender::send bool sendPacket(const std::vector<FmqRequestDatum>& packet); RequestChannelSender(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking); private: const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel; std::atomic<bool> mValid{true}; const bool mBlocking; }; /** * The ExecutionBurstController class manages both the serialization and * deserialization of data across FMQ, making it appear to the runtime as a * regular synchronous inference. Additionally, this class manages the burst's * memory cache. */ class ExecutionBurstController { DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstController); public: /** * NN runtime burst callback object and memory cache. * * ExecutionBurstCallback associates a hidl_memory object with a slot number * to be passed across FMQ. The ExecutionBurstServer can use this callback * to retrieve this hidl_memory corresponding to the slot via HIDL. * * Whenever a hidl_memory object is copied, it will duplicate the underlying * file descriptor. Because the NN runtime currently copies the hidl_memory * on each execution, it is difficult to associate hidl_memory objects with * previously cached hidl_memory objects. For this reason, callers of this * class must pair each hidl_memory object with an associated key. For * efficiency, if two hidl_memory objects represent the same underlying * buffer, they must use the same key. */ class ExecutionBurstCallback : public IBurstCallback { DISALLOW_COPY_AND_ASSIGN(ExecutionBurstCallback); public: ExecutionBurstCallback() = default; Return<void> getMemories(const hidl_vec<int32_t>& slots, getMemories_cb cb) override; /** * This function performs one of two different actions: * 1) If a key corresponding to a memory resource is unrecognized by the * ExecutionBurstCallback object, the ExecutionBurstCallback object * will allocate a slot, bind the memory to the slot, and return the * slot identifier. * 2) If a key corresponding to a memory resource is recognized by the * ExecutionBurstCallback object, the ExecutionBurstCallback object * will return the existing slot identifier. * * @param memories Memory resources used in an inference. * @param keys Unique identifiers where each element corresponds to a * memory resource element in "memories". * @return Unique slot identifiers where each returned slot element * corresponds to a memory resource element in "memories". */ std::vector<int32_t> getSlots(const hidl_vec<hidl_memory>& memories, const std::vector<intptr_t>& keys); /* * This function performs two different actions: * 1) Removes an entry from the cache (if present), including the local * storage of the hidl_memory object. Note that this call does not * free any corresponding hidl_memory object in ExecutionBurstServer, * which is separately freed via IBurstContext::freeMemory. * 2) Return whether a cache entry was removed and which slot was removed if * found. If the key did not to correspond to any entry in the cache, a * slot number of 0 is returned. The slot number and whether the entry * existed is useful so the same slot can be freed in the * ExecutionBurstServer's cache via IBurstContext::freeMemory. */ std::pair<bool, int32_t> freeMemory(intptr_t key); private: int32_t getSlotLocked(const hidl_memory& memory, intptr_t key); int32_t allocateSlotLocked(); std::mutex mMutex; std::stack<int32_t, std::vector<int32_t>> mFreeSlots; std::map<intptr_t, int32_t> mMemoryIdToSlot; std::vector<hidl_memory> mMemoryCache; }; /** * Creates a burst controller on a prepared model. * * Prefer this over ExecutionBurstController's constructor. * * @param preparedModel Model prepared for execution to execute on. * @param blocking 'true' if the FMQ should use a futex to perform blocking * until data is available in a less responsive, but more energy * efficient manner. 'false' if the FMQ should use spin-looping to * wait until data is available in a more responsive, but less energy * efficient manner. * @return ExecutionBurstController Execution burst controller object. */ static std::unique_ptr<ExecutionBurstController> create(const sp<IPreparedModel>& preparedModel, bool blocking); // prefer calling ExecutionBurstController::create ExecutionBurstController(const std::shared_ptr<RequestChannelSender>& requestChannelSender, const std::shared_ptr<ResultChannelReceiver>& resultChannelReceiver, const sp<IBurstContext>& burstContext, const sp<ExecutionBurstCallback>& callback, const sp<hardware::hidl_death_recipient>& deathHandler = nullptr); // explicit destructor to unregister the death recipient ~ExecutionBurstController(); /** * Execute a request on a model. * * @param request Arguments to be executed on a model. * @param measure Whether to collect timing measurements, either YES or NO * @param memoryIds Identifiers corresponding to each memory object in the * request's pools. * @return A tuple of: * - status of the execution * - dynamic output shapes from the execution * - any execution time measurements of the execution */ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing> compute( const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); // TODO: combine "compute" and "tryCompute" back into a single function. // "tryCompute" was created later to return the "fallback" boolean. This // could not be done directly in "compute" because the VTS test cases (which // test burst using "compute") had already been locked down and could not be // changed. /** * Execute a request on a model. * * @param request Arguments to be executed on a model. * @param measure Whether to collect timing measurements, either YES or NO * @param memoryIds Identifiers corresponding to each memory object in the * request's pools. * @return A tuple of: * - status of the execution * - dynamic output shapes from the execution * - any execution time measurements of the execution * - whether or not a failed burst execution should be re-run using a * different path (e.g., IPreparedModel::executeSynchronously) */ std::tuple<ErrorStatus, std::vector<OutputShape>, Timing, bool> tryCompute( const Request& request, MeasureTiming measure, const std::vector<intptr_t>& memoryIds); /** * Propagate a user's freeing of memory to the service. * * @param key Key corresponding to the memory object. */ void freeMemory(intptr_t key); private: std::mutex mMutex; const std::shared_ptr<RequestChannelSender> mRequestChannelSender; const std::shared_ptr<ResultChannelReceiver> mResultChannelReceiver; const sp<IBurstContext> mBurstContext; const sp<ExecutionBurstCallback> mMemoryCache; const sp<hardware::hidl_death_recipient> mDeathHandler; }; } // namespace android::nn #endif // ANDROID_ML_NN_RUNTIME_EXECUTION_BURST_CONTROLLER_H