#!/usr/bin/env python
#
# Copyright (C) 2008 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Signs all the APK files in a target-files zipfile, producing a new
target-files zip.

Usage:  sign_target_files_apks [flags] input_target_files output_target_files

  -s  (--signapk_jar)  <path>
      Path of the signapks.jar file used to sign an individual APK
      file.

  -e  (--extra_apks)  <name,name,...=key>
      Add extra APK name/key pairs as though they appeared in
      apkcerts.txt (so mappings specified by -k and -d are applied).
      Keys specified in -e override any value for that app contained
      in the apkcerts.txt file.  Option may be repeated to give
      multiple extra packages.

  -k  (--key_mapping)  <src_key=dest_key>
      Add a mapping from the key name as specified in apkcerts.txt (the
      src_key) to the real key you wish to sign the package with
      (dest_key).  Option may be repeated to give multiple key
      mappings.

  -d  (--default_key_mappings)  <dir>
      Set up the following key mappings:

        build/target/product/security/testkey   ==>  $dir/releasekey
        build/target/product/security/media     ==>  $dir/media
        build/target/product/security/shared    ==>  $dir/shared
        build/target/product/security/platform  ==>  $dir/platform

      -d and -k options are added to the set of mappings in the order
      in which they appear on the command line.

  -o  (--replace_ota_keys)
      Replace the certificate (public key) used by OTA package
      verification with the one specified in the input target_files
      zip (in the META/otakeys.txt file).  Key remapping (-k and -d)
      is performed on this key.

  -t  (--tag_changes)  <+tag>,<-tag>,...
      Comma-separated list of changes to make to the set of tags (in
      the last component of the build fingerprint).  Prefix each with
      '+' or '-' to indicate whether that tag should be added or
      removed.  Changes are processed in the order they appear.
      Default value is "-test-keys,+ota-rel-keys,+release-keys".

"""

import sys

if sys.hexversion < 0x02040000:
  print >> sys.stderr, "Python 2.4 or newer is required."
  sys.exit(1)

import cStringIO
import copy
import os
import re
import subprocess
import tempfile
import zipfile

import common

OPTIONS = common.OPTIONS

OPTIONS.extra_apks = {}
OPTIONS.key_map = {}
OPTIONS.replace_ota_keys = False
OPTIONS.tag_changes = ("-test-keys", "+ota-rel-keys", "+release-keys")

def GetApkCerts(tf_zip):
  certmap = {}
  for line in tf_zip.read("META/apkcerts.txt").split("\n"):
    line = line.strip()
    if not line: continue
    m = re.match(r'^name="(.*)"\s+certificate="(.*)\.x509\.pem"\s+'
                 r'private_key="\2\.pk8"$', line)
    if not m:
      raise SigningError("failed to parse line from apkcerts.txt:\n" + line)
    certmap[m.group(1)] = OPTIONS.key_map.get(m.group(2), m.group(2))
  for apk, cert in OPTIONS.extra_apks.iteritems():
    certmap[apk] = OPTIONS.key_map.get(cert, cert)
  return certmap


def CheckAllApksSigned(input_tf_zip, apk_key_map):
  """Check that all the APKs we want to sign have keys specified, and
  error out if they don't."""
  unknown_apks = []
  for info in input_tf_zip.infolist():
    if info.filename.endswith(".apk"):
      name = os.path.basename(info.filename)
      if name not in apk_key_map:
        unknown_apks.append(name)
  if unknown_apks:
    print "ERROR: no key specified for:\n\n ",
    print "\n  ".join(unknown_apks)
    print "\nUse '-e <apkname>=' to specify a key (which may be an"
    print "empty string to not sign this apk)."
    sys.exit(1)


def SharedUserForApk(data):
  tmp = tempfile.NamedTemporaryFile()
  tmp.write(data)
  tmp.flush()

  p = common.Run(["aapt", "dump", "xmltree", tmp.name, "AndroidManifest.xml"],
                 stdout=subprocess.PIPE)
  data, _ = p.communicate()
  if p.returncode != 0:
    raise ExternalError("failed to run aapt dump")
  lines = data.split("\n")
  for i in lines:
    m = re.match(r'^\s*A: android:sharedUserId\([0-9a-fx]*\)="([^"]*)" .*$', i)
    if m:
      return m.group(1)
  return None


