"""Qnt primitive used by the GEMM function. """ import neon_emitter class Error(Exception): """Module level error.""" class ConfigurationError(Error): """Unsupported configuration.""" class QntLane(object): def __init__(self, source, output, offset, load_1, load_2): self.source = source self.output = output self.offset = offset self.load_1 = load_1 self.load_2 = load_2 def BuildName(lanes, leftovers, aligned): name = 'qnt_%dx8' % lanes if leftovers: name += '_%d' % leftovers if aligned: name += '_aligned' return name def LoadAndDuplicateOffsets(emitter, registers, lanes, offsets): if lanes == 1 or lanes == 2 or lanes == 3: offset_registers = [] for unused_i in range(0, lanes): register = registers.QuadRegister() emitter.EmitVLoadA('1.32', [emitter.AllLanes(registers.Low(register)), emitter.AllLanes(registers.High(register))], emitter.DereferenceIncrement(offsets, 32)) offset_registers.append(register) return offset_registers else: raise ConfigurationError('Unsupported number of lanes: %d' % lanes) def GenerateQntLanes(emitter, registers, qnt_lanes, source, stride, destination, destination_stride, offsets): """Prepare lanes for reading unquantized multiplication results.""" offset_registers = LoadAndDuplicateOffsets(emitter, registers, qnt_lanes, offsets) lanes = [] last_input_register = source last_output_register = destination for i in range(0, qnt_lanes): if not i: lanes.append(QntLane(source, destination, offset_registers[i], registers.QuadRegister(), # load 1 registers.QuadRegister())) # load 2 else: input_register = registers.GeneralRegister() output_register = registers.GeneralRegister() lanes.append(QntLane(input_register, output_register, offset_registers[i], registers.QuadRegister(), # load 1 registers.QuadRegister())) # load 2 emitter.EmitAdd(input_register, last_input_register, stride) emitter.EmitAdd(output_register, last_output_register, destination_stride) last_input_register = input_register last_output_register = output_register return lanes def DuplicateRegister(emitter, registers, value): register = registers.QuadRegister() emitter.EmitVDup('32', register, value) return register def GenerateQuantize(emitter, registers, lanes, lane_temps, multiplicative_offset, rounding_offset, shift): """Inner loop for quantization: add offsets, multiply, round, shift.""" for lane in lanes: emitter.EmitVAdd('i32', lane[0], lane[0], lane[1]) for lane in lanes: emitter.EmitVMul('i32', lane[0], lane[0], multiplicative_offset) for lane in lanes: emitter.EmitVAdd('i32', lane[0], lane[0], rounding_offset) for lane in lanes: emitter.EmitVShl('s32', lane[0], lane[0], shift) for lane in lanes: emitter.EmitVQmovn('s32', lane[2], lane[0]) for lane_temp in lane_temps: emitter.EmitVQmovun('s16', registers.Low(lane_temp), lane_temp) def GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset, rounding_offset, shift, alignment): """Load unquantized data from lanes, quantize, store final result.""" lane_temps = [] for lane in lanes: lane_temps.append(registers.QuadRegister()) for lane in lanes: emitter.EmitVLoadA( '1.32', [registers.Low(lane.load_1), registers.High(lane.load_1), registers.Low(lane.load_2), registers.High(lane.load_2)], emitter.DereferenceIncrement(lane.source, 64)) for lane in lanes: emitter.EmitPld(lane.source) quantize_setup = [] for (lane_temp, lane) in zip(lane_temps, lanes): quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp)]) GenerateQuantize(emitter, registers, quantize_setup, lane_temps, multiplicative_offset, rounding_offset, shift) for (lane_temp, lane) in zip(lane_temps, lanes): emitter.EmitVStore('1.8', registers.Low(lane_temp), emitter.DereferenceIncrement(lane.output, alignment)) for lane_temp in lane_temps: registers.FreeRegister(lane_temp) def GenerateLoadLeftovers(emitter, registers, leftovers, lanes): """Handle non multiply of 8 leftover loading.""" if leftovers == 1: for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane( registers.Low(lane.load_1), 0), emitter.Dereference(lane.source, None)) elif leftovers == 2: for lane in lanes: emitter.EmitVLoad('1.32', registers.Low(lane.load_1), emitter.Dereference(lane.source, 64)) elif leftovers == 3: for lane in lanes: emitter.EmitVLoad('1.32', registers.Low(lane.load_1), emitter.DereferenceIncrement(lane.source, 64)) for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane( registers.High(lane.load_1), 0), emitter.Dereference(lane.source, None)) elif leftovers == 4: for lane in lanes: emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), registers.High(lane.load_1)], emitter.Dereference(lane.source, 64)) elif leftovers == 5: for lane in lanes: emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), registers.High(lane.load_1)], emitter.DereferenceIncrement(lane.source, 64)) for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane( registers.Low(lane.load_2), 0), emitter.Dereference(lane.source, None)) elif leftovers == 6: for lane in lanes: emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), registers.High(lane.load_1), registers.Low(lane.load_2)], emitter.Dereference(lane.source, 64)) elif leftovers == 7: for lane in lanes: emitter.EmitVLoadA('1.32', [registers.Low(lane.load_1), registers.High(lane.load_1), registers.Low(lane.load_2)], emitter.DereferenceIncrement(lane.source, 64)) for lane in lanes: emitter.EmitVLoad('1.32', emitter.Lane( registers.High(lane.load_2), 0), emitter.Dereference(lane.source, None)) else: raise ConfigurationError('Unsuported leftover count: %d' % leftovers) def GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes): """Handle non multiply of 8 leftover storing.""" setup = [] for (temp, lane) in zip(lane_temps, lanes): setup.append([registers.Low(temp), lane.output]) if leftovers == 1: for lane in setup: emitter.EmitVStore('1.8', emitter.Lane(lane[0], 0), emitter.Dereference(lane[1], None)) elif leftovers == 2: for lane in setup: emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), emitter.Dereference(lane[1], None)) elif leftovers == 3: for lane in setup: emitter.EmitVStore('1.16', emitter.Lane(lane[0], 0), emitter.DereferenceIncrement(lane[1], None)) for lane in setup: emitter.EmitVStore('1.8', emitter.Lane(lane[0], 2), emitter.Dereference(lane[1], None)) elif leftovers == 4: for lane in setup: emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), emitter.Dereference(lane[1], None)) elif leftovers == 5: for lane in setup: emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), emitter.DereferenceIncrement(lane[1], None)) for lane in setup: emitter.EmitVStore('1.8', emitter.Lane(lane[0], 4), emitter.Dereference(lane[1], None)) elif leftovers == 6: for lane in setup: emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), emitter.DereferenceIncrement(lane[1], None)) for lane in setup: emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), emitter.Dereference(lane[1], None)) elif leftovers == 7: for lane in setup: emitter.EmitVStore('1.32', emitter.Lane(lane[0], 0), emitter.DereferenceIncrement(lane[1], None)) for lane in setup: emitter.EmitVStore('1.16', emitter.Lane(lane[0], 2), emitter.DereferenceIncrement(lane[1], None)) for lane in setup: emitter.EmitVStore('1.8', emitter.Lane(lane[0], 6), emitter.DereferenceIncrement(lane[1], None)) else: raise ConfigurationError('Unsupported leftovers count: %d' % leftovers) def GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes, multiplicative_offset, rounding_offset, shift): """Handle leftovers if row size not a multiply of 8.""" lane_temps = [] for lane in lanes: lane_temps.append(registers.QuadRegister()) GenerateLoadLeftovers(emitter, registers, leftovers, lanes) quantize_setup = [] for (lane_temp, lane) in zip(lane_temps, lanes): quantize_setup.append([lane.load_1, lane.offset, registers.Low(lane_temp)]) if leftovers > 4: quantize_setup.append([lane.load_2, lane.offset, registers.High(lane_temp) ]) GenerateQuantize(emitter, registers, quantize_setup, lane_temps, multiplicative_offset, rounding_offset, shift) GenerateStoreLeftovers(emitter, registers, leftovers, lane_temps, lanes) def GenerateQntNx8(emitter, qnt_lanes, leftovers, aligned): """Emits optimized quantization code for given lanes and row size.""" if leftovers < 0 or leftovers > 7: raise ConfigurationError('Leftovers should be between 0 and 7 inclusive.') if qnt_lanes < 1 or qnt_lanes > 3: raise ConfigurationError('Qnt_lanes should should be 1, 2 or 3.') name = BuildName(qnt_lanes, leftovers, aligned) emitter.EmitFunctionBeginA( name, [['const std::int32_t*', 'source'], ['std::int32_t', 'count'], ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'], ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'], ['std::int32_t', 'multiplicative_offset'], ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void') emitter.EmitAssert('count %% 8 == %d' % leftovers) emitter.EmitAssert('count >= 8') emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(source) % 8 == 0') if aligned: emitter.EmitAssert('reinterpret_cast<std::uintptr_t>(destination) % 8 == 0') if qnt_lanes > 1: emitter.EmitAssert('destination_stride % 8 == 0') emitter.EmitAsmBegin() registers = neon_emitter.NeonRegisters() count = registers.MapParameter('count') multiplicative_offset = DuplicateRegister( emitter, registers, registers.MapParameter('multiplicative_offset')) rounding_offset = DuplicateRegister(emitter, registers, registers.MapParameter('rounding_offset')) shift = DuplicateRegister(emitter, registers, registers.MapParameter('shift')) lanes = GenerateQntLanes( emitter, registers, qnt_lanes, registers.MapParameter('source'), registers.MapParameter('stride'), registers.MapParameter('destination'), registers.MapParameter('destination_stride'), registers.MapParameter('offsets')) if leftovers: emitter.EmitSubs(count, count, emitter.ImmediateConstant(leftovers)) emitter.EmitBeqFront(2) emitter.EmitNewline() emitter.EmitNumericalLabel(1) emitter.EmitSubs(count, count, emitter.ImmediateConstant(8)) GenerateLoadQuantizeStore(emitter, registers, lanes, multiplicative_offset, rounding_offset, shift, 64 if aligned else None) emitter.EmitNewline() emitter.EmitBneBack(1) if leftovers: emitter.EmitNumericalLabel(2) GenerateLeftoverLoadQuantizeStore(emitter, registers, leftovers, lanes, multiplicative_offset, rounding_offset, shift) emitter.EmitAsmEnd(registers.MappedParameters(), [], registers.Clobbers() + ['cc', 'memory']) emitter.EmitFunctionEnd() def BuildMultiQuantizeName(aligned, rows): name = 'multi_qnt_%dx8' % rows if aligned: name = '%s_aligned' % name return name def GenerateMultiQuantize(emitter, aligned, rows): """Emit main quantization code that switches between optimized versions.""" name = BuildMultiQuantizeName(aligned, rows) emitter.EmitFunctionBeginA( name, [['const std::int32_t*', 'source'], ['std::int32_t', 'count'], ['std::int32_t', 'stride'], ['const std::int32_t*', 'offsets'], ['std::uint8_t*', 'destination'], ['std::int32_t', 'destination_stride'], ['std::int32_t', 'multiplicative_offset'], ['std::int32_t', 'rounding_offset'], ['std::int32_t', 'shift']], 'void') emitter.EmitSwitch('count % 8') for leftovers in range(0, 8): emitter.EmitCase(leftovers) emitter.PushIndent() emitter.EmitCall( BuildName(rows, leftovers, aligned), ['source', 'count', 'stride', 'offsets', 'destination', 'destination_stride', 'multiplicative_offset', 'rounding_offset', 'shift']) emitter.EmitBreak() emitter.PopIndent() emitter.EmitSwitchEnd() emitter.EmitFunctionEnd() def GenerateFunctions(neon, cc): for aligned in [True, False]: for lanes in range(1, 4): for leftovers in range(0, 8): GenerateQntNx8(neon, lanes, leftovers, aligned) neon.EmitNewline() for aligned in [True, False]: for rows in range(1, 4): GenerateMultiQuantize(cc, aligned, rows) cc.EmitNewline()