// Copyright (C) 2017 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Protos for performing inference with an EmbeddingNetwork.

syntax = "proto2";
option optimize_for = LITE_RUNTIME;

package libtextclassifier.nlp_core;

// Wrapper for storing a matrix of parameters. These are stored in row-major
// order.
message MatrixParams {
  optional int32 rows = 1;  // # of rows in the matrix
  optional int32 cols = 2;  // # of columns in the matrix

  // Non-quantized matrix entries.
  repeated float value = 3 [packed = true];

  // Whether the matrix is quantized.
  optional bool is_quantized = 4 [default = false];

  // Bytes for all quantized values.  Each value (see "repeated float value"
  // field) is quantized to an uint8 (1 byte) value, and all these bytes are
  // concatenated into the string from this field.
  optional bytes bytes_for_quantized_values = 7;

  // Bytes for all scale factors for dequantizing the values.  The quantization
  // process generates a float16 scale factor for each column.  The 2 bytes for
  // each such float16 are put in little-endian order (least significant byte
  // first) and next all these pairs of bytes are concatenated into the string
  // from this field.
  optional bytes bytes_for_col_scales = 8;

  reserved 5, 6;
}

// Stores all parameters for a given EmbeddingNetwork. This can either be a
// EmbeddingNetwork or a PrecomputedEmbeddingNetwork: for precomputed networks,
// the embedding weights are actually the activations of the first hidden layer
// *before* the bias is added and the non-linear transform is applied.
//
// Thus, for PrecomputedEmbeddingNetwork storage, hidden layers are stored
// starting from the second hidden layer, while biases are stored for every
// hidden layer.
message EmbeddingNetworkProto {
  // Embeddings and hidden layers. Note that if is_precomputed == true, then the
  // embeddings should store the activations of the first hidden layer, so we
  // must have hidden_bias_size() == hidden_size() + 1 (we store weights for
  // first hidden layer bias, but no the layer itself.)
  repeated MatrixParams embeddings = 1;
  repeated MatrixParams hidden = 2;
  repeated MatrixParams hidden_bias = 3;

  // Final layer of the network.
  optional MatrixParams softmax = 4;
  optional MatrixParams softmax_bias = 5;

  // Element i of the repeated field below indicates number of features that use
  // the i-th embedding space.
  repeated int32 embedding_num_features = 7;

  // Whether or not this is intended to store a precomputed network.
  optional bool is_precomputed = 11 [default = false];

  // True if this EmbeddingNetworkProto can be used for inference with no
  // additional matrix transposition.
  //
  // Given an EmbeddingNetworkProto produced by a Neurosis training pipeline, we
  // have to transpose a few matrices (e.g., the embedding matrices) before we
  // can perform inference.  When we do so, we negate this flag.  Note: we don't
  // simply set this to true: transposing twice takes us to the original state.
  optional bool is_transposed = 12 [default = false];

  // Allow extensions.
  extensions 100 to max;

  reserved 6, 8, 9, 10;
}