"""Multiply primitive optimized for the gemv operation."""

import neon_emitter


class Error(Exception):
  """Module level error."""


class ConfigurationError(Error):
  """Unsupported configuration."""


def GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
                                  count, lhs, rhs_1, rhs_2):
  """Emit inner loop for 1 row x M cols multiplication."""
  emitter.EmitComment('General 1xM lanes loop.')
  emitter.EmitNumericalLabel(1)
  emitter.EmitNewline()
  emitter.EmitComment('Subtract counter.')
  emitter.EmitSubs(count, count, emitter.ImmediateConstant(8))
  emitter.EmitNewline()

  right_load = [registers.DoubleRegister() for unused_i in range(4)]
  left_load = registers.DoubleRegister()

  emitter.EmitVLoad('1.8', left_load, emitter.DereferenceIncrement(lhs, 64))
  emitter.EmitVLoadA('1.8', right_load, emitter.DereferenceIncrement(rhs_1, 64))

  emitter.EmitPldOffset(lhs, emitter.ImmediateConstant(64))
  emitter.EmitPldOffset(rhs_1, emitter.ImmediateConstant(128))

  multiply_results = [registers.QuadRegister() for unused_i in range(4)]

  for i in range(4):
    emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)

  emitter.EmitVLoadA('1.8', right_load[:lanes_count],
                     emitter.DereferenceIncrement(rhs_2, 64))
  emitter.EmitPldOffset(rhs_2, emitter.ImmediateConstant(lanes_count * 32))

  for i in range(4):
    emitter.EmitVPadal('u16', aggregators[i], multiply_results[i])

  for i in range(lanes_count):
    emitter.EmitVMull('u8', multiply_results[i], right_load[i], left_load)

  for i in range(lanes_count):
    emitter.EmitVPadal('u16', aggregators[i + 4], multiply_results[i])

  emitter.EmitNewline()
  emitter.EmitComment('Loop break.')
  emitter.EmitBneBack(1)
  emitter.EmitNewline()

  registers.FreeRegister(left_load)
  registers.FreeRegisters(right_load)
  registers.FreeRegisters(multiply_results)


def ReadLeft(emitter, registers, lhs):
  register = registers.QuadRegister()
  emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)),
                              emitter.AllLanes(registers.High(register))],
                     emitter.Dereference(lhs, None))
  return register


def ReadRight(emitter, registers, rhs, count):
  if count == 1 or count == 2:
    register = registers.DoubleRegister()
  elif count == 3 or count == 4:
    register = registers.QuadRegister()
  else:
    raise ConfigurationError('Unsupported elements no: %d' % count)
  emitter.EmitVLoad('1.32', register, emitter.Dereference(rhs, 64))
  return register


def DuplicateGeneralRegister(emitter, registers, general_register,
                             min_register):
  duplicated = registers.QuadRegister(min_register)
  emitter.EmitVDup('32', duplicated, general_register)
  return duplicated


def GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
                                  result_type, lhs_add, rhs_add, lhs, rhs_1,
                                  rhs_2, results):
  """Generates assembly responsible for reducing the 4 way aggregators."""
  if lhs_add:
    left_offset = ReadLeft(emitter, registers, lhs)
  else:
    left_offset = None

  if rhs_add:
    right_offset_1 = ReadRight(emitter, registers, rhs_1, 4)
    right_offset_2 = ReadRight(emitter, registers, rhs_2, lanes_count)
  else:
    right_offset_1 = None
    right_offset_2 = None

  if result_type is 'float':
    result_scale = DuplicateGeneralRegister(
        emitter, registers, registers.MapParameter('result_scale'), 4)
  else:
    result_scale = None

  emitter.EmitNewline()
  emitter.EmitComment('Horizontal reduce aggregators.')
  for aggregator in aggregators:
    emitter.EmitVPadd('u32', registers.Low(aggregator),
                      registers.Low(aggregator), registers.High(aggregator))

  temp = aggregators[0]
  emitter.EmitVPadd('u32', registers.Low(temp), registers.Low(aggregators[0]),
                    registers.Low(aggregators[1]))
  emitter.EmitVPadd('u32', registers.High(temp), registers.Low(aggregators[2]),
                    registers.Low(aggregators[3]))

  if lanes_count == 1:
    temp_2 = registers.Low(aggregators[1])
    emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
                      registers.Low(aggregators[4]))
  elif lanes_count == 2:
    temp_2 = registers.Low(aggregators[1])
    emitter.EmitVPadd('u32', temp_2, registers.Low(aggregators[4]),
                      registers.Low(aggregators[5]))
  elif lanes_count == 3:
    temp_2 = aggregators[1]
    emitter.EmitVPadd('u32', registers.Low(temp_2),
                      registers.Low(aggregators[4]),
                      registers.Low(aggregators[5]))
    emitter.EmitVPadd('u32', registers.High(temp_2),
                      registers.Low(aggregators[6]),
                      registers.Low(aggregators[6]))
  elif lanes_count == 4:
    temp_2 = aggregators[1]
    emitter.EmitVPadd('u32', registers.Low(temp_2),
                      registers.Low(aggregators[4]),
                      registers.Low(aggregators[5]))
    emitter.EmitVPadd('u32', registers.High(temp_2),
                      registers.Low(aggregators[6]),
                      registers.Low(aggregators[7]))
  else:
    temp_2 = None

  if lhs_add:
    emitter.EmitNewline()
    emitter.EmitComment('Add lhs offsets to aggregated rows.')
    emitter.EmitVAdd('s32', temp, temp, left_offset)
    if lanes_count == 1 or lanes_count == 2:
      emitter.EmitVAdd('s32', temp_2, temp_2, registers.Low(left_offset))
    elif lanes_count == 3 or lanes_count == 4:
      emitter.EmitVAdd('s32', temp_2, temp_2, left_offset)

  if rhs_add:
    emitter.EmitNewline()
    emitter.EmitComment('Add rhs offset to aggregated rows.')
    emitter.EmitVAdd('s32', temp, temp, right_offset_1)
    emitter.EmitVAdd('s32', temp_2, temp_2, right_offset_2)

  if result_type is 'float':
    emitter.EmitNewline()
    emitter.EmitComment('Convert to float and scale.')
    emitter.EmitVCvt('f32', 's32', temp, temp)
    emitter.EmitVCvt('f32', 's32', temp_2, temp_2)
    emitter.EmitVMul('f32', temp, temp, result_scale)
    if lanes_count == 1 or lanes_count == 2:
      emitter.EmitVMul('f32', temp_2, temp_2, registers.Low(result_scale))
    elif lanes_count == 3 or lanes_count == 4:
      emitter.EmitVMul('f32', temp_2, temp_2, result_scale)

  emitter.EmitNewline()
  emitter.EmitComment('Store results.')
  if lanes_count == 1:
    emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp)],
                        emitter.DereferenceIncrement(results, None))
    emitter.EmitVStore('1.32', emitter.Lane(temp_2, 0),
                       emitter.Dereference(results, None))
  elif lanes_count == 2:
    emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
                                 temp_2], emitter.Dereference(results, None))
  elif lanes_count == 3:
    emitter.EmitVStoreA(
        '1.32',
        [registers.Low(temp), registers.High(temp), registers.Low(temp_2)],
        emitter.DereferenceIncrement(results, None))
    emitter.EmitVStore('1.32', emitter.Lane(
        registers.High(temp_2), 0), emitter.Dereference(results, None))
  elif lanes_count == 4:
    emitter.EmitVStoreA('1.32', [registers.Low(temp), registers.High(temp),
                                 registers.Low(temp_2), registers.High(temp_2)],
                        emitter.Dereference(results, None))


