/* * Copyright (C) 2018 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "OperationsUtils.cpp" #include "gmock/gmock-matchers.h" #include "gtest/gtest.h" namespace android { namespace nn { namespace wrapper { namespace { using ::testing::ElementsAreArray; } // namespace TEST(CalculateBroadcastedShapeTest, Basic) { Shape shape1; Shape shape2; shape1.dimensions = {4, 3, 2, 1}; shape2.dimensions = {3, 1, 5}; Shape expectedOutputShape; expectedOutputShape.dimensions = {4, 3, 2, 5}; Shape actualOutputShape; EXPECT_TRUE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape)); EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions)); EXPECT_TRUE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape)); EXPECT_THAT(actualOutputShape.dimensions, ElementsAreArray(expectedOutputShape.dimensions)); } TEST(CalculateBroadcastedShapeTest, FailsOnIncompatible) { Shape shape1; Shape shape2; shape1.dimensions = {5}; shape2.dimensions = {3}; Shape actualOutputShape; EXPECT_FALSE(calculateBroadcastedShape(shape1, shape2, &actualOutputShape)); EXPECT_FALSE(calculateBroadcastedShape(shape2, shape1, &actualOutputShape)); } static int32_t getExtensionType(uint16_t extensionPrefix, uint16_t typeWithinExtension) { constexpr uint8_t kLowBitsType = static_cast<uint8_t>(Model::ExtensionTypeEncoding::LOW_BITS_TYPE); int32_t type = (extensionPrefix << kLowBitsType) | typeWithinExtension; EXPECT_TRUE(isExtensionOperandType(static_cast<OperandType>(type))); return type; } TEST(TensorHasUnspecifiedDimensionsTest, ExtensionTensorWithUnspecifiedRank) { // Regression test for b/124285861. EXPECT_TRUE(tensorHasUnspecifiedDimensions(getExtensionType(1, 0), /*dim=*/nullptr, /*dimCount=*/0)); } TEST(ValidateOperandTypeTest, ExtensionTensorWithUnspecifiedRank) { // Regression test for b/124104123. constexpr uint16_t kExtensionPrefix = 1; constexpr uint16_t kTypeWithinExtension = 0; int32_t extensionType = getExtensionType(kExtensionPrefix, kTypeWithinExtension); ANeuralNetworksOperandType type = { .type = extensionType, .dimensionCount = 0, .dimensions = nullptr, }; Extension::OperandTypeInformation info = { .type = kTypeWithinExtension, .isTensor = true, .byteSize = 4, }; EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/true), ANEURALNETWORKS_NO_ERROR); EXPECT_EQ(validateOperandType(type, &info, /*tag=*/"test", /*allowPartial=*/false), ANEURALNETWORKS_BAD_DATA); } } // namespace wrapper } // namespace nn } // namespace android