Skip to content

Commit

Permalink
Simplify template based file generation code.
Browse files Browse the repository at this point in the history
  • Loading branch information
r12f committed Dec 7, 2023
1 parent 44e07b8 commit 0cd229b
Showing 1 changed file with 73 additions and 54 deletions.
127 changes: 73 additions & 54 deletions dash-pipeline/SAI/sai_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,23 @@ def parse(self, p4rt_value, *args, **kwargs):
#
# Parsed SAI objects and parsers
#
# The SAI objects are parsed from the P4Runtime JSON file, generated by p4 compiler, which contains the information
# of all tables and entry information.
#
# The classes below are used to parse the P4Runtime JSON file to get the key information, so we can generate the SAI
# API headers and implementation files afterwards.
#
# At high level, the hiredarchy of the SAI objects is as follows:
#
# DASHSAIExtensions: All DASH SAI extensions.
# - SAIEnum: A single enum type.
# - SAIEnumMember: A single enum member within the enum.
# - SAIAPISet: All information for a single SAI API set, such as routing or CA-PA mapping.
# - SAIAPITableData: All information for a single SAI API table used in the API set.
# - SAIAPITableKey: Information of a single P4 table key defined in the table.
# - SAIAPITableAction: Information of a single P4 table action defined used by the table.
# - SAIAPITableActionParam: Information of a single P4 table action parameter used by the action.
#
class SAIType:
sai_type_to_field = {
'bool': 'booldata',
Expand Down Expand Up @@ -300,7 +317,7 @@ def parse_p4rt(self, p4rt_table_key, v4_or_v6_key_ids):

self.id = p4rt_table_key['id']
self.name = p4rt_table_key[NAME_TAG]
print("Parsing table key: " + self.name)
#print("Parsing table key: " + self.name)

full_key_name, self.sai_key_name = self.name.split(':')
key_tuple = full_key_name.split('.')
Expand Down Expand Up @@ -365,7 +382,7 @@ def parse_p4rt(self, p4rt_table_action, sai_enums):
]
}
'''
print("Parsing table action: " + self.name)
#print("Parsing table action: " + self.name)
self.name = self.name.split('.')[-1]
self.parse_action_params(p4rt_table_action, sai_enums)

Expand All @@ -384,7 +401,7 @@ def parse_action_params(self, p4rt_table_action, sai_enums):
# Parse all params.
for p in p4rt_table_action[PARAMS_TAG]:
param_name = p[NAME_TAG]
param = SAIAPITableActionParam.from_p4rt(p, name = param_name, sai_enums = sai_enums, v4_or_v6_param_ids = v4_or_v6_param_ids)
param = SAIAPITableActionParam.from_p4rt(p, sai_enums = sai_enums, v4_or_v6_param_ids = v4_or_v6_param_ids)
self.params.append(param)

return
Expand All @@ -407,9 +424,9 @@ def parse_p4rt(self, p4rt_table_action_param, sai_enums, v4_or_v6_param_ids):
{ "id": 1, "name": "dst_vnet_id", "bitwidth": 16 }
'''
print("Parsing table action param: " + self.name)
self.id = p4rt_table_action_param['id']
self.name = p4rt_table_action_param[NAME_TAG]
#print("Parsing table action param: " + self.name)

if STRUCTURED_ANNOTATIONS_TAG in p4rt_table_action_param:
self._parse_sai_object_annotation(p4rt_table_action_param)
Expand Down Expand Up @@ -457,7 +474,7 @@ def parse_p4rt(self, p4rt_table, program, all_actions, ignore_tables):
if self.name in ignore_tables:
return

print("Found table: " + self.name)
print("Parsing table: " + self.name)

self.name, self.api_name = self.name.split('|')
if '.' in self.name:
Expand Down Expand Up @@ -555,7 +572,7 @@ def __merge_action_params_to_table_params(self, action):
self.action_params.append(action_param)


class DASHAPISet:
class DASHAPISet(SAIObject):
'''
This class holds all parsed SAI API info for a specific API set, such as routing or CA-PA mapping.
'''
Expand Down Expand Up @@ -623,24 +640,36 @@ def __parse_sai_table_action(self, p4rt_actions, sai_enums):
action_data[action.id] = action
return action_data

#
# SAI Generators
#
class SAITemplateRender:
jinja2_env = None

@classmethod
def new_tm(cls, template_file_path):
if cls.jinja2_env == None:
cls.env = Environment(loader=FileSystemLoader('.'), trim_blocks=True, lstrip_blocks=True)
cls.env.add_extension('jinja2.ext.loopcontrols')
cls.env.add_extension('jinja2.ext.do')

return cls.env.get_template(template_file_path)

def get_uniq_sai_api(sai_api):
""" Only keep one table per group(with same table name) """
groups = set()
sai_api = copy.deepcopy(sai_api)
tables = []
for table in sai_api.tables:
if table.name in groups:
continue
tables.append(table)
groups.add(table.name)
sai_api.tables = tables
return sai_api
def __init__(self, template_file_path):
self.template_file_path = template_file_path
self.tm = SAITemplateRender.new_tm(template_file_path)

def render(self, **kwargs):
return self.tm.render(**kwargs)

