// Copyright 2015 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <brillo/streams/stream_utils.h>
#include <limits>
#include <base/bind.h>
#include <brillo/message_loops/message_loop.h>
#include <brillo/streams/stream_errors.h>
namespace brillo {
namespace stream_utils {
namespace {
// Status of asynchronous CopyData operation.
struct CopyDataState {
brillo::StreamPtr in_stream;
brillo::StreamPtr out_stream;
std::vector<uint8_t> buffer;
uint64_t remaining_to_copy;
uint64_t size_copied;
CopyDataSuccessCallback success_callback;
CopyDataErrorCallback error_callback;
};
// Async CopyData I/O error callback.
void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
const brillo::Error* error) {
state->error_callback.Run(std::move(state->in_stream),
std::move(state->out_stream), error);
}
// Forward declaration.
void PerformRead(const std::shared_ptr<CopyDataState>& state);
// Callback from read operation for CopyData. Writes the read data to the output
// stream and invokes PerformRead when done to restart the copy cycle.
void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
if (size == 0) {
state->success_callback.Run(std::move(state->in_stream),
std::move(state->out_stream),
state->size_copied);
return;
}
state->size_copied += size;
CHECK_GE(state->remaining_to_copy, size);
state->remaining_to_copy -= size;
brillo::ErrorPtr error;
bool success = state->out_stream->WriteAllAsync(
state->buffer.data(), size, base::Bind(&PerformRead, state),
base::Bind(&OnCopyDataError, state), &error);
if (!success)
OnCopyDataError(state, error.get());
}
// Performs the read part of asynchronous CopyData operation. Reads the data
// from input stream and invokes PerformWrite when done to write the data to
// the output stream.
void PerformRead(const std::shared_ptr<CopyDataState>& state) {
brillo::ErrorPtr error;
const uint64_t buffer_size = state->buffer.size();
// |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
// also not overflow size_t, so the static_cast below is safe.
size_t size_to_read =
static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
if (size_to_read == 0)
return PerformWrite(state, 0); // Nothing more to read. Finish operation.
bool success = state->in_stream->ReadAsync(
state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
base::Bind(OnCopyDataError, state), &error);
if (!success)
OnCopyDataError(state, error.get());
}
} // anonymous namespace
bool ErrorStreamClosed(const base::Location& location,
ErrorPtr* error) {
Error::AddTo(error,
location,
errors::stream::kDomain,
errors::stream::kStreamClosed,
"Stream is closed");
return false;
}
bool ErrorOperationNotSupported(const base::Location& location,
ErrorPtr* error) {
Error::AddTo(error,
location,
errors::stream::kDomain,
errors::stream::kOperationNotSupported,
"Stream operation not supported");
return false;
}
bool ErrorReadPastEndOfStream(const base::Location& location,
ErrorPtr* error) {
Error::AddTo(error,
location,
errors::stream::kDomain,
errors::stream::kPartialData,
"Reading past the end of stream");
return false;
}
bool ErrorOperationTimeout(const base::Location& location,
ErrorPtr* error) {
Error::AddTo(error,
location,
errors::stream::kDomain,
errors::stream::kTimeout,
"Operation timed out");
return false;
}
bool CheckInt64Overflow(const base::Location& location,
uint64_t position,
int64_t offset,
ErrorPtr* error) {
if (offset < 0) {
// Subtracting the offset. Make sure we do not underflow.
uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
if (position >= unsigned_offset)
return true;
} else {
// Adding the offset. Make sure we do not overflow unsigned 64 bits first.
if (position <= std::numeric_limits<uint64_t>::max() - offset) {
// We definitely will not overflow the unsigned 64 bit integer.
// Now check that we end up within the limits of signed 64 bit integer.
uint64_t new_position = position + offset;
uint64_t max = std::numeric_limits<int64_t>::max();
if (new_position <= max)
return true;
}
}
Error::AddTo(error,
location,
errors::stream::kDomain,
errors::stream::kInvalidParameter,
"The stream offset value is out of range");
return false;
}
bool CalculateStreamPosition(const base::Location& location,
int64_t offset,
Stream::Whence whence,
uint64_t current_position,
uint64_t stream_size,
uint64_t* new_position,
ErrorPtr* error) {
uint64_t pos = 0;
switch (whence) {
case Stream::Whence::FROM_BEGIN:
pos = 0;
break;
case Stream::Whence::FROM_CURRENT:
pos = current_position;
break;
case Stream::Whence::FROM_END:
pos = stream_size;
break;
default:
Error::AddTo(error,
location,
errors::stream::kDomain,
errors::stream::kInvalidParameter,
"Invalid stream position whence");
return false;
}
if (!CheckInt64Overflow(location, pos, offset, error))
return false;
*new_position = static_cast<uint64_t>(pos + offset);
return true;
}
void CopyData(StreamPtr in_stream,
StreamPtr out_stream,
const CopyDataSuccessCallback& success_callback,
const CopyDataErrorCallback& error_callback) {
CopyData(std::move(in_stream), std::move(out_stream),
std::numeric_limits<uint64_t>::max(), 4096, success_callback,
error_callback);
}
void CopyData(StreamPtr in_stream,
StreamPtr out_stream,
uint64_t max_size_to_copy,
size_t buffer_size,
const CopyDataSuccessCallback& success_callback,
const CopyDataErrorCallback& error_callback) {
auto state = std::make_shared<CopyDataState>();
state->in_stream = std::move(in_stream);
state->out_stream = std::move(out_stream);
state->buffer.resize(buffer_size);
state->remaining_to_copy = max_size_to_copy;
state->size_copied = 0;
state->success_callback = success_callback;
state->error_callback = error_callback;
brillo::MessageLoop::current()->PostTask(FROM_HERE,
base::Bind(&PerformRead, state));
}
} // namespace stream_utils
} // namespace brillo