/* * Copyright (C) 2019 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H #define ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H #include "NeuralNetworksExtensions.h" #include "NeuralNetworksWrapper.h" #include <variant> namespace android { namespace nn { namespace extension_wrapper { using wrapper::SymmPerChannelQuantParams; using wrapper::Type; struct ExtensionOperandParams { std::vector<uint8_t> data; ExtensionOperandParams(std::vector<uint8_t> data) : data(std::move(data)) {} template <typename T> ExtensionOperandParams(const T& data) : ExtensionOperandParams( std::vector(reinterpret_cast<const uint8_t*>(&data), reinterpret_cast<const uint8_t*>(&data) + sizeof(data))) { static_assert(std::is_trivially_copyable<T>::value, "data must be trivially copyable"); } }; struct OperandType { using ExtraParams = std::variant<std::monostate, SymmPerChannelQuantParams, ExtensionOperandParams>; ANeuralNetworksOperandType operandType; std::vector<uint32_t> dimensions; ExtraParams extraParams; OperandType(const OperandType& other) : operandType(other.operandType), dimensions(other.dimensions), extraParams(other.extraParams) { operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; } OperandType& operator=(const OperandType& other) { if (this != &other) { operandType = other.operandType; dimensions = other.dimensions; extraParams = other.extraParams; operandType.dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr; } return *this; } OperandType(Type type, std::vector<uint32_t> d, float scale = 0.0f, int32_t zeroPoint = 0, ExtraParams&& extraParams = std::monostate()) : dimensions(std::move(d)), extraParams(std::move(extraParams)) { operandType = { .type = static_cast<int32_t>(type), .dimensionCount = static_cast<uint32_t>(dimensions.size()), .dimensions = dimensions.size() > 0 ? dimensions.data() : nullptr, .scale = scale, .zeroPoint = zeroPoint, }; } OperandType(Type type, std::vector<uint32_t> dimensions, float scale, int32_t zeroPoint, SymmPerChannelQuantParams&& channelQuant) : OperandType(type, dimensions, scale, zeroPoint, ExtraParams(std::move(channelQuant))) {} OperandType(Type type, std::vector<uint32_t> dimensions, ExtraParams&& extraParams) : OperandType(type, dimensions, 0.0f, 0, std::move(extraParams)) {} }; class Model : public wrapper::Model { public: using wrapper::Model::Model; // Inherit constructors. int32_t getExtensionOperandType(const char* extensionName, uint16_t typeWithinExtension) { int32_t result; if (ANeuralNetworksModel_getExtensionOperandType(mModel, extensionName, typeWithinExtension, &result) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } return result; } ANeuralNetworksOperationType getExtensionOperationType(const char* extensionName, uint16_t typeWithinExtension) { ANeuralNetworksOperationType result; if (ANeuralNetworksModel_getExtensionOperationType(mModel, extensionName, typeWithinExtension, &result) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } return result; } uint32_t addOperand(const OperandType* type) { if (ANeuralNetworksModel_addOperand(mModel, &(type->operandType)) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } if (std::holds_alternative<SymmPerChannelQuantParams>(type->extraParams)) { const auto& channelQuant = std::get<SymmPerChannelQuantParams>(type->extraParams); if (ANeuralNetworksModel_setOperandSymmPerChannelQuantParams( mModel, mNextOperandId, &channelQuant.params) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } } else if (std::holds_alternative<ExtensionOperandParams>(type->extraParams)) { const auto& extension = std::get<ExtensionOperandParams>(type->extraParams); if (ANeuralNetworksModel_setOperandExtensionData( mModel, mNextOperandId, extension.data.data(), extension.data.size()) != ANEURALNETWORKS_NO_ERROR) { mValid = false; } } return mNextOperandId++; } }; } // namespace extension_wrapper namespace wrapper { using ExtensionModel = extension_wrapper::Model; using ExtensionOperandType = extension_wrapper::OperandType; using ExtensionOperandParams = extension_wrapper::ExtensionOperandParams; } // namespace wrapper } // namespace nn } // namespace android #endif // ANDROID_ML_NN_RUNTIME_NEURAL_NETWORKS_WRAPPER_EXTENSIONS_H