// Copyright 2015 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <brillo/streams/stream_utils.h>

#include <limits>

#include <base/bind.h>
#include <brillo/message_loops/message_loop.h>
#include <brillo/streams/stream_errors.h>

namespace brillo {
namespace stream_utils {

namespace {

// Status of asynchronous CopyData operation.
struct CopyDataState {
  brillo::StreamPtr in_stream;
  brillo::StreamPtr out_stream;
  std::vector<uint8_t> buffer;
  uint64_t remaining_to_copy;
  uint64_t size_copied;
  CopyDataSuccessCallback success_callback;
  CopyDataErrorCallback error_callback;
};

// Async CopyData I/O error callback.
void OnCopyDataError(const std::shared_ptr<CopyDataState>& state,
                     const brillo::Error* error) {
  state->error_callback.Run(std::move(state->in_stream),
                            std::move(state->out_stream), error);
}

// Forward declaration.
void PerformRead(const std::shared_ptr<CopyDataState>& state);

// Callback from read operation for CopyData. Writes the read data to the output
// stream and invokes PerformRead when done to restart the copy cycle.
void PerformWrite(const std::shared_ptr<CopyDataState>& state, size_t size) {
  if (size == 0) {
    state->success_callback.Run(std::move(state->in_stream),
                                std::move(state->out_stream),
                                state->size_copied);
    return;
  }
  state->size_copied += size;
  CHECK_GE(state->remaining_to_copy, size);
  state->remaining_to_copy -= size;

  brillo::ErrorPtr error;
  bool success = state->out_stream->WriteAllAsync(
      state->buffer.data(), size, base::Bind(&PerformRead, state),
      base::Bind(&OnCopyDataError, state), &error);

  if (!success)
    OnCopyDataError(state, error.get());
}

// Performs the read part of asynchronous CopyData operation. Reads the data
// from input stream and invokes PerformWrite when done to write the data to
// the output stream.
void PerformRead(const std::shared_ptr<CopyDataState>& state) {
  brillo::ErrorPtr error;
  const uint64_t buffer_size = state->buffer.size();
  // |buffer_size| is guaranteed to fit in size_t, so |size_to_read| value will
  // also not overflow size_t, so the static_cast below is safe.
  size_t size_to_read =
      static_cast<size_t>(std::min(buffer_size, state->remaining_to_copy));
  if (size_to_read == 0)
    return PerformWrite(state, 0);  // Nothing more to read. Finish operation.
  bool success = state->in_stream->ReadAsync(
      state->buffer.data(), size_to_read, base::Bind(PerformWrite, state),
      base::Bind(OnCopyDataError, state), &error);

  if (!success)
    OnCopyDataError(state, error.get());
}

}  // anonymous namespace

bool ErrorStreamClosed(const base::Location& location,
                       ErrorPtr* error) {
  Error::AddTo(error,
               location,
               errors::stream::kDomain,
               errors::stream::kStreamClosed,
               "Stream is closed");
  return false;
}

bool ErrorOperationNotSupported(const base::Location& location,
                                ErrorPtr* error) {
  Error::AddTo(error,
               location,
               errors::stream::kDomain,
               errors::stream::kOperationNotSupported,
               "Stream operation not supported");
  return false;
}

bool ErrorReadPastEndOfStream(const base::Location& location,
                              ErrorPtr* error) {
  Error::AddTo(error,
               location,
               errors::stream::kDomain,
               errors::stream::kPartialData,
               "Reading past the end of stream");
  return false;
}

bool ErrorOperationTimeout(const base::Location& location,
                           ErrorPtr* error) {
  Error::AddTo(error,
               location,
               errors::stream::kDomain,
               errors::stream::kTimeout,
               "Operation timed out");
  return false;
}

bool CheckInt64Overflow(const base::Location& location,
                        uint64_t position,
                        int64_t offset,
                        ErrorPtr* error) {
  if (offset < 0) {
    // Subtracting the offset. Make sure we do not underflow.
    uint64_t unsigned_offset = static_cast<uint64_t>(-offset);
    if (position >= unsigned_offset)
      return true;
  } else {
    // Adding the offset. Make sure we do not overflow unsigned 64 bits first.
    if (position <= std::numeric_limits<uint64_t>::max() - offset) {
      // We definitely will not overflow the unsigned 64 bit integer.
      // Now check that we end up within the limits of signed 64 bit integer.
      uint64_t new_position = position + offset;
      uint64_t max = std::numeric_limits<int64_t>::max();
      if (new_position <= max)
        return true;
    }
  }
  Error::AddTo(error,
               location,
               errors::stream::kDomain,
               errors::stream::kInvalidParameter,
               "The stream offset value is out of range");
  return false;
}

bool CalculateStreamPosition(const base::Location& location,
                             int64_t offset,
                             Stream::Whence whence,
                             uint64_t current_position,
                             uint64_t stream_size,
                             uint64_t* new_position,
                             ErrorPtr* error) {
  uint64_t pos = 0;
  switch (whence) {
    case Stream::Whence::FROM_BEGIN:
      pos = 0;
      break;

    case Stream::Whence::FROM_CURRENT:
      pos = current_position;
      break;

    case Stream::Whence::FROM_END:
      pos = stream_size;
      break;

    default:
      Error::AddTo(error,
                   location,
                   errors::stream::kDomain,
                   errors::stream::kInvalidParameter,
                   "Invalid stream position whence");
      return false;
  }

  if (!CheckInt64Overflow(location, pos, offset, error))
    return false;

  *new_position = static_cast<uint64_t>(pos + offset);
  return true;
}

void CopyData(StreamPtr in_stream,
              StreamPtr out_stream,
              const CopyDataSuccessCallback& success_callback,
              const CopyDataErrorCallback& error_callback) {
  CopyData(std::move(in_stream), std::move(out_stream),
           std::numeric_limits<uint64_t>::max(), 4096, success_callback,
           error_callback);
}

void CopyData(StreamPtr in_stream,
              StreamPtr out_stream,
              uint64_t max_size_to_copy,
              size_t buffer_size,
              const CopyDataSuccessCallback& success_callback,
              const CopyDataErrorCallback& error_callback) {
  auto state = std::make_shared<CopyDataState>();
  state->in_stream = std::move(in_stream);
  state->out_stream = std::move(out_stream);
  state->buffer.resize(buffer_size);
  state->remaining_to_copy = max_size_to_copy;
  state->size_copied = 0;
  state->success_callback = success_callback;
  state->error_callback = error_callback;
  brillo::MessageLoop::current()->PostTask(FROM_HERE,
                                             base::Bind(&PerformRead, state));
}

}  // namespace stream_utils
}  // namespace brillo