diff --git a/python/sdist/amici/cxxcodeprinter.py b/python/sdist/amici/cxxcodeprinter.py index e6e377b331..a6ff17e835 100644 --- a/python/sdist/amici/cxxcodeprinter.py +++ b/python/sdist/amici/cxxcodeprinter.py @@ -303,7 +303,9 @@ def get_switch_statement( indentation_step: Optional[str] = " " * 4, ): """ - Generate code for switch statement + Generate code for a C++ switch statement. + + Generate code for a C++ switch statement with a ``break`` after each case. :param condition: Condition for switch @@ -321,26 +323,39 @@ def get_switch_statement( :return: Code for switch expression as list of strings """ - lines = [] - if not cases: - return lines + return [] indent0 = indentation_level * indentation_step indent1 = (indentation_level + 1) * indentation_step indent2 = (indentation_level + 2) * indentation_step + + # try to find redundant statements and collapse those cases + # map statements to case expressions + cases_map: dict[tuple[str, ...], list[str]] = {} for expression, statements in cases.items(): if statements: - lines.extend( + statement_code = tuple( [ - f"{indent1}case {expression}:", *(f"{indent2}{statement}" for statement in statements), f"{indent2}break;", ] ) - - if lines: - lines.insert(0, f"{indent0}switch({condition}) {{") - lines.append(indent0 + "}") - - return lines + case_code = f"{indent1}case {expression}:" + + cases_map[statement_code] = cases_map.get(statement_code, []) + [ + case_code + ] + + if not cases_map: + return [] + + return [ + f"{indent0}switch({condition}) {{", + *( + code + for codes in cases_map.items() + for code in itertools.chain.from_iterable(reversed(codes)) + ), + indent0 + "}", + ]