/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
#define ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
#include "HalInterfaces.h"
#include "OperationsUtils.h"
namespace android {
namespace nn {
// Encapsulates an operation implementation.
struct OperationRegistration {
OperationType type;
const char* name;
// Validates operand types, shapes, and any values known during graph creation.
std::function<bool(const IOperationValidationContext*)> validate;
// prepare is called when the inputs this operation depends on have been
// computed. Typically, prepare does any remaining validation and sets
// output shapes via context->setOutputShape(...).
std::function<bool(IOperationExecutionContext*)> prepare;
// Executes the operation, reading from context->getInputBuffer(...)
// and writing to context->getOutputBuffer(...).
std::function<bool(IOperationExecutionContext*)> execute;
struct Flag {
// Whether the operation allows at least one operand to be omitted.
bool allowOmittedOperand = false;
// Whether the operation allows at least one input operand to be a zero-sized tensor.
bool allowZeroSizedInput = false;
} flags;
OperationRegistration(OperationType type, const char* name,
std::function<bool(const IOperationValidationContext*)> validate,
std::function<bool(IOperationExecutionContext*)> prepare,
std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
: type(type),
name(name),
validate(validate),
prepare(prepare),
execute(execute),
flags(flags) {}
};
// A registry of operation implementations.
class IOperationResolver {
public:
virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
virtual ~IOperationResolver() {}
};
// A registry of builtin operation implementations.
//
// Note that some operations bypass BuiltinOperationResolver (b/124041202).
//
// Usage:
// const OperationRegistration* operationRegistration =
// BuiltinOperationResolver::get()->findOperation(operationType);
// NN_RET_CHECK(operationRegistration != nullptr);
// NN_RET_CHECK(operationRegistration->validate != nullptr);
// NN_RET_CHECK(operationRegistration->validate(&context));
//
class BuiltinOperationResolver : public IOperationResolver {
DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
public:
static const BuiltinOperationResolver* get() {
static BuiltinOperationResolver instance;
return &instance;
}
const OperationRegistration* findOperation(OperationType operationType) const override;
private:
BuiltinOperationResolver();
void registerOperation(const OperationRegistration* operationRegistration);
const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
};
// NN_REGISTER_OPERATION creates OperationRegistration for consumption by
// OperationResolver.
//
// Usage:
// (check OperationRegistration::Flag for available fields and default values.)
//
// - With default flags.
// NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
// foo_op::prepare, foo_op::execute);
//
// - With a customized flag.
// NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
// foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
//
// - With multiple customized flags.
// NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
// foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
// .allowZeroSizedInput = true);
//
#ifdef NN_INCLUDE_CPU_IMPLEMENTATION
#define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \
const OperationRegistration* register_##identifier() { \
static OperationRegistration registration(OperationType::identifier, operationName, \
validate, prepare, execute, {__VA_ARGS__}); \
return ®istration; \
}
#else
// This version ignores CPU execution logic (prepare and execute).
// The compiler is supposed to omit that code so that only validation logic
// makes it into libneuralnetworks_utils.
#define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
...) \
const OperationRegistration* register_##identifier() { \
static OperationRegistration registration(OperationType::identifier, operationName, \
validate, nullptr, nullptr, {__VA_ARGS__}); \
return ®istration; \
}
#endif
} // namespace nn
} // namespace android
#endif // ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H