// Copyright (c) 2012 The Chromium Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. #include "net/websockets/websocket_frame_parser.h" #include <algorithm> #include <limits> #include "base/basictypes.h" #include "base/big_endian.h" #include "base/logging.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/scoped_vector.h" #include "net/base/io_buffer.h" #include "net/websockets/websocket_frame.h" namespace { const uint8 kFinalBit = 0x80; const uint8 kReserved1Bit = 0x40; const uint8 kReserved2Bit = 0x20; const uint8 kReserved3Bit = 0x10; const uint8 kOpCodeMask = 0xF; const uint8 kMaskBit = 0x80; const uint8 kPayloadLengthMask = 0x7F; const uint64 kMaxPayloadLengthWithoutExtendedLengthField = 125; const uint64 kPayloadLengthWithTwoByteExtendedLengthField = 126; const uint64 kPayloadLengthWithEightByteExtendedLengthField = 127; } // Unnamed namespace. namespace net { WebSocketFrameParser::WebSocketFrameParser() : current_read_pos_(0), frame_offset_(0), websocket_error_(kWebSocketNormalClosure) { std::fill(masking_key_.key, masking_key_.key + WebSocketFrameHeader::kMaskingKeyLength, '\0'); } WebSocketFrameParser::~WebSocketFrameParser() {} bool WebSocketFrameParser::Decode( const char* data, size_t length, ScopedVector<WebSocketFrameChunk>* frame_chunks) { if (websocket_error_ != kWebSocketNormalClosure) return false; if (!length) return true; // TODO(yutak): Remove copy. buffer_.insert(buffer_.end(), data, data + length); while (current_read_pos_ < buffer_.size()) { bool first_chunk = false; if (!current_frame_header_.get()) { DecodeFrameHeader(); if (websocket_error_ != kWebSocketNormalClosure) return false; // If frame header is incomplete, then carry over the remaining // data to the next round of Decode(). if (!current_frame_header_.get()) break; first_chunk = true; } scoped_ptr<WebSocketFrameChunk> frame_chunk = DecodeFramePayload(first_chunk); DCHECK(frame_chunk.get()); frame_chunks->push_back(frame_chunk.release()); if (current_frame_header_.get()) { DCHECK(current_read_pos_ == buffer_.size()); break; } } // Drain unnecessary data. TODO(yutak): Remove copy. (but how?) buffer_.erase(buffer_.begin(), buffer_.begin() + current_read_pos_); current_read_pos_ = 0; // Sanity check: the size of carried-over data should not exceed // the maximum possible length of a frame header. static const size_t kMaximumFrameHeaderSize = WebSocketFrameHeader::kBaseHeaderSize + WebSocketFrameHeader::kMaximumExtendedLengthSize + WebSocketFrameHeader::kMaskingKeyLength; DCHECK_LT(buffer_.size(), kMaximumFrameHeaderSize); return true; } void WebSocketFrameParser::DecodeFrameHeader() { typedef WebSocketFrameHeader::OpCode OpCode; static const int kMaskingKeyLength = WebSocketFrameHeader::kMaskingKeyLength; DCHECK(!current_frame_header_.get()); const char* start = &buffer_.front() + current_read_pos_; const char* current = start; const char* end = &buffer_.front() + buffer_.size(); // Header needs 2 bytes at minimum. if (end - current < 2) return; uint8 first_byte = *current++; uint8 second_byte = *current++; bool final = (first_byte & kFinalBit) != 0; bool reserved1 = (first_byte & kReserved1Bit) != 0; bool reserved2 = (first_byte & kReserved2Bit) != 0; bool reserved3 = (first_byte & kReserved3Bit) != 0; OpCode opcode = first_byte & kOpCodeMask; bool masked = (second_byte & kMaskBit) != 0; uint64 payload_length = second_byte & kPayloadLengthMask; if (payload_length == kPayloadLengthWithTwoByteExtendedLengthField) { if (end - current < 2) return; uint16 payload_length_16; base::ReadBigEndian(current, &payload_length_16); current += 2; payload_length = payload_length_16; if (payload_length <= kMaxPayloadLengthWithoutExtendedLengthField) websocket_error_ = kWebSocketErrorProtocolError; } else if (payload_length == kPayloadLengthWithEightByteExtendedLengthField) { if (end - current < 8) return; base::ReadBigEndian(current, &payload_length); current += 8; if (payload_length <= kuint16max || payload_length > static_cast<uint64>(kint64max)) { websocket_error_ = kWebSocketErrorProtocolError; } else if (payload_length > static_cast<uint64>(kint32max)) { websocket_error_ = kWebSocketErrorMessageTooBig; } } if (websocket_error_ != kWebSocketNormalClosure) { buffer_.clear(); current_read_pos_ = 0; current_frame_header_.reset(); frame_offset_ = 0; return; } if (masked) { if (end - current < kMaskingKeyLength) return; std::copy(current, current + kMaskingKeyLength, masking_key_.key); current += kMaskingKeyLength; } else { std::fill(masking_key_.key, masking_key_.key + kMaskingKeyLength, '\0'); } current_frame_header_.reset(new WebSocketFrameHeader(opcode)); current_frame_header_->final = final; current_frame_header_->reserved1 = reserved1; current_frame_header_->reserved2 = reserved2; current_frame_header_->reserved3 = reserved3; current_frame_header_->masked = masked; current_frame_header_->payload_length = payload_length; current_read_pos_ += current - start; DCHECK_EQ(0u, frame_offset_); } scoped_ptr<WebSocketFrameChunk> WebSocketFrameParser::DecodeFramePayload( bool first_chunk) { const char* current = &buffer_.front() + current_read_pos_; const char* end = &buffer_.front() + buffer_.size(); uint64 next_size = std::min<uint64>( end - current, current_frame_header_->payload_length - frame_offset_); // This check must pass because |payload_length| is already checked to be // less than std::numeric_limits<int>::max() when the header is parsed. DCHECK_LE(next_size, static_cast<uint64>(kint32max)); scoped_ptr<WebSocketFrameChunk> frame_chunk(new WebSocketFrameChunk); if (first_chunk) { frame_chunk->header = current_frame_header_->Clone(); } frame_chunk->final_chunk = false; if (next_size) { frame_chunk->data = new IOBufferWithSize(static_cast<int>(next_size)); char* io_data = frame_chunk->data->data(); memcpy(io_data, current, next_size); if (current_frame_header_->masked) { // The masking function is its own inverse, so we use the same function to // unmask as to mask. MaskWebSocketFramePayload( masking_key_, frame_offset_, io_data, next_size); } current_read_pos_ += next_size; frame_offset_ += next_size; } DCHECK_LE(frame_offset_, current_frame_header_->payload_length); if (frame_offset_ == current_frame_header_->payload_length) { frame_chunk->final_chunk = true; current_frame_header_.reset(); frame_offset_ = 0; } return frame_chunk.Pass(); } } // namespace net