def CheckSharedUserIdsConsistent(input_tf_zip, apk_key_map):
  """Check that all packages that request the same shared user id are
  going to be signed with the same key."""

  shared_user_apks = {}
  maxlen = len("(unknown key)")

  for info in input_tf_zip.infolist():
    if info.filename.endswith(".apk"):
      data = input_tf_zip.read(info.filename)

      name = os.path.basename(info.filename)
      shared_user = SharedUserForApk(data)
      key = apk_key_map[name]
      maxlen = max(maxlen, len(key))

      if shared_user is not None:
        shared_user_apks.setdefault(
            shared_user, {}).setdefault(key, []).append(name)

  errors = []
  for k, v in shared_user_apks.iteritems():
    # each shared user should have exactly one key used for all the
    # apks that want that user.
    if len(v) > 1:
      errors.append((k, v))

  if not errors: return

  print "ERROR:  shared user inconsistency.  All apks wanting to use"
  print "        a given shared user must be signed with the same key."
  print
  errors.sort()
  for user, keys in errors:
    print 'shared user id "%s":' % (user,)
    for key, apps in keys.iteritems():
      print '  %-*s   %s' % (maxlen, key or "(unknown key)", apps[0])
      for a in apps[1:]:
        print (' ' * (maxlen+5)) + a
    print

  sys.exit(1)


def SignApk(data, keyname, pw):
  unsigned = tempfile.NamedTemporaryFile()
  unsigned.write(data)
  unsigned.flush()

  signed = tempfile.NamedTemporaryFile()

  common.SignFile(unsigned.name, signed.name, keyname, pw, align=4)

  data = signed.read()
  unsigned.close()
  signed.close()

  return data


def SignApks(input_tf_zip, output_tf_zip, apk_key_map, key_passwords):
  maxsize = max([len(os.path.basename(i.filename))
                 for i in input_tf_zip.infolist()
                 if i.filename.endswith('.apk')])

  for info in input_tf_zip.infolist():
    data = input_tf_zip.read(info.filename)
    out_info = copy.copy(info)
    if info.filename.endswith(".apk"):
      name = os.path.basename(info.filename)
      key = apk_key_map[name]
      if key:
        print "    signing: %-*s (%s)" % (maxsize, name, key)
        signed_data = SignApk(data, key, key_passwords[key])
        output_tf_zip.writestr(out_info, signed_data)
      else:
        # an APK we're not supposed to sign.
        print "NOT signing: %s" % (name,)
        output_tf_zip.writestr(out_info, data)
    elif info.filename in ("SYSTEM/build.prop",
                           "RECOVERY/RAMDISK/default.prop"):
      print "rewriting %s:" % (info.filename,)
      new_data = RewriteProps(data)
      output_tf_zip.writestr(out_info, new_data)
    else:
      # a non-APK file; copy it verbatim
      output_tf_zip.writestr(out_info, data)


def RewriteProps(data):
  output = []
  for line in data.split("\n"):
    line = line.strip()
    original_line = line
    if line and line[0] != '#':
      key, value = line.split("=", 1)
      if key == "ro.build.fingerprint":
        pieces = line.split("/")
        tags = set(pieces[-1].split(","))
        for ch in OPTIONS.tag_changes:
          if ch[0] == "-":
            tags.discard(ch[1:])
          elif ch[0] == "+":
            tags.add(ch[1:])
        line = "/".join(pieces[:-1] + [",".join(sorted(tags))])
      elif key == "ro.build.description":
        pieces = line.split(" ")
        assert len(pieces) == 5
        tags = set(pieces[-1].split(","))
        for ch in OPTIONS.tag_changes:
          if ch[0] == "-":
            tags.discard(ch[1:])
          elif ch[0] == "+":
            tags.add(ch[1:])
        line = " ".join(pieces[:-1] + [",".join(sorted(tags))])
    if line != original_line:
      print "  replace: ", original_line
      print "     with: ", line
    output.append(line)
  return "\n".join(output) + "\n"


