Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Arm64 input register loading and incrementation. #7614

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions gemm_compiler/aarch64_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def inner_loop(self, M, N):
w_step=self.register_bytes() * N_COUNT,
)
for l in self.weights_asm()['loop']:
if N_COUNT % 2 == 0:
if N_COUNT % 2 != 0:
asm_string += l.format(
W_ptr=self.w_ptr_register(),
W=self.w_registers()[nr],
Expand Down Expand Up @@ -276,17 +276,15 @@ def clamp_inputs_and_outputs(
def increment_ptr(self, ptr, step):
return f'add {ptr}, {ptr}, {step}\n'

def zero_gp_register(self, reg):
return f'eor {reg}, {reg}, {reg}\n'
def initialize_k_register(self, reg):
kc_register = self.kc_register()
return f'mov {reg}, {kc_register}\n'

def cmp_k_and_jump_if_less(self, label):
kc_register = self.kc_register()
k_register = self.k_register()
return """add {k_register}, {k_register}, 4
cmp {kc_register}, {k_register}
bne {label}\n""".format(
label=label, k_register=k_register, kc_register=kc_register
)
return f"""subs {k_register}, {k_register}, 4
bne {label}\n"""

def epilogue(self, M, N, isa):
restore_stack = """
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/avx512f_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def store(
if pop_c:
asm_string += '\n' + '# Pop output pointers from the stack.\n'
c_reg_offset = 0
POP_C = 'mov {C_REG}, [rsp + {offset}]\n'
POP_C = 'mov {C_REG}, [rsp - {offset}]\n'
for mr in range(0, M):
sp_offset = 128 + (mr) * 16 + 8
asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset)
Expand All @@ -208,7 +208,7 @@ def store(
)
if pop_c:
asm_string += '\n' + '# Write output pointers to the stack.\n'
POP_C = 'mov [rsp + {offset}], {C_REG}\n'
POP_C = 'mov [rsp - {offset}], {C_REG}\n'
for mr in range(0, M):
sp_offset = 128 + (mr) * 16 + 8
asm_string += POP_C.format(C_REG=cm_registers[mr], offset=sp_offset)
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/base_architecture.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def increment_ptr(self, ptr, step):
raise NotImplementedError

@abstractmethod
def zero_gp_register(self, reg):
"""Zero the given general purpose register."""
def initialize_k_register(self, reg):
"""Initialized the given general purpose register for inner loop control."""
raise NotImplementedError

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions gemm_compiler/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def generate_gemm_microkernel(

# the outer loop label
asm_string += '\nouter_loop:\n'
asm_string += '# Zero k counter.\n'
asm_string += isa.zero_gp_register(k_register)
asm_string += '# Initialize k counter.\n'
asm_string += isa.initialize_k_register(k_register)

# Read a registers from the stack if required
asm_string += isa.read_a_registers(M=M)
Expand Down
20 changes: 0 additions & 20 deletions gemm_compiler/neondot_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,6 @@ def quantization_params(self):
def quantization_params_register(self):
return 'x24'

def input_asm(self):
in_asm = {
'loop': [
'ldr d{AM}, [{AM_ptr}, {a_offset}]\n',
]
}
return in_asm

def weights_asm(self):
w_asm = {
'loop': [
'ldr q{W}, [{W_ptr}, {offset}]\n',
],
'loop_2': [
'ldp q{W}, q{W_1}, [{W_ptr}, {offset}]\n',
],
'after': 'add {W}, {W}, {w_step}\n',
}
return w_asm

def compute_asm(self):
c_asm = {
'loop': ['sdot v{ACC}.4s, v{W}.16b, v{A}.4b[0]\n'],
Expand Down
30 changes: 11 additions & 19 deletions gemm_compiler/neonfma_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,19 @@ def w_registers(self):
def input_asm(self):
in_asm = {
'loop': [
'ldr d{AM}, [{AM_ptr}, {a_offset}]\n',
'ldr s{AM}, [{AM_ptr}], 4\n',
]
}
return in_asm

def weights_asm(self):
w_asm = {
'loop': [
'ldr q{W}, [{W_ptr}, {offset}]\n',
'ldr q{W}, [{W_ptr}], 16\n',
],
'loop_2': [
'ldp q{W}, q{W_1}, [{W_ptr}, {offset}]\n',
'ldp q{W}, q{W_1}, [{W_ptr}], 32\n',
],
'after': 'add {W}, {W}, {w_step}\n',
}
return w_asm

Expand Down Expand Up @@ -139,22 +138,21 @@ def store(
cmp {nc}, {n_step}
b.lo tail_{N_2}\n""".format(n_step=N, N_2=N // 2, nc=nc_reg)
for mr in range(0, M):
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}]\n'.format(
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format(
ACC=accumulators[mr],
ACC_1=accumulators[M + mr],
c_reg=cm_registers[mr],
)
for nr in range(2, N_COUNT, 2):
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}, {offset}]\n'.format(
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format(
ACC=accumulators[M * 2 + mr],
ACC_1=accumulators[M * 3 + mr],
c_reg=cm_registers[mr],
offset=self.register_bytes() * nr,
)
for mr in range(0, M):
asm_string += 'add {cm}, {cm}, {cn_stride}\n'.format(
cn_stride=N_COUNT * 16, cm=cm_registers[mr]
)
AM_PTR = self.am_registers()[mr]
kc_register = self.kc_register()
asm_string += f'sub {AM_PTR}, {AM_PTR}, {kc_register}\n'
CHECK = """
sub {nc}, {nc}, {n_step}
b.ne outer_loop
Expand All @@ -167,7 +165,7 @@ def store(
\ntail_8:
tbz {nc_lo}, 3, tail_4\n""".format(nc_lo=nc_lo)
for mr in range(0, M):
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}]\n'.format(
asm_string += 'stp q{ACC}, q{ACC_1}, [{c_reg}], 32\n'.format(
ACC=accumulators[mr],
ACC_1=accumulators[mr + M],
c_reg=cm_registers[mr],
Expand All @@ -179,30 +177,24 @@ def store(
asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format(
ACC0=accumulators[mr + M], ACC1=accumulators[mr + 3 * M]
)
for mr in range(0, M):
asm_string += 'add {cm}, {cm}, 32\n'.format(cm=cm_registers[mr])
asm_string += """
\ntail_4:
tbz {nc_lo}, 2, tail_2\n""".format(nc_lo=nc_lo)
for mr in range(0, M):
asm_string += 'str q{ACC}, [{c_reg}]\n'.format(
asm_string += 'str q{ACC}, [{c_reg}], 16\n'.format(
ACC=accumulators[mr], c_reg=cm_registers[mr]
)
for mr in range(0, M):
asm_string += 'mov v{ACC0}.16b, v{ACC1}.16b\n'.format(
ACC0=accumulators[mr], ACC1=accumulators[mr + M]
)
for mr in range(0, M):
asm_string += 'add {cm}, {cm}, 16\n'.format(cm=cm_registers[mr])
asm_string += """
\ntail_2:
tbz {nc_lo}, 1, tail_1\n""".format(nc_lo=nc_lo)
for mr in range(0, M):
asm_string += 'str d{ACC}, [{c_reg}]\n'.format(
asm_string += 'str d{ACC}, [{c_reg}], 8\n'.format(
ACC=accumulators[mr], c_reg=cm_registers[mr]
)
for mr in range(0, M):
asm_string += 'add {c_reg}, {c_reg}, 8\n'.format(c_reg=cm_registers[mr])
for mr in range(0, M):
asm_string += 'dup d{ACC}, v{ACC}.d[1]\n'.format(ACC=accumulators[mr])
asm_string += """
Expand Down
14 changes: 7 additions & 7 deletions gemm_compiler/x64_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def input_output_register_setup(self, M):
cmovle {aM}, {aM_1}
cmovle {cM}, {cM_1}\n"""
INPUT_OUTPUT_REGISTER_PUSH = """
mov [rsp + {a_rsp_offset}], {aM}
mov [rsp + {c_rsp_offset}], {cM}\n"""
mov [rsp - {a_rsp_offset}], {aM}
mov [rsp - {c_rsp_offset}], {cM}\n"""
ret = ''
if self.stack_size(M) != 0:
ret += """sub rsp, {stack_size}\n""".format(
Expand All @@ -208,11 +208,11 @@ def input_output_register_setup(self, M):
ret += (
'# Write rsi (a pointer) to the stack as we need the register.\n'
)
ret += 'mov [rsp + 128], rsi\n'
ret += 'mov [rsp - 128], rsi\n'
ret += (
'# Write r10 (c pointer) to the stack as we need the register.\n'
)
ret += 'mov [rsp + 136], r10\n'
ret += 'mov [rsp - 136], r10\n'
for mr in range(1, M):
# cycle size of 2 if required
if M > self.max_M_before_spilling():
Expand Down Expand Up @@ -262,7 +262,7 @@ def read_a_registers(self, M):
if M <= self.max_M_before_spilling():
return ''
ret = '# Read a pointers from stack into GP registers.\n'
POP_A = 'mov {aM}, [rsp + {a_rsp_offset}]\n'
POP_A = 'mov {aM}, [rsp - {a_rsp_offset}]\n'
for mr in range(0, M):
a_rsp_offset = 128 + mr * 16
ret += POP_A.format(aM=registers[mr], a_rsp_offset=a_rsp_offset)
Expand All @@ -272,7 +272,7 @@ def read_a_registers(self, M):
def increment_ptr(self, ptr, step):
return f'add {ptr}, {step}\n'

def zero_gp_register(self, reg):
def initialize_k_register(self, reg):
return f'mov {reg}, 0\n'

def cmp_k_and_jump_if_less(self, label):
Expand All @@ -287,7 +287,7 @@ def cmp_k_and_jump_if_less(self, label):

def load_from_stack(self, reg, offset):
"""Load 8 bytes from the given offset from the stack pointer to reg."""
return f'mov {reg}, [rsp + {offset}]\n'
return f'mov {reg}, [rsp - {offset}]\n'

def epilogue(self, M, N, isa):
restore_stack = '\nreturn:\n'
Expand Down
Loading
Loading