/* * Copyright (C) 2017 The Android Open Source Project * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "chre_host/socket_server.h" #include <poll.h> #include <cassert> #include <cinttypes> #include <csignal> #include <cstdlib> #include <map> #include <mutex> #include <cutils/sockets.h> #include "chre_host/log.h" namespace android { namespace chre { std::atomic<bool> SocketServer::sSignalReceived(false); namespace { void maskAllSignals() { sigset_t signalMask; sigfillset(&signalMask); if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) { LOG_ERROR("Couldn't mask all signals", errno); } } void maskAllSignalsExceptIntAndTerm() { sigset_t signalMask; sigfillset(&signalMask); sigdelset(&signalMask, SIGINT); sigdelset(&signalMask, SIGTERM); if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) { LOG_ERROR("Couldn't mask all signals except INT/TERM", errno); } } } // anonymous namespace SocketServer::SocketServer() { // Initialize the socket fds field for all inactive client slots to -1, so // poll skips over it, and we don't attempt to send on it for (size_t i = 1; i <= kMaxActiveClients; i++) { mPollFds[i].fd = -1; mPollFds[i].events = POLLIN; } } void SocketServer::run(const char *socketName, bool allowSocketCreation, ClientMessageCallback clientMessageCallback) { mClientMessageCallback = clientMessageCallback; mSockFd = android_get_control_socket(socketName); if (mSockFd == INVALID_SOCKET && allowSocketCreation) { LOGI("Didn't inherit socket, creating..."); mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED, SOCK_SEQPACKET); } if (mSockFd == INVALID_SOCKET) { LOGE("Couldn't get/create socket"); } else { int ret = listen(mSockFd, kMaxPendingConnectionRequests); if (ret < 0) { LOG_ERROR("Couldn't listen on socket", errno); } else { serviceSocket(); } { std::lock_guard<std::mutex> lock(mClientsMutex); for (const auto& pair : mClients) { int clientSocket = pair.first; if (close(clientSocket) != 0) { LOGI("Couldn't close client %" PRIu16 "'s socket: %s", pair.second.clientId, strerror(errno)); } } mClients.clear(); } close(mSockFd); } } void SocketServer::sendToAllClients(const void *data, size_t length) { std::lock_guard<std::mutex> lock(mClientsMutex); int deliveredCount = 0; for (const auto& pair : mClients) { int clientSocket = pair.first; uint16_t clientId = pair.second.clientId; if (sendToClientSocket(data, length, clientSocket, clientId)) { deliveredCount++; } else if (errno == EINTR) { // Exit early if we were interrupted - we should only get this for // SIGINT/SIGTERM, so we should exit quickly break; } } if (deliveredCount == 0) { LOGW("Got message but didn't deliver to any clients"); } } bool SocketServer::sendToClientById(const void *data, size_t length, uint16_t clientId) { std::lock_guard<std::mutex> lock(mClientsMutex); bool sent = false; for (const auto& pair : mClients) { uint16_t thisClientId = pair.second.clientId; if (thisClientId == clientId) { int clientSocket = pair.first; sent = sendToClientSocket(data, length, clientSocket, thisClientId); break; } } return sent; } void SocketServer::acceptClientConnection() { int clientSocket = accept(mSockFd, NULL, NULL); if (clientSocket < 0) { LOG_ERROR("Couldn't accept client connection", errno); } else if (mClients.size() >= kMaxActiveClients) { LOGW("Rejecting client request - maximum number of clients reached"); close(clientSocket); } else { ClientData clientData; clientData.clientId = mNextClientId++; // We currently don't handle wraparound - if we're getting this many // connects/disconnects, then something is wrong. // TODO: can handle this properly by iterating over the existing clients to // avoid a conflict. if (clientData.clientId == 0) { LOGE("Couldn't allocate client ID"); std::exit(-1); } bool slotFound = false; for (size_t i = 1; i <= kMaxActiveClients; i++) { if (mPollFds[i].fd < 0) { mPollFds[i].fd = clientSocket; slotFound = true; break; } } if (!slotFound) { LOGE("Couldn't find slot for client!"); assert(slotFound); close(clientSocket); } else { { std::lock_guard<std::mutex> lock(mClientsMutex); mClients[clientSocket] = clientData; } LOGI("Accepted new client connection (count %zu), assigned client ID %" PRIu16, mClients.size(), clientData.clientId); } } } void SocketServer::handleClientData(int clientSocket) { const ClientData& clientData = mClients[clientSocket]; uint16_t clientId = clientData.clientId; ssize_t packetSize = recv( clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT); if (packetSize < 0) { LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId, strerror(errno)); } else if (packetSize == 0) { LOGI("Client %" PRIu16 " disconnected", clientId); disconnectClient(clientSocket); } else { LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId); mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize); } } void SocketServer::disconnectClient(int clientSocket) { { std::lock_guard<std::mutex> lock(mClientsMutex); mClients.erase(clientSocket); } close(clientSocket); bool removed = false; for (size_t i = 1; i <= kMaxActiveClients; i++) { if (mPollFds[i].fd == clientSocket) { mPollFds[i].fd = -1; removed = true; break; } } if (!removed) { LOGE("Out of sync"); assert(removed); } } bool SocketServer::sendToClientSocket(const void *data, size_t length, int clientSocket, uint16_t clientId) { errno = 0; ssize_t bytesSent = send(clientSocket, data, length, 0); if (bytesSent < 0) { LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length, clientId, strerror(errno)); } else if (bytesSent == 0) { LOGW("Client %" PRIu16 " disconnected before message could be delivered", clientId); } else { LOGV("Delivered message of size %zu bytes to client %" PRIu16, length, clientId); } return (bytesSent > 0); } void SocketServer::serviceSocket() { constexpr size_t kListenIndex = 0; static_assert(kListenIndex == 0, "Code assumes that the first index is " "always the listen socket"); mPollFds[kListenIndex].fd = mSockFd; mPollFds[kListenIndex].events = POLLIN; // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM, // and ignore other signals sigset_t signalMask; sigfillset(&signalMask); sigdelset(&signalMask, SIGINT); sigdelset(&signalMask, SIGTERM); // Masking signals here ensure that after this point, we won't handle INT/TERM // until after we call into ppoll() maskAllSignals(); std::signal(SIGINT, signalHandler); std::signal(SIGTERM, signalHandler); LOGI("Ready to accept connections"); while (!sSignalReceived) { int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask); maskAllSignalsExceptIntAndTerm(); if (ret == -1) { LOGI("Exiting poll loop: %s", strerror(errno)); break; } if (mPollFds[kListenIndex].revents & POLLIN) { acceptClientConnection(); } for (size_t i = 1; i <= kMaxActiveClients; i++) { if (mPollFds[i].fd < 0) { continue; } if (mPollFds[i].revents & POLLIN) { handleClientData(mPollFds[i].fd); } } // Mask all signals to ensure that sSignalReceived can't become true between // checking it in the while condition and calling into ppoll() maskAllSignals(); } } void SocketServer::signalHandler(int signal) { LOGD("Caught signal %d", signal); sSignalReceived = true; } } // namespace chre } // namespace android