def ReplaceOtaKeys(input_tf_zip, output_tf_zip):
  try:
    keylist = input_tf_zip.read("META/otakeys.txt").split()
  except KeyError:
    raise ExternalError("can't read META/otakeys.txt from input")

  mapped_keys = []
  for k in keylist:
    m = re.match(r"^(.*)\.x509\.pem$", k)
    if not m:
      raise ExternalError("can't parse \"%s\" from META/otakeys.txt" % (k,))
    k = m.group(1)
    mapped_keys.append(OPTIONS.key_map.get(k, k) + ".x509.pem")

  if mapped_keys:
    print "using:\n   ", "\n   ".join(mapped_keys)
    print "for OTA package verification"
  else:
    mapped_keys.append(
        OPTIONS.key_map["build/target/product/security/testkey"] + ".x509.pem")
    print "META/otakeys.txt has no keys; using", mapped_keys[0]

  # recovery uses a version of the key that has been slightly
  # predigested (by DumpPublicKey.java) and put in res/keys.

  p = common.Run(["java", "-jar",
                  os.path.join(OPTIONS.search_path, "framework", "dumpkey.jar")]
                 + mapped_keys,
                 stdout=subprocess.PIPE)
  data, _ = p.communicate()
  if p.returncode != 0:
    raise ExternalError("failed to run dumpkeys")
  common.ZipWriteStr(output_tf_zip, "RECOVERY/RAMDISK/res/keys", data)

  # SystemUpdateActivity uses the x509.pem version of the keys, but
  # put into a zipfile system/etc/security/otacerts.zip.

  tempfile = cStringIO.StringIO()
  certs_zip = zipfile.ZipFile(tempfile, "w")
  for k in mapped_keys:
    certs_zip.write(k)
  certs_zip.close()
  common.ZipWriteStr(output_tf_zip, "SYSTEM/etc/security/otacerts.zip",
                     tempfile.getvalue())


def main(argv):

  def option_handler(o, a):
    if o in ("-s", "--signapk_jar"):
      OPTIONS.signapk_jar = a
    elif o in ("-e", "--extra_apks"):
      names, key = a.split("=")
      names = names.split(",")
      for n in names:
        OPTIONS.extra_apks[n] = key
    elif o in ("-d", "--default_key_mappings"):
      OPTIONS.key_map.update({
          "build/target/product/security/testkey": "%s/releasekey" % (a,),
          "build/target/product/security/media": "%s/media" % (a,),
          "build/target/product/security/shared": "%s/shared" % (a,),
          "build/target/product/security/platform": "%s/platform" % (a,),
          })
    elif o in ("-k", "--key_mapping"):
      s, d = a.split("=")
      OPTIONS.key_map[s] = d
    elif o in ("-o", "--replace_ota_keys"):
      OPTIONS.replace_ota_keys = True
    elif o in ("-t", "--tag_changes"):
      new = []
      for i in a.split(","):
        i = i.strip()
        if not i or i[0] not in "-+":
          raise ValueError("Bad tag change '%s'" % (i,))
        new.append(i[0] + i[1:].strip())
      OPTIONS.tag_changes = tuple(new)
    else:
      return False
    return True

  args = common.ParseOptions(argv, __doc__,
                             extra_opts="s:e:d:k:ot:",
                             extra_long_opts=["signapk_jar=",
                                              "extra_apks=",
                                              "default_key_mappings=",
                                              "key_mapping=",
                                              "replace_ota_keys",
                                              "tag_changes="],
                             extra_option_handler=option_handler)

  if len(args) != 2:
    common.Usage(__doc__)
    sys.exit(1)

  input_zip = zipfile.ZipFile(args[0], "r")
  output_zip = zipfile.ZipFile(args[1], "w")

  apk_key_map = GetApkCerts(input_zip)
  CheckAllApksSigned(input_zip, apk_key_map)
  CheckSharedUserIdsConsistent(input_zip, apk_key_map)

  key_passwords = common.GetKeyPasswords(set(apk_key_map.values()))
  SignApks(input_zip, output_zip, apk_key_map, key_passwords)

  if OPTIONS.replace_ota_keys:
    ReplaceOtaKeys(input_zip, output_zip)

  input_zip.close()
  output_zip.close()

  print "done."


if __name__ == '__main__':
  try:
    main(sys.argv[1:])
  except common.ExternalError, e:
    print
    print "   ERROR: %s" % (e,)
    print
    sys.exit(1)