def BuildName(result_type, lhs_add, rhs_add, lanes):
  name = 'mul_1x8_%dx8_%s' % (lanes, result_type)
  if lhs_add:
    name += '_lhsadd'
  if rhs_add:
    name += '_rhsadd'
  return name


def CppResultType(result_type):
  if result_type is 'int32':
    return 'std::int32_t*'
  elif result_type is 'float':
    return 'float*'
  else:
    raise ConfigurationError('Unsupported result type: %s' % result_type)


def GetParameters(result_type):
  params = [['const std::uint8_t*', 'lhs'], ['const std::uint8_t*', 'rhs_1'],
            ['const std::uint8_t*', 'rhs_2'], ['std::int32_t', 'count'],
            [CppResultType(result_type), 'result']]
  if result_type is 'float':
    params.append(['float', 'result_scale'])
  return params


def GenerateAndClearAggregators(emitter, registers, aggregator_count):
  """Prepare aggregators and emit aggregator clear code."""
  emitter.EmitNewline()
  emitter.EmitComment('Clear aggregators.')
  aggregators = []
  for i in range(aggregator_count):
    aggregator = registers.QuadRegister()
    aggregators.append(aggregator)
    if i < 3:
      emitter.EmitVMov('i32', aggregator, emitter.ImmediateConstant(0))
    else:
      emitter.EmitVMov('i32', aggregator, aggregators[i - 3])
  emitter.EmitNewline()
  return aggregators


def GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes_count):
  """Generates the 1xN multiplication primitive."""
  if lanes_count < 1 or lanes_count > 4:
    raise ConfigurationError('Lanes should be: 1, 2, 3 or 4.')

  emitter.EmitFunctionBeginA(
      BuildName(result_type, lhs_add, rhs_add, lanes_count + 4),
      GetParameters(result_type), 'inline void')

  emitter.EmitAssert('count % 8 == 0')
  emitter.EmitAssert('count >= 8')
  emitter.EmitAsmBegin()

  registers = neon_emitter.NeonRegisters()

  count = registers.MapParameter('count')

  lhs = registers.MapParameter('lhs')
  rhs_1 = registers.MapParameter('rhs_1')
  rhs_2 = registers.MapParameter('rhs_2')

  emitter.EmitPld(lhs)
  emitter.EmitPld(rhs_1)
  emitter.EmitPld(rhs_2)

  aggregators = GenerateAndClearAggregators(emitter, registers, lanes_count + 4)

  GenerateLoadMultiplyAggregate(emitter, registers, lanes_count, aggregators,
                                count, lhs, rhs_1, rhs_2)
  GenerateAggregatorReduceStore(emitter, registers, lanes_count, aggregators,
                                result_type, lhs_add, rhs_add, lhs, rhs_1,
                                rhs_2, registers.MapParameter('result'))

  emitter.EmitAsmEnd(registers.MappedParameters(), [],
                     registers.Clobbers() + ['cc', 'memory'])
  emitter.EmitFunctionEnd()


def GenerateFunctions(emitter, result_type, lhs_add, rhs_add):
  for lanes in range(1, 5):
    GenerateMul1x8Mx8(emitter, result_type, lhs_add, rhs_add, lanes)
    emitter.EmitNewline()


if __name__ == '__main__':
  GenerateFunctions(neon_emitter.NeonEmitter(), 'int32', True, True)