#!/usr/bin/env python
# Copyright 2017, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
try:
	from os import fstat, stat, remove
	from sys import exit
	from argparse import ArgumentParser, FileType
	from ctypes import sizeof, Structure, c_char, c_int
	from struct import pack, calcsize
	import zlib
except Exception as e:
	print("some module is needed:" + str(e))
	exit(-1)

dt_head_info_fmt = '4sII'
dt_entry_fmt = 'Q4I2Q'
dtimg_version = 1
dtb_count = 1

def write32(output, value):
	output.write(chr(value & 255)) ; value=value // 256
	output.write(chr(value & 255)) ; value=value // 256
	output.write(chr(value & 255)) ; value=value // 256
	output.write(chr(value & 255))

def compress(filename, input, output):
	output.write('\037\213\010')
	output.write(chr(0))

	statval = stat(filename)
	write32(output, 0)
	output.write('\002')
	output.write('\003')

	crcval = zlib.crc32("")
	compobj = zlib.compressobj(9, zlib.DEFLATED, -zlib.MAX_WBITS,
		zlib.DEF_MEM_LEVEL, 0)
	while True:
		data = input.read(1024)
		if data == "":
			break
		crcval = zlib.crc32(data, crcval)
		output.write(compobj.compress(data))
	output.write(compobj.flush())
	write32(output, crcval)
	write32(output, statval.st_size)

def dtb_compress(dtb_file):
	try:
		outputname = dtb_file + '.gz'
		input = open(dtb_file, 'rb')
		output = open(outputname, 'wb')
		compress(dtb_file, input, output)
		input.close()
		output.close()
	except Exception as e:
		print('dtb_compress error:' + str(e))
		exit(-1)
	return outputname

class dt_head_info(Structure):
	_fields_ = [('magic', c_char * 4),
		    ('version', c_int),
		    ('dt_count', c_int)]

class dt_entry_t(Structure):
	_fields_ = [('dtb_size', c_int),
		    ('dtb_offset', c_int)]

def align_page_size(offset, pagesize):
	return (pagesize - (offset % pagesize))

def write_head_info(head_info, args):
	args.output.write(pack(dt_head_info_fmt,
			       head_info.magic,
			       head_info.version,
			       head_info.dt_count))

def write_dtb_entry_t(dt_entry, args):
	args.output.write(pack(dt_entry_fmt,
			       0,  # reserved
			       dt_entry.dtb_size,
			       0,  # reserved
			       dt_entry.dtb_offset,
			       0,  # reserved
			       0,  # reserved
			       0)) # reserved

def write_padding(args, padding):
	for i in range(0, padding):
		args.output.write('\x00')

def write_dtb(args):
	dtb_file = args.dtb
	out_dtb = dtb_file
	if args.compress == True:
		out_dtb = dtb_compress(dtb_file)
	try:
		dtb_offset = calcsize(dt_head_info_fmt) + \
				      calcsize(dt_entry_fmt) + \
				      4
		padding = align_page_size(dtb_offset, args.pagesize)
		dtb_size = stat(out_dtb).st_size
		dtb_size_padding = align_page_size(dtb_size, args.pagesize)
		dt_entry = dt_entry_t(dtb_size + dtb_size_padding,
				      dtb_offset + padding)
		write_dtb_entry_t(dt_entry, args)
		args.output.write(pack('I', 0)) # SUCCESS code number
		write_padding(args, padding)
		with open(out_dtb, 'rb') as dtb_fd:
			args.output.write(dtb_fd.read(dtb_size))
			write_padding(args, dtb_size_padding)
	except Exception as e:
		print('write dtb error:' + str(e))
		exit(-1)

def clean_gz_file(args):
	try:
		if args.compress != True:
			return
		remove(args.dtb + '.gz')
	except Exception as e:
		print('clean gz file error:' + str(e))
		exit(-1)

def parse_cmdline():
	parser = ArgumentParser()
	parser.add_argument('-c', '--compress', help='compress dtb or not',
			    action='store_true')
	parser.add_argument('-d', '--dtb', help='path to the dtb', type=str,
			    required=True)
	parser.add_argument('-s', '--pagesize', help='align page size',
			    type=int, choices=[2**i for i in range(11,15)],
			    default=2048)
	parser.add_argument('-o', '--output', help='output file name',
			    type=FileType('wb'), required=True)
	return parser.parse_args()

def main():
	args = parse_cmdline()
	dtimg_head_info = dt_head_info('HSDT', dtimg_version, dtb_count)
	write_head_info(dtimg_head_info, args)
	write_dtb(args)
	clean_gz_file(args)

if __name__ == '__main__':
	main()