/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "common/task-context.h" #include <stdlib.h> #include <string> #include "util/base/integral_types.h" #include "util/base/logging.h" #include "util/strings/numbers.h" namespace libtextclassifier { namespace nlp_core { namespace { int32 ParseInt32WithDefault(const std::string &s, int32 defval) { int32 value = defval; return ParseInt32(s.c_str(), &value) ? value : defval; } int64 ParseInt64WithDefault(const std::string &s, int64 defval) { int64 value = defval; return ParseInt64(s.c_str(), &value) ? value : defval; } double ParseDoubleWithDefault(const std::string &s, double defval) { double value = defval; return ParseDouble(s.c_str(), &value) ? value : defval; } } // namespace TaskInput *TaskContext::GetInput(const std::string &name) { // Return existing input if it exists. for (int i = 0; i < spec_.input_size(); ++i) { if (spec_.input(i).name() == name) return spec_.mutable_input(i); } // Create new input. TaskInput *input = spec_.add_input(); input->set_name(name); return input; } TaskInput *TaskContext::GetInput(const std::string &name, const std::string &file_format, const std::string &record_format) { TaskInput *input = GetInput(name); if (!file_format.empty()) { bool found = false; for (int i = 0; i < input->file_format_size(); ++i) { if (input->file_format(i) == file_format) found = true; } if (!found) input->add_file_format(file_format); } if (!record_format.empty()) { bool found = false; for (int i = 0; i < input->record_format_size(); ++i) { if (input->record_format(i) == record_format) found = true; } if (!found) input->add_record_format(record_format); } return input; } void TaskContext::SetParameter(const std::string &name, const std::string &value) { TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")"; // If the parameter already exists update the value. for (int i = 0; i < spec_.parameter_size(); ++i) { if (spec_.parameter(i).name() == name) { spec_.mutable_parameter(i)->set_value(value); return; } } // Add new parameter. TaskSpec::Parameter *param = spec_.add_parameter(); param->set_name(name); param->set_value(value); } std::string TaskContext::GetParameter(const std::string &name) const { // First try to find parameter in task specification. for (int i = 0; i < spec_.parameter_size(); ++i) { if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); } // Parameter not found, return empty std::string. return ""; } int TaskContext::GetIntParameter(const std::string &name) const { std::string value = GetParameter(name); return ParseInt32WithDefault(value, 0); } int64 TaskContext::GetInt64Parameter(const std::string &name) const { std::string value = GetParameter(name); return ParseInt64WithDefault(value, 0); } bool TaskContext::GetBoolParameter(const std::string &name) const { std::string value = GetParameter(name); return value == "true"; } double TaskContext::GetFloatParameter(const std::string &name) const { std::string value = GetParameter(name); return ParseDoubleWithDefault(value, 0.0); } std::string TaskContext::Get(const std::string &name, const char *defval) const { // First try to find parameter in task specification. for (int i = 0; i < spec_.parameter_size(); ++i) { if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); } // Parameter not found, return default value. return defval; } std::string TaskContext::Get(const std::string &name, const std::string &defval) const { return Get(name, defval.c_str()); } int TaskContext::Get(const std::string &name, int defval) const { std::string value = Get(name, ""); return ParseInt32WithDefault(value, defval); } int64 TaskContext::Get(const std::string &name, int64 defval) const { std::string value = Get(name, ""); return ParseInt64WithDefault(value, defval); } double TaskContext::Get(const std::string &name, double defval) const { std::string value = Get(name, ""); return ParseDoubleWithDefault(value, defval); } bool TaskContext::Get(const std::string &name, bool defval) const { std::string value = Get(name, ""); return value.empty() ? defval : value == "true"; } std::string TaskContext::InputFile(const TaskInput &input) { if (input.part_size() == 0) { TC_LOG(ERROR) << "No file for TaskInput " << input.name(); return ""; } if (input.part_size() > 1) { TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name(); } return input.part(0).file_pattern(); } bool TaskContext::Supports(const TaskInput &input, const std::string &file_format, const std::string &record_format) { // Check file format. if (input.file_format_size() > 0) { bool found = false; for (int i = 0; i < input.file_format_size(); ++i) { if (input.file_format(i) == file_format) { found = true; break; } } if (!found) return false; } // Check record format. if (input.record_format_size() > 0) { bool found = false; for (int i = 0; i < input.record_format_size(); ++i) { if (input.record_format(i) == record_format) { found = true; break; } } if (!found) return false; } return true; } } // namespace nlp_core } // namespace libtextclassifier