// Copyright 2013 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "google_apis/gcm/engine/mcs_client.h" #include "base/basictypes.h" #include "base/message_loop/message_loop.h" #include "base/strings/string_number_conversions.h" #include "google_apis/gcm/base/mcs_util.h" #include "google_apis/gcm/base/socket_stream.h" #include "google_apis/gcm/engine/connection_factory.h" #include "google_apis/gcm/engine/rmq_store.h" using namespace google::protobuf::io; namespace gcm { namespace { typedef scoped_ptr<google::protobuf::MessageLite> MCSProto; // TODO(zea): get these values from MCS settings. const int64 kHeartbeatDefaultSeconds = 60 * 15; // 15 minutes. // The category of messages intended for the GCM client itself from MCS. const char kMCSCategory[] = "com.google.android.gsf.gtalkservice"; // The from field for messages originating in the GCM client. const char kGCMFromField[] = "gcm@android.com"; // MCS status message types. const char kIdleNotification[] = "IdleNotification"; // TODO(zea): consume the following message types: // const char kAlwaysShowOnIdle[] = "ShowAwayOnIdle"; // const char kPowerNotification[] = "PowerNotification"; // const char kDataActiveNotification[] = "DataActiveNotification"; // The number of unacked messages to allow before sending a stream ack. // Applies to both incoming and outgoing messages. // TODO(zea): make this server configurable. const int kUnackedMessageBeforeStreamAck = 10; // The global maximum number of pending messages to have in the send queue. const size_t kMaxSendQueueSize = 10 * 1024; // The maximum message size that can be sent to the server. const int kMaxMessageBytes = 4 * 1024; // 4KB, like the server. // Helper for converting a proto persistent id list to a vector of strings. bool BuildPersistentIdListFromProto(const google::protobuf::string& bytes, std::vector<std::string>* id_list) { mcs_proto::SelectiveAck selective_ack; if (!selective_ack.ParseFromString(bytes)) return false; std::vector<std::string> new_list; for (int i = 0; i < selective_ack.id_size(); ++i) { DCHECK(!selective_ack.id(i).empty()); new_list.push_back(selective_ack.id(i)); } id_list->swap(new_list); return true; } } // namespace struct ReliablePacketInfo { ReliablePacketInfo(); ~ReliablePacketInfo(); // The stream id with which the message was sent. uint32 stream_id; // If reliable delivery was requested, the persistent id of the message. std::string persistent_id; // The type of message itself (for easier lookup). uint8 tag; // The protobuf of the message itself. MCSProto protobuf; }; ReliablePacketInfo::ReliablePacketInfo() : stream_id(0), tag(0) { } ReliablePacketInfo::~ReliablePacketInfo() {} MCSClient::MCSClient( const base::FilePath& rmq_path, ConnectionFactory* connection_factory, scoped_refptr<base::SequencedTaskRunner> blocking_task_runner) : state_(UNINITIALIZED), android_id_(0), security_token_(0), connection_factory_(connection_factory), connection_handler_(NULL), last_device_to_server_stream_id_received_(0), last_server_to_device_stream_id_received_(0), stream_id_out_(0), stream_id_in_(0), rmq_store_(rmq_path, blocking_task_runner), heartbeat_interval_( base::TimeDelta::FromSeconds(kHeartbeatDefaultSeconds)), heartbeat_timer_(true, true), blocking_task_runner_(blocking_task_runner), weak_ptr_factory_(this) { } MCSClient::~MCSClient() { } void MCSClient::Initialize( const InitializationCompleteCallback& initialization_callback, const OnMessageReceivedCallback& message_received_callback, const OnMessageSentCallback& message_sent_callback) { DCHECK_EQ(state_, UNINITIALIZED); initialization_callback_ = initialization_callback; message_received_callback_ = message_received_callback; message_sent_callback_ = message_sent_callback; state_ = LOADING; rmq_store_.Load(base::Bind(&MCSClient::OnRMQLoadFinished, weak_ptr_factory_.GetWeakPtr())); connection_factory_->Initialize( base::Bind(&MCSClient::ResetStateAndBuildLoginRequest, weak_ptr_factory_.GetWeakPtr()), base::Bind(&MCSClient::HandlePacketFromWire, weak_ptr_factory_.GetWeakPtr()), base::Bind(&MCSClient::MaybeSendMessage, weak_ptr_factory_.GetWeakPtr())); connection_handler_ = connection_factory_->GetConnectionHandler(); } void MCSClient::Login(uint64 android_id, uint64 security_token) { DCHECK_EQ(state_, LOADED); if (android_id != android_id_ && security_token != security_token_) { DCHECK(android_id); DCHECK(security_token); DCHECK(restored_unackeds_server_ids_.empty()); android_id_ = android_id; security_token_ = security_token; rmq_store_.SetDeviceCredentials(android_id_, security_token_, base::Bind(&MCSClient::OnRMQUpdateFinished, weak_ptr_factory_.GetWeakPtr())); } state_ = CONNECTING; connection_factory_->Connect(); } void MCSClient::SendMessage(const MCSMessage& message, bool use_rmq) { DCHECK_EQ(state_, CONNECTED); if (to_send_.size() > kMaxSendQueueSize) { base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(message_sent_callback_, "Message queue full.")); return; } if (message.size() > kMaxMessageBytes) { base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(message_sent_callback_, "Message too large.")); return; } ReliablePacketInfo* packet_info = new ReliablePacketInfo(); packet_info->protobuf = message.CloneProtobuf(); if (use_rmq) { PersistentId persistent_id = GetNextPersistentId(); DVLOG(1) << "Setting persistent id to " << persistent_id; packet_info->persistent_id = persistent_id; SetPersistentId(persistent_id, packet_info->protobuf.get()); rmq_store_.AddOutgoingMessage(persistent_id, MCSMessage(message.tag(), *(packet_info->protobuf)), base::Bind(&MCSClient::OnRMQUpdateFinished, weak_ptr_factory_.GetWeakPtr())); } else { // Check that there is an active connection to the endpoint. if (!connection_handler_->CanSendMessage()) { base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(message_sent_callback_, "Unable to reach endpoint")); return; } } to_send_.push_back(make_linked_ptr(packet_info)); MaybeSendMessage(); } void MCSClient::Destroy() { rmq_store_.Destroy(base::Bind(&MCSClient::OnRMQUpdateFinished, weak_ptr_factory_.GetWeakPtr())); } void MCSClient::ResetStateAndBuildLoginRequest( mcs_proto::LoginRequest* request) { DCHECK(android_id_); DCHECK(security_token_); stream_id_in_ = 0; stream_id_out_ = 1; last_device_to_server_stream_id_received_ = 0; last_server_to_device_stream_id_received_ = 0; // TODO(zea): expire all messages older than their TTL. // Add any pending acknowledgments to the list of ids. for (StreamIdToPersistentIdMap::const_iterator iter = unacked_server_ids_.begin(); iter != unacked_server_ids_.end(); ++iter) { restored_unackeds_server_ids_.push_back(iter->second); } unacked_server_ids_.clear(); // Any acknowledged server ids which have not been confirmed by the server // are treated like unacknowledged ids. for (std::map<StreamId, PersistentIdList>::const_iterator iter = acked_server_ids_.begin(); iter != acked_server_ids_.end(); ++iter) { restored_unackeds_server_ids_.insert(restored_unackeds_server_ids_.end(), iter->second.begin(), iter->second.end()); } acked_server_ids_.clear(); // Then build the request, consuming all pending acknowledgments. request->Swap(BuildLoginRequest(android_id_, security_token_).get()); for (PersistentIdList::const_iterator iter = restored_unackeds_server_ids_.begin(); iter != restored_unackeds_server_ids_.end(); ++iter) { request->add_received_persistent_id(*iter); } acked_server_ids_[stream_id_out_] = restored_unackeds_server_ids_; restored_unackeds_server_ids_.clear(); // Push all unacknowledged messages to front of send queue. No need to save // to RMQ, as all messages that reach this point should already have been // saved as necessary. while (!to_resend_.empty()) { to_send_.push_front(to_resend_.back()); to_resend_.pop_back(); } DVLOG(1) << "Resetting state, with " << request->received_persistent_id_size() << " incoming acks pending, and " << to_send_.size() << " pending outgoing messages."; heartbeat_timer_.Stop(); state_ = CONNECTING; } void MCSClient::SendHeartbeat() { SendMessage(MCSMessage(kHeartbeatPingTag, mcs_proto::HeartbeatPing()), false); } void MCSClient::OnRMQLoadFinished(const RMQStore::LoadResult& result) { if (!result.success) { state_ = UNINITIALIZED; LOG(ERROR) << "Failed to load/create RMQ state. Not connecting."; initialization_callback_.Run(false, 0, 0); return; } state_ = LOADED; stream_id_out_ = 1; // Login request is hardcoded to id 1. if (result.device_android_id == 0 || result.device_security_token == 0) { DVLOG(1) << "No device credentials found, assuming new client."; initialization_callback_.Run(true, 0, 0); return; } android_id_ = result.device_android_id; security_token_ = result.device_security_token; DVLOG(1) << "RMQ Load finished with " << result.incoming_messages.size() << " incoming acks pending and " << result.outgoing_messages.size() << " outgoing messages pending."; restored_unackeds_server_ids_ = result.incoming_messages; // First go through and order the outgoing messages by recency. std::map<uint64, google::protobuf::MessageLite*> ordered_messages; for (std::map<PersistentId, google::protobuf::MessageLite*>::const_iterator iter = result.outgoing_messages.begin(); iter != result.outgoing_messages.end(); ++iter) { uint64 timestamp = 0; if (!base::StringToUint64(iter->first, ×tamp)) { LOG(ERROR) << "Invalid restored message."; return; } ordered_messages[timestamp] = iter->second; } // Now go through and add the outgoing messages to the send queue in their // appropriate order (oldest at front, most recent at back). for (std::map<uint64, google::protobuf::MessageLite*>::const_iterator iter = ordered_messages.begin(); iter != ordered_messages.end(); ++iter) { ReliablePacketInfo* packet_info = new ReliablePacketInfo(); packet_info->protobuf.reset(iter->second); packet_info->persistent_id = base::Uint64ToString(iter->first); to_send_.push_back(make_linked_ptr(packet_info)); } initialization_callback_.Run(true, android_id_, security_token_); } void MCSClient::OnRMQUpdateFinished(bool success) { LOG_IF(ERROR, !success) << "RMQ Update failed!"; // TODO(zea): Rebuild the store from scratch in case of persistence failure? } void MCSClient::MaybeSendMessage() { if (to_send_.empty()) return; if (!connection_handler_->CanSendMessage()) return; // TODO(zea): drop messages older than their TTL. DVLOG(1) << "Pending output message found, sending."; MCSPacketInternal packet = to_send_.front(); to_send_.pop_front(); if (!packet->persistent_id.empty()) to_resend_.push_back(packet); SendPacketToWire(packet.get()); } void MCSClient::SendPacketToWire(ReliablePacketInfo* packet_info) { // Reset the heartbeat interval. heartbeat_timer_.Reset(); packet_info->stream_id = ++stream_id_out_; DVLOG(1) << "Sending packet of type " << packet_info->protobuf->GetTypeName(); // Set the proper last received stream id to acknowledge received server // packets. DVLOG(1) << "Setting last stream id received to " << stream_id_in_; SetLastStreamIdReceived(stream_id_in_, packet_info->protobuf.get()); if (stream_id_in_ != last_server_to_device_stream_id_received_) { last_server_to_device_stream_id_received_ = stream_id_in_; // Mark all acknowledged server messages as such. Note: they're not dropped, // as it may be that they'll need to be re-acked if this message doesn't // make it. PersistentIdList persistent_id_list; for (StreamIdToPersistentIdMap::const_iterator iter = unacked_server_ids_.begin(); iter != unacked_server_ids_.end(); ++iter) { DCHECK_LE(iter->first, last_server_to_device_stream_id_received_); persistent_id_list.push_back(iter->second); } unacked_server_ids_.clear(); acked_server_ids_[stream_id_out_] = persistent_id_list; } connection_handler_->SendMessage(*packet_info->protobuf); } void MCSClient::HandleMCSDataMesssage( scoped_ptr<google::protobuf::MessageLite> protobuf) { mcs_proto::DataMessageStanza* data_message = reinterpret_cast<mcs_proto::DataMessageStanza*>(protobuf.get()); // TODO(zea): implement a proper status manager rather than hardcoding these // values. scoped_ptr<mcs_proto::DataMessageStanza> response( new mcs_proto::DataMessageStanza()); response->set_from(kGCMFromField); bool send = false; for (int i = 0; i < data_message->app_data_size(); ++i) { const mcs_proto::AppData& app_data = data_message->app_data(i); if (app_data.key() == kIdleNotification) { // Tell the MCS server the client is not idle. send = true; mcs_proto::AppData data; data.set_key(kIdleNotification); data.set_value("false"); response->add_app_data()->CopyFrom(data); response->set_category(kMCSCategory); } } if (send) { SendMessage( MCSMessage(kDataMessageStanzaTag, response.PassAs<const google::protobuf::MessageLite>()), false); } } void MCSClient::HandlePacketFromWire( scoped_ptr<google::protobuf::MessageLite> protobuf) { if (!protobuf.get()) return; uint8 tag = GetMCSProtoTag(*protobuf); PersistentId persistent_id = GetPersistentId(*protobuf); StreamId last_stream_id_received = GetLastStreamIdReceived(*protobuf); if (last_stream_id_received != 0) { last_device_to_server_stream_id_received_ = last_stream_id_received; // Process device to server messages that have now been acknowledged by the // server. Because messages are stored in order, just pop off all that have // a stream id lower than server's last received stream id. HandleStreamAck(last_stream_id_received); // Process server_to_device_messages that the server now knows were // acknowledged. Again, they're in order, so just keep going until the // stream id is reached. StreamIdList acked_stream_ids_to_remove; for (std::map<StreamId, PersistentIdList>::iterator iter = acked_server_ids_.begin(); iter != acked_server_ids_.end() && iter->first <= last_stream_id_received; ++iter) { acked_stream_ids_to_remove.push_back(iter->first); } for (StreamIdList::iterator iter = acked_stream_ids_to_remove.begin(); iter != acked_stream_ids_to_remove.end(); ++iter) { acked_server_ids_.erase(*iter); } } ++stream_id_in_; if (!persistent_id.empty()) { unacked_server_ids_[stream_id_in_] = persistent_id; rmq_store_.AddIncomingMessage(persistent_id, base::Bind(&MCSClient::OnRMQUpdateFinished, weak_ptr_factory_.GetWeakPtr())); } DVLOG(1) << "Received message of type " << protobuf->GetTypeName() << " with persistent id " << (persistent_id.empty() ? "NULL" : persistent_id) << ", stream id " << stream_id_in_ << " and last stream id received " << last_stream_id_received; if (unacked_server_ids_.size() > 0 && unacked_server_ids_.size() % kUnackedMessageBeforeStreamAck == 0) { SendMessage(MCSMessage(kIqStanzaTag, BuildStreamAck(). PassAs<const google::protobuf::MessageLite>()), false); } switch (tag) { case kLoginResponseTag: { mcs_proto::LoginResponse* login_response = reinterpret_cast<mcs_proto::LoginResponse*>(protobuf.get()); DVLOG(1) << "Received login response:"; DVLOG(1) << " Id: " << login_response->id(); DVLOG(1) << " Timestamp: " << login_response->server_timestamp(); if (login_response->has_error()) { state_ = UNINITIALIZED; DVLOG(1) << " Error code: " << login_response->error().code(); DVLOG(1) << " Error message: " << login_response->error().message(); initialization_callback_.Run(false, 0, 0); return; } state_ = CONNECTED; stream_id_in_ = 1; // To account for the login response. DCHECK_EQ(1U, stream_id_out_); // Pass the login response on up. base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(message_received_callback_, MCSMessage(tag, protobuf.PassAs< const google::protobuf::MessageLite>()))); // If there are pending messages, attempt to send one. if (!to_send_.empty()) { base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(&MCSClient::MaybeSendMessage, weak_ptr_factory_.GetWeakPtr())); } heartbeat_timer_.Start(FROM_HERE, heartbeat_interval_, base::Bind(&MCSClient::SendHeartbeat, weak_ptr_factory_.GetWeakPtr())); return; } case kHeartbeatPingTag: DCHECK_GE(stream_id_in_, 1U); DVLOG(1) << "Received heartbeat ping, sending ack."; SendMessage( MCSMessage(kHeartbeatAckTag, mcs_proto::HeartbeatAck()), false); return; case kHeartbeatAckTag: DCHECK_GE(stream_id_in_, 1U); DVLOG(1) << "Received heartbeat ack."; // TODO(zea): add logic to reconnect if no ack received within a certain // timeout (with backoff). return; case kCloseTag: LOG(ERROR) << "Received close command, closing connection."; state_ = UNINITIALIZED; initialization_callback_.Run(false, 0, 0); // TODO(zea): should this happen in non-error cases? Reconnect? return; case kIqStanzaTag: { DCHECK_GE(stream_id_in_, 1U); mcs_proto::IqStanza* iq_stanza = reinterpret_cast<mcs_proto::IqStanza*>(protobuf.get()); const mcs_proto::Extension& iq_extension = iq_stanza->extension(); switch (iq_extension.id()) { case kSelectiveAck: { PersistentIdList acked_ids; if (BuildPersistentIdListFromProto(iq_extension.data(), &acked_ids)) { HandleSelectiveAck(acked_ids); } return; } case kStreamAck: // Do nothing. The last received stream id is always processed if it's // present. return; default: LOG(WARNING) << "Received invalid iq stanza extension " << iq_extension.id(); return; } } case kDataMessageStanzaTag: { DCHECK_GE(stream_id_in_, 1U); mcs_proto::DataMessageStanza* data_message = reinterpret_cast<mcs_proto::DataMessageStanza*>(protobuf.get()); if (data_message->category() == kMCSCategory) { HandleMCSDataMesssage(protobuf.Pass()); return; } DCHECK(protobuf.get()); base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(message_received_callback_, MCSMessage(tag, protobuf.PassAs< const google::protobuf::MessageLite>()))); return; } default: LOG(ERROR) << "Received unexpected message of type " << static_cast<int>(tag); return; } } void MCSClient::HandleStreamAck(StreamId last_stream_id_received) { PersistentIdList acked_outgoing_persistent_ids; StreamIdList acked_outgoing_stream_ids; while (!to_resend_.empty() && to_resend_.front()->stream_id <= last_stream_id_received) { const MCSPacketInternal& outgoing_packet = to_resend_.front(); acked_outgoing_persistent_ids.push_back(outgoing_packet->persistent_id); acked_outgoing_stream_ids.push_back(outgoing_packet->stream_id); to_resend_.pop_front(); } DVLOG(1) << "Server acked " << acked_outgoing_persistent_ids.size() << " outgoing messages, " << to_resend_.size() << " remaining unacked"; rmq_store_.RemoveOutgoingMessages(acked_outgoing_persistent_ids, base::Bind(&MCSClient::OnRMQUpdateFinished, weak_ptr_factory_.GetWeakPtr())); HandleServerConfirmedReceipt(last_stream_id_received); } void MCSClient::HandleSelectiveAck(const PersistentIdList& id_list) { // First check the to_resend_ queue. Acknowledgments should always happen // in the order they were sent, so if messages are present they should match // the acknowledge list. PersistentIdList::const_iterator iter = id_list.begin(); for (; iter != id_list.end() && !to_resend_.empty(); ++iter) { const MCSPacketInternal& outgoing_packet = to_resend_.front(); DCHECK_EQ(outgoing_packet->persistent_id, *iter); // No need to re-acknowledge any server messages this message already // acknowledged. StreamId device_stream_id = outgoing_packet->stream_id; HandleServerConfirmedReceipt(device_stream_id); to_resend_.pop_front(); } // If the acknowledged ids aren't all there, they might be in the to_send_ // queue (typically when a StreamAck confirms messages as part of a login // response). for (; iter != id_list.end() && !to_send_.empty(); ++iter) { const MCSPacketInternal& outgoing_packet = to_send_.front(); DCHECK_EQ(outgoing_packet->persistent_id, *iter); // No need to re-acknowledge any server messages this message already // acknowledged. StreamId device_stream_id = outgoing_packet->stream_id; HandleServerConfirmedReceipt(device_stream_id); to_send_.pop_front(); } DCHECK(iter == id_list.end()); DVLOG(1) << "Server acked " << id_list.size() << " messages, " << to_resend_.size() << " remaining unacked."; rmq_store_.RemoveOutgoingMessages(id_list, base::Bind(&MCSClient::OnRMQUpdateFinished, weak_ptr_factory_.GetWeakPtr())); // Resend any remaining outgoing messages, as they were not received by the // server. DVLOG(1) << "Resending " << to_resend_.size() << " messages."; while (!to_resend_.empty()) { to_send_.push_front(to_resend_.back()); to_resend_.pop_back(); } } void MCSClient::HandleServerConfirmedReceipt(StreamId device_stream_id) { // TODO(zea): use a message id the sender understands. base::MessageLoop::current()->PostTask( FROM_HERE, base::Bind(message_sent_callback_, "Message " + base::UintToString(device_stream_id) + " sent.")); PersistentIdList acked_incoming_ids; for (std::map<StreamId, PersistentIdList>::iterator iter = acked_server_ids_.begin(); iter != acked_server_ids_.end() && iter->first <= device_stream_id;) { acked_incoming_ids.insert(acked_incoming_ids.end(), iter->second.begin(), iter->second.end()); acked_server_ids_.erase(iter++); } DVLOG(1) << "Server confirmed receipt of " << acked_incoming_ids.size() << " acknowledged server messages."; rmq_store_.RemoveIncomingMessages(acked_incoming_ids, base::Bind(&MCSClient::OnRMQUpdateFinished, weak_ptr_factory_.GetWeakPtr())); } MCSClient::PersistentId MCSClient::GetNextPersistentId() { return base::Uint64ToString(base::TimeTicks::Now().ToInternalValue()); } } // namespace gcm