# -*- coding: utf-8 -*-
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for tracker file functionality."""

import errno
import hashlib
import json
import os
import re

from boto import config
from gslib.exception import CommandException
from gslib.util import CreateDirIfNeeded
from gslib.util import GetGsutilStateDir
from gslib.util import ResumableThreshold
from gslib.util import UTF8

# The maximum length of a file name can vary wildly between different
# operating systems, so we always ensure that tracker files are less
# than 100 characters in order to avoid any such issues.
MAX_TRACKER_FILE_NAME_LENGTH = 100


TRACKER_FILE_UNWRITABLE_EXCEPTION_TEXT = (
    'Couldn\'t write tracker file (%s): %s. This can happen if gsutil is '
    'configured to save tracker files to an unwritable directory)')


class TrackerFileType(object):
  UPLOAD = 'upload'
  DOWNLOAD = 'download'
  DOWNLOAD_COMPONENT = 'download_component'
  PARALLEL_UPLOAD = 'parallel_upload'
  SLICED_DOWNLOAD = 'sliced_download'
  REWRITE = 'rewrite'


def _HashFilename(filename):
  """Apply a hash function (SHA1) to shorten the passed file name.

  The spec for the hashed file name is as follows:

      TRACKER_<hash>_<trailing>

  where hash is a SHA1 hash on the original file name and trailing is
  the last 16 chars from the original file name. Max file name lengths
  vary by operating system so the goal of this function is to ensure
  the hashed version takes fewer than 100 characters.

  Args:
    filename: file name to be hashed.

  Returns:
    shorter, hashed version of passed file name
  """
  if isinstance(filename, unicode):
    filename = filename.encode(UTF8)
  else:
    filename = unicode(filename, UTF8).encode(UTF8)
  m = hashlib.sha1(filename)
  return 'TRACKER_' + m.hexdigest() + '.' + filename[-16:]


def CreateTrackerDirIfNeeded():
  """Looks up or creates the gsutil tracker file directory.

  This is the configured directory where gsutil keeps its resumable transfer
  tracker files. This function creates it if it doesn't already exist.

  Returns:
    The pathname to the tracker directory.
  """
  tracker_dir = config.get(
      'GSUtil', 'resumable_tracker_dir',
      os.path.join(GetGsutilStateDir(), 'tracker-files'))
  CreateDirIfNeeded(tracker_dir)
  return tracker_dir


def GetRewriteTrackerFilePath(src_bucket_name, src_obj_name, dst_bucket_name,
                              dst_obj_name, api_selector):
  """Gets the tracker file name described by the arguments.

  Args:
    src_bucket_name: Source bucket (string).
    src_obj_name: Source object (string).
    dst_bucket_name: Destination bucket (string).
    dst_obj_name: Destination object (string)
    api_selector: API to use for this operation.

  Returns:
    File path to tracker file.
  """
  # Encode the src and dest bucket and object names into the tracker file
  # name.
  res_tracker_file_name = (
      re.sub('[/\\\\]', '_', 'rewrite__%s__%s__%s__%s__%s.token' %
             (src_bucket_name, src_obj_name, dst_bucket_name,
              dst_obj_name, api_selector)))

  return _HashAndReturnPath(res_tracker_file_name, TrackerFileType.REWRITE)


def GetTrackerFilePath(dst_url, tracker_file_type, api_selector, src_url=None,
                       component_num=None):
  """Gets the tracker file name described by the arguments.

  Args:
    dst_url: Destination URL for tracker file.
    tracker_file_type: TrackerFileType for this operation.
    api_selector: API to use for this operation.
    src_url: Source URL for the source file name for parallel uploads.
    component_num: Component number if this is a download component, else None.

  Returns:
    File path to tracker file.
  """
  if tracker_file_type == TrackerFileType.UPLOAD:
    # Encode the dest bucket and object name into the tracker file name.
    res_tracker_file_name = (
        re.sub('[/\\\\]', '_', 'resumable_upload__%s__%s__%s.url' %
               (dst_url.bucket_name, dst_url.object_name, api_selector)))
  elif tracker_file_type == TrackerFileType.DOWNLOAD:
    # Encode the fully-qualified dest file name into the tracker file name.
    res_tracker_file_name = (
        re.sub('[/\\\\]', '_', 'resumable_download__%s__%s.etag' %
               (os.path.realpath(dst_url.object_name), api_selector)))
  elif tracker_file_type == TrackerFileType.DOWNLOAD_COMPONENT:
    # Encode the fully-qualified dest file name and the component number
    # into the tracker file name.
    res_tracker_file_name = (
        re.sub('[/\\\\]', '_', 'resumable_download__%s__%s__%d.etag' %
               (os.path.realpath(dst_url.object_name), api_selector,
                component_num)))
  elif tracker_file_type == TrackerFileType.PARALLEL_UPLOAD:
    # Encode the dest bucket and object names as well as the source file name
    # into the tracker file name.
    res_tracker_file_name = (
        re.sub('[/\\\\]', '_', 'parallel_upload__%s__%s__%s__%s.url' %
               (dst_url.bucket_name, dst_url.object_name,
                src_url, api_selector)))
  elif tracker_file_type == TrackerFileType.SLICED_DOWNLOAD:
    # Encode the fully-qualified dest file name into the tracker file name.
    res_tracker_file_name = (
        re.sub('[/\\\\]', '_', 'sliced_download__%s__%s.etag' %
               (os.path.realpath(dst_url.object_name), api_selector)))
  elif tracker_file_type == TrackerFileType.REWRITE:
    # Should use GetRewriteTrackerFilePath instead.
    raise NotImplementedError()

  return _HashAndReturnPath(res_tracker_file_name, tracker_file_type)


def DeleteDownloadTrackerFiles(dst_url, api_selector):
  """Deletes all tracker files corresponding to an object download.

  Args:
    dst_url: StorageUrl describing the destination file.
    api_selector: The Cloud API implementation used.
  """
  # Delete non-sliced download tracker file.
  DeleteTrackerFile(GetTrackerFilePath(dst_url, TrackerFileType.DOWNLOAD,
                                       api_selector))

  # Delete all sliced download tracker files.
  tracker_files = GetSlicedDownloadTrackerFilePaths(dst_url, api_selector)
  for tracker_file in tracker_files:
    DeleteTrackerFile(tracker_file)


def GetSlicedDownloadTrackerFilePaths(dst_url, api_selector,
                                      num_components=None):
  """Gets a list of sliced download tracker file paths.

  The list consists of the parent tracker file path in index 0, and then
  any existing component tracker files in [1:].

  Args:
    dst_url: Destination URL for tracker file.
    api_selector: API to use for this operation.
    num_components: The number of component tracker files, if already known.
                    If not known, the number will be retrieved from the parent
                    tracker file on disk.
  Returns:
    File path to tracker file.
  """
  parallel_tracker_file_path = GetTrackerFilePath(
      dst_url, TrackerFileType.SLICED_DOWNLOAD, api_selector)
  tracker_file_paths = [parallel_tracker_file_path]

  # If we don't know the number of components, check the tracker file.
  if num_components is None:
    tracker_file = None
    try:
      tracker_file = open(parallel_tracker_file_path, 'r')
      num_components = json.load(tracker_file)['num_components']
    except (IOError, ValueError):
      return tracker_file_paths
    finally:
      if tracker_file:
        tracker_file.close()

  for i in range(num_components):
    tracker_file_paths.append(GetTrackerFilePath(
        dst_url, TrackerFileType.DOWNLOAD_COMPONENT, api_selector,
        component_num=i))

  return tracker_file_paths


def _HashAndReturnPath(res_tracker_file_name, tracker_file_type):
  """Hashes and returns a tracker file path.

  Args:
    res_tracker_file_name: The tracker file name prior to it being hashed.
    tracker_file_type: The TrackerFileType of res_tracker_file_name.

  Returns:
    Final (hashed) tracker file path.
  """
  resumable_tracker_dir = CreateTrackerDirIfNeeded()
  hashed_tracker_file_name = _HashFilename(res_tracker_file_name)
  tracker_file_name = '%s_%s' % (str(tracker_file_type).lower(),
                                 hashed_tracker_file_name)
  tracker_file_path = '%s%s%s' % (resumable_tracker_dir, os.sep,
                                  tracker_file_name)
  assert len(tracker_file_name) < MAX_TRACKER_FILE_NAME_LENGTH
  return tracker_file_path


def DeleteTrackerFile(tracker_file_name):
  if tracker_file_name and os.path.exists(tracker_file_name):
    os.unlink(tracker_file_name)


def HashRewriteParameters(
    src_obj_metadata, dst_obj_metadata, projection, src_generation=None,
    gen_match=None, meta_gen_match=None, canned_acl=None, fields=None,
    max_bytes_per_call=None):
  """Creates an MD5 hex digest of the parameters for a rewrite call.

  Resuming rewrites requires that the input parameters are identical. Thus,
  the rewrite tracker file needs to represent the input parameters. For
  easy comparison, hash the input values. If a user does a performs a
  same-source/same-destination rewrite via a different command (for example,
  with a changed ACL), the hashes will not match and we will restart the
  rewrite from the beginning.

  Args:
    src_obj_metadata: apitools Object describing source object. Must include
      bucket, name, and etag.
    dst_obj_metadata: apitools Object describing destination object. Must
      include bucket and object name
    projection: Projection used for the API call.
    src_generation: Optional source generation.
    gen_match: Optional generation precondition.
    meta_gen_match: Optional metageneration precondition.
    canned_acl: Optional canned ACL string.
    fields: Optional fields to include in response.
    max_bytes_per_call: Optional maximum bytes rewritten per call.

  Returns:
    MD5 hex digest Hash of the input parameters, or None if required parameters
    are missing.
  """
  if (not src_obj_metadata or
      not src_obj_metadata.bucket or
      not src_obj_metadata.name or
      not src_obj_metadata.etag or
      not dst_obj_metadata or
      not dst_obj_metadata.bucket or
      not dst_obj_metadata.name or
      not projection):
    return
  md5_hash = hashlib.md5()
  for input_param in (
      src_obj_metadata, dst_obj_metadata, projection, src_generation,
      gen_match, meta_gen_match, canned_acl, fields, max_bytes_per_call):
    md5_hash.update(str(input_param))
  return md5_hash.hexdigest()


def ReadRewriteTrackerFile(tracker_file_name, rewrite_params_hash):
  """Attempts to read a rewrite tracker file.

  Args:
    tracker_file_name: Tracker file path string.
    rewrite_params_hash: MD5 hex digest of rewrite call parameters constructed
        by HashRewriteParameters.

  Returns:
    String rewrite_token for resuming rewrite requests if a matching tracker
    file exists, None otherwise (which will result in starting a new rewrite).
  """
  # Check to see if we already have a matching tracker file.
  tracker_file = None
  if not rewrite_params_hash:
    return
  try:
    tracker_file = open(tracker_file_name, 'r')
    existing_hash = tracker_file.readline().rstrip('\n')
    if existing_hash == rewrite_params_hash:
      # Next line is the rewrite token.
      return tracker_file.readline().rstrip('\n')
  except IOError as e:
    # Ignore non-existent file (happens first time a rewrite is attempted.
    if e.errno != errno.ENOENT:
      print('Couldn\'t read Copy tracker file (%s): %s. Restarting copy '
            'from scratch.' %
            (tracker_file_name, e.strerror))
  finally:
    if tracker_file:
      tracker_file.close()


def WriteRewriteTrackerFile(tracker_file_name, rewrite_params_hash,
                            rewrite_token):
  """Writes a rewrite tracker file.

  Args:
    tracker_file_name: Tracker file path string.
    rewrite_params_hash: MD5 hex digest of rewrite call parameters constructed
        by HashRewriteParameters.
    rewrite_token: Rewrite token string returned by the service.
  """
  _WriteTrackerFile(tracker_file_name, '%s\n%s\n' % (rewrite_params_hash,
                                                     rewrite_token))


def ReadOrCreateDownloadTrackerFile(src_obj_metadata, dst_url, logger,
                                    api_selector, start_byte,
                                    existing_file_size, component_num=None):
  """Checks for a download tracker file and creates one if it does not exist.

  The methodology for determining the download start point differs between
  normal and sliced downloads. For normal downloads, the existing bytes in
  the file are presumed to be correct and have been previously downloaded from
  the server (if a tracker file exists). In this case, the existing file size
  is used to determine the download start point. For sliced downloads, the
  number of bytes previously retrieved from the server cannot be determined
  from the existing file size, and so the number of bytes known to have been
  previously downloaded is retrieved from the tracker file.

  Args:
    src_obj_metadata: Metadata for the source object. Must include etag and
                      generation.
    dst_url: Destination URL for tracker file.
    logger: For outputting log messages.
    api_selector: API to use for this operation.
    start_byte: The start byte of the byte range for this download.
    existing_file_size: Size of existing file for this download on disk.
    component_num: The component number, if this is a component of a parallel
                   download, else None.

  Returns:
    tracker_file_name: The name of the tracker file, if one was used.
    download_start_byte: The first byte that still needs to be downloaded.
  """
  assert src_obj_metadata.etag

  tracker_file_name = None
  if src_obj_metadata.size < ResumableThreshold():
    # Don't create a tracker file for a small downloads; cross-process resumes
    # won't work, but restarting a small download is inexpensive.
    return tracker_file_name, start_byte

  download_name = dst_url.object_name
  if component_num is None:
    tracker_file_type = TrackerFileType.DOWNLOAD
  else:
    tracker_file_type = TrackerFileType.DOWNLOAD_COMPONENT
    download_name += ' component %d' % component_num

  tracker_file_name = GetTrackerFilePath(dst_url, tracker_file_type,
                                         api_selector,
                                         component_num=component_num)
  tracker_file = None
  # Check to see if we already have a matching tracker file.
  try:
    tracker_file = open(tracker_file_name, 'r')
    if tracker_file_type is TrackerFileType.DOWNLOAD:
      etag_value = tracker_file.readline().rstrip('\n')
      if etag_value == src_obj_metadata.etag:
        return tracker_file_name, existing_file_size
    elif tracker_file_type is TrackerFileType.DOWNLOAD_COMPONENT:
      component_data = json.loads(tracker_file.read())
      if (component_data['etag'] == src_obj_metadata.etag and
          component_data['generation'] == src_obj_metadata.generation):
        return tracker_file_name, component_data['download_start_byte']

    logger.warn('Tracker file doesn\'t match for download of %s. Restarting '
                'download from scratch.' % download_name)

  except (IOError, ValueError) as e:
    # Ignore non-existent file (happens first time a download
    # is attempted on an object), but warn user for other errors.
    if isinstance(e, ValueError) or e.errno != errno.ENOENT:
      logger.warn('Couldn\'t read download tracker file (%s): %s. Restarting '
                  'download from scratch.' % (tracker_file_name, str(e)))
  finally:
    if tracker_file:
      tracker_file.close()

  # There wasn't a matching tracker file, so create one and then start the
  # download from scratch.
  if tracker_file_type is TrackerFileType.DOWNLOAD:
    _WriteTrackerFile(tracker_file_name, '%s\n' % src_obj_metadata.etag)
  elif tracker_file_type is TrackerFileType.DOWNLOAD_COMPONENT:
    WriteDownloadComponentTrackerFile(tracker_file_name, src_obj_metadata,
                                      start_byte)
  return tracker_file_name, start_byte


def WriteDownloadComponentTrackerFile(tracker_file_name, src_obj_metadata,
                                      current_file_pos):
  """Updates or creates a download component tracker file on disk.

  Args:
    tracker_file_name: The name of the tracker file.
    src_obj_metadata: Metadata for the source object. Must include etag.
    current_file_pos: The current position in the file.
  """
  component_data = {'etag': src_obj_metadata.etag,
                    'generation': src_obj_metadata.generation,
                    'download_start_byte': current_file_pos}

  _WriteTrackerFile(tracker_file_name, json.dumps(component_data))


def _WriteTrackerFile(tracker_file_name, data):
  """Creates a tracker file, storing the input data."""
  try:
    with os.fdopen(os.open(tracker_file_name,
                           os.O_WRONLY | os.O_CREAT, 0600), 'w') as tf:
      tf.write(data)
    return False
  except (IOError, OSError) as e:
    raise RaiseUnwritableTrackerFileException(tracker_file_name, e.strerror)


def RaiseUnwritableTrackerFileException(tracker_file_name, error_str):
  """Raises an exception when unable to write the tracker file."""
  raise CommandException(TRACKER_FILE_UNWRITABLE_EXCEPTION_TEXT %
                         (tracker_file_name, error_str))