def render_to_file(self, target_file_path, **kwargs):
rendered_str = self.tm.render(**kwargs)
write_if_different(target_file_path, rendered_str)

# don't write content to file if file already exists
# and the content is the same, this will not touch
# the file and let make utilize this
def write_if_different(file,content):
def write_if_different(file, content):
if os.path.isfile(file) == True:
o = open(file, "r")
data = o.read()
Expand All @@ -650,41 +679,37 @@ def write_if_different(file,content):
with open(file, 'w') as o:
o.write(content)

def write_sai_impl_files(sai_api):
env = Environment(loader=FileSystemLoader('.'), trim_blocks=True, lstrip_blocks=True)
env.add_extension('jinja2.ext.loopcontrols')
env.add_extension('jinja2.ext.do')
sai_impl_tm = env.get_template('/templates/saiapi.cpp.j2')
if "dash" in sai_api.app_name:
header_prefix = "experimental"
else:
header_prefix = ""
sai_impl_str = sai_impl_tm.render(tables = sai_api.tables, app_name = sai_api.app_name, header_prefix = header_prefix)
write_if_different('./lib/sai' + sai_api.app_name.replace('_', '') + '.cpp',sai_impl_str)
class SAIFileUpdater:
def __init__(self, file_path):
self.file_path = file_path

def write_sai_fixed_api_files(sai_api_full_name_list):
env = Environment(loader=FileSystemLoader('.'))
with open(file_path, 'r') as f:
self.lines = f.readlines()

for filename in ['saifixedapis.cpp', 'saiimpl.h']:
env = Environment(loader=FileSystemLoader('.'), trim_blocks=True, lstrip_blocks=True)
sai_impl_tm = env.get_template('/templates/%s.j2' % filename)
sai_impl_str = sai_impl_tm.render(tables = sai_api.tables, app_name = sai_api.app_name, api_names = sai_api_full_name_list)
def get_uniq_sai_api(sai_api):
""" Only keep one table per group(with same table name) """
groups = set()
sai_api = copy.deepcopy(sai_api)
tables = []
for table in sai_api.tables:
if table.name in groups:
continue
tables.append(table)
groups.add(table.name)
sai_api.tables = tables
return sai_api

write_if_different('./lib/%s' % filename,sai_impl_str)
def write_sai_impl_files(sai_api):
header_prefix = "experimental" if "dash" in sai_api.app_name else ""
SAITemplateRender('templates/saiapi.cpp.j2').render_to_file('./lib/sai' + sai_api.app_name.replace('_', '') + '.cpp', tables = sai_api.tables, app_name = sai_api.app_name, header_prefix = header_prefix)

def write_sai_fixed_api_files(sai_api_full_name_list):
for filename in ['saifixedapis.cpp', 'saiimpl.h']:
SAITemplateRender('templates/%s.j2' % filename).render_to_file('./lib/%s' % filename, tables = sai_api.tables, app_name = sai_api.app_name, api_names = sai_api_full_name_list)

def write_sai_files(sai_api):
# The main file
with open('templates/saiapi.h.j2', 'r') as sai_header_tm_file:
sai_header_tm_str = sai_header_tm_file.read()

env = Environment(loader=FileSystemLoader('.'), trim_blocks=True, lstrip_blocks=True)
env.add_extension('jinja2.ext.loopcontrols')
env.add_extension('jinja2.ext.do')
sai_header_tm = env.get_template('templates/saiapi.h.j2')
sai_header_str = sai_header_tm.render(sai_api = sai_api)

write_if_different('./SAI/experimental/saiexperimental' + sai_api.app_name.replace('_', '') + '.h',sai_header_str)
SAITemplateRender('templates/saiapi.h.j2').render_to_file('./SAI/experimental/saiexperimental' + sai_api.app_name.replace('_', '') + '.h', sai_api = sai_api)

# The SAI Extensions
with open('./SAI/experimental/saiextensions.h', 'r') as f:
Expand Down Expand Up @@ -797,19 +822,14 @@ def write_sai_files(sai_api):
sai_api_name_list.append(sai_api.app_name.replace('_', ''))
sai_api_full_name_list.append(sai_api.app_name)

env = Environment(loader=FileSystemLoader('.'), trim_blocks=True, lstrip_blocks=True)
env.add_extension('jinja2.ext.loopcontrols')
env.add_extension('jinja2.ext.do')

final_sai_enums = []
with open('./SAI/experimental/saitypesextensions.h', 'r') as f:
content = f.read()
for enum in sai_enums:
if enum.name not in content:
final_sai_enums.append(enum)

sai_enums_tm = env.get_template('templates/saienums.j2')
sai_enums_str = sai_enums_tm.render(sai_enums = final_sai_enums)
sai_enums_str = SAITemplateRender('templates/saienums.j2').render(sai_enums = final_sai_enums)
sai_enums_lines = sai_enums_str.split('\n')

# The SAI object struct for entries
Expand All @@ -826,7 +846,6 @@ def write_sai_files(sai_api):

write_if_different('./SAI/experimental/saitypesextensions.h',''.join(new_lines))


write_sai_fixed_api_files(sai_api_full_name_list)

if args.print_sai_lib:
Expand Down

0 comments on commit 0cd229b

Please sign in to comment.