diff --git a/noxfile.py b/noxfile.py index f0e24f642..5783838f7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -27,7 +27,7 @@ "pytest!=7.1.0", "pyyaml", "types-PyYAML", - "typing_extensions", + "typing_extensions>=4.10", "ml-dtypes", ) ONNX = "onnx==1.17" diff --git a/onnxscript/converter.py b/onnxscript/converter.py index a565cacfd..f155f87a1 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -800,6 +800,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: non_scalar_indices.extend(scalar_indices) if non_scalar_indices: last_axis, _ = non_scalar_indices[-1] + else: + # TODO(justinchuby): Clarify what last_axis should be when non_scalar_indices is False + last_axis = None for axis, index_expr in non_scalar_indices: index_value = self._translate_expr(index_expr) axis_attr = self._make_onnx_attr("axis", axis) diff --git a/onnxscript/diagnostics/infra/__init__.py b/onnxscript/diagnostics/infra/__init__.py deleted file mode 100644 index d271aea2e..000000000 --- a/onnxscript/diagnostics/infra/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from ._infra import ( - DiagnosticOptions, - Graph, - Invocation, - Level, - Location, - Rule, - RuleCollection, - Stack, - StackFrame, - Tag, - ThreadFlowLocation, - levels, -) -from .context import Diagnostic, DiagnosticContext, RuntimeErrorWithDiagnosticError - -__all__ = [ - "Diagnostic", - "DiagnosticContext", - "DiagnosticOptions", - "Graph", - "Invocation", - "Level", - "levels", - "Location", - "Rule", - "RuleCollection", - "RuntimeErrorWithDiagnosticError", - "Stack", - "StackFrame", - "Tag", - "ThreadFlowLocation", -] diff --git a/onnxscript/diagnostics/infra/_infra.py b/onnxscript/diagnostics/infra/_infra.py deleted file mode 100644 index 1d8d4264b..000000000 --- a/onnxscript/diagnostics/infra/_infra.py +++ /dev/null @@ -1,321 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""This file defines an additional layer of abstraction on top of the SARIF OM.""" - -from __future__ import annotations - -import dataclasses -import enum -import pprint -from typing import FrozenSet, List, Mapping, Optional, Sequence, Tuple - -from onnxscript.diagnostics.infra import formatter, sarif - - -class Level(enum.IntEnum): - """The level of a diagnostic. - - This class is used to represent the level of a diagnostic. The levels are defined - by the SARIF specification, and are not modifiable. For alternative categories, - please use infra.Tag instead. When selecting a level, please consider the following - guidelines: - - - NONE: Informational result that does not indicate the presence of a problem. - - NOTE: An opportunity for improvement was found. - - WARNING: A potential problem was found. - - ERROR: A serious problem was found. - - This level is a subclass of enum.IntEnum, and can be used as an integer. Its integer - value maps to the logging levels in Python's logging module. The mapping is as - follows: - - Level.NONE = logging.DEBUG = 10 - Level.NOTE = logging.INFO = 20 - Level.WARNING = logging.WARNING = 30 - Level.ERROR = logging.ERROR = 40 - """ - - NONE = 10 - NOTE = 20 - WARNING = 30 - ERROR = 40 - - -levels = Level - - -class Tag(enum.Enum): - """The tag of a diagnostic. This class can be inherited to define custom tags.""" - - -class PatchedPropertyBag(sarif.PropertyBag): - """Key/value pairs that provide additional information about the object. - - The definition of PropertyBag via SARIF spec is "A property bag is an object (ยง3.6) - containing an unordered set of properties with arbitrary names." However it is not - reflected in the json file, and therefore not captured by the python representation. - This patch adds additional **kwargs to the `__init__` method to allow recording - arbitrary key/value pairs. - """ - - def __init__(self, tags: Optional[List[str]] = None, **kwargs): - super().__init__(tags=tags) - self.__dict__.update(kwargs) - - -@dataclasses.dataclass(frozen=True) -class Rule: - id: str - name: str - message_default_template: str - short_description: Optional[str] = None - full_description: Optional[str] = None - full_description_markdown: Optional[str] = None - help_uri: Optional[str] = None - - @classmethod - def from_sarif(cls, **kwargs): - """Returns a rule from the SARIF reporting descriptor.""" - short_description = kwargs.get("short_description", {}).get("text") - full_description = kwargs.get("full_description", {}).get("text") - full_description_markdown = kwargs.get("full_description", {}).get("markdown") - help_uri = kwargs.get("help_uri") - - rule = cls( - id=kwargs["id"], - name=kwargs["name"], - message_default_template=kwargs["message_strings"]["default"]["text"], - short_description=short_description, - full_description=full_description, - full_description_markdown=full_description_markdown, - help_uri=help_uri, - ) - return rule - - def sarif(self) -> sarif.ReportingDescriptor: - """Returns a SARIF reporting descriptor of this Rule.""" - short_description = ( - sarif.MultiformatMessageString(text=self.short_description) - if self.short_description is not None - else None - ) - full_description = ( - sarif.MultiformatMessageString( - text=self.full_description, markdown=self.full_description_markdown - ) - if self.full_description is not None - else None - ) - return sarif.ReportingDescriptor( - id=self.id, - name=self.name, - short_description=short_description, - full_description=full_description, - help_uri=self.help_uri, - ) - - def format(self, level: Level, *args, **kwargs) -> Tuple[Rule, Level, str]: - """Returns a tuple of (rule, level, message) for a diagnostic. - - This method is used to format the message of a diagnostic. The message is - formatted using the default template of this rule, and the arguments passed in - as `*args` and `**kwargs`. The level is used to override the default level of - this rule. - """ - return (self, level, self.format_message(*args, **kwargs)) - - def format_message(self, *args, **kwargs) -> str: - """Returns the formatted default message of this Rule. - - This method should be overridden (with code generation) by subclasses to reflect - the exact arguments needed by the message template. This is a helper method to - create the default message for a diagnostic. - """ - return self.message_default_template.format(*args, **kwargs) - - def pretty_print(self): - pass - - -@dataclasses.dataclass -class Location: - uri: Optional[str] = None - line: Optional[int] = None - message: Optional[str] = None - start_column: Optional[int] = None - end_column: Optional[int] = None - snippet: Optional[str] = None - function: Optional[str] = None - - def sarif(self) -> sarif.Location: - """Returns the SARIF representation of this location.""" - return sarif.Location( - physical_location=sarif.PhysicalLocation( - artifact_location=sarif.ArtifactLocation(uri=self.uri), - region=sarif.Region( - start_line=self.line, - start_column=self.start_column, - end_column=self.end_column, - snippet=sarif.ArtifactContent(text=self.snippet), - ), - ), - message=sarif.Message(text=self.message) if self.message is not None else None, - ) - - def pretty_print(self): - """Prints the location in a traceback style format.""" - unknown = "" - snippet = self.snippet or unknown - uri = self.uri or unknown - function = self.function or unknown - lineno = self.line if self.line is not None else unknown - message = f" # {self.message}" if self.message is not None else "" - print(f' File "{uri}", line {lineno}, in {function}\n {snippet}{message}') - - -@dataclasses.dataclass -class StackFrame: - location: Location - - def sarif(self) -> sarif.StackFrame: - """Returns the SARIF representation of this stack frame.""" - return sarif.StackFrame(location=self.location.sarif()) - - def pretty_print(self): - """Prints the stack frame in a human-readable format.""" - self.location.pretty_print() - - -@dataclasses.dataclass -class Stack: - """Records a stack trace. The frames are in order from newest to oldest stack frame.""" - - frames: List[StackFrame] = dataclasses.field(default_factory=list) - message: Optional[str] = None - - def sarif(self) -> sarif.Stack: - """Returns the SARIF representation of this stack.""" - return sarif.Stack( - frames=[frame.sarif() for frame in self.frames], - message=sarif.Message(text=self.message) if self.message is not None else None, - ) - - def pretty_print(self): - """Prints the stack in a human-readable format.""" - formatter.pretty_print_title(f"Stack: {self.message}", fill_char="-") - for frame in reversed(self.frames): - frame.pretty_print() - - -@dataclasses.dataclass -class ThreadFlowLocation: - """Records code location and the initial state.""" - - location: Location - state: Mapping[str, str] - index: int - stack: Optional[Stack] = None - - def sarif(self) -> sarif.ThreadFlowLocation: - """Returns the SARIF representation of this thread flow location.""" - return sarif.ThreadFlowLocation( - location=self.location.sarif(), - state=self.state, - stack=self.stack.sarif() if self.stack is not None else None, - ) - - def pretty_print(self, verbose: bool = False): - """Prints the thread flow location in a human-readable format.""" - formatter.pretty_print_title(f"Step {self.index}", fill_char="-") - self.location.pretty_print() - if verbose: - print(f"State: {pprint.pformat(self.state)}") - if self.stack is not None: - self.stack.pretty_print() - - -@dataclasses.dataclass -class Graph: - """A graph of diagnostics. - - This class stores the string representation of a model graph. - The `nodes` and `edges` fields are unused in the current implementation. - """ - - graph: str - name: str - description: Optional[str] = None - - def sarif(self) -> sarif.Graph: - """Returns the SARIF representation of this graph.""" - return sarif.Graph( - description=sarif.Message(text=self.graph), - properties=PatchedPropertyBag(name=self.name, description=self.description), - ) - - def pretty_print( - self, - verbose: bool = False, - ): - """Prints the diagnostics in a human-readable format. - - Args: - verbose: If True, prints all information. Otherwise, only prints compact - information. E.g., graph name and description. - log_level: The minimum level of diagnostics to print. - """ - formatter.pretty_print_title(f"Graph: {self.name}", fill_char="-") - print(self.description) - if verbose: - print(self.graph) - - -@dataclasses.dataclass -class RuleCollection: - _rule_id_name_set: FrozenSet[Tuple[str, str]] = dataclasses.field(init=False) - - def __post_init__(self) -> None: - self._rule_id_name_set = frozenset( - { - (field.default.id, field.default.name) - for field in dataclasses.fields(self) - if isinstance(field.default, Rule) - } - ) - - def __contains__(self, rule: Rule) -> bool: - """Checks if the rule is in the collection.""" - return (rule.id, rule.name) in self._rule_id_name_set - - @classmethod - def custom_collection_from_list( - cls, new_collection_class_name: str, rules: Sequence[Rule] - ) -> RuleCollection: - """Creates a custom class inherited from RuleCollection with the list of rules.""" - return dataclasses.make_dataclass( - new_collection_class_name, - [ - ( - formatter.kebab_case_to_snake_case(rule.name), - type(rule), - dataclasses.field(default=rule), - ) - for rule in rules - ], - bases=(cls,), - )() - - -class Invocation: - # TODO: Implement this. - # Tracks top level call arguments and diagnostic options. - def __init__(self) -> None: - raise NotImplementedError() - - -@dataclasses.dataclass -class DiagnosticOptions: - """Options for diagnostic context.""" - - log_verbose: bool = dataclasses.field(default=False) - log_level: Level = dataclasses.field(default=Level.ERROR) diff --git a/onnxscript/diagnostics/infra/context.py b/onnxscript/diagnostics/infra/context.py deleted file mode 100644 index 081ba9f65..000000000 --- a/onnxscript/diagnostics/infra/context.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""A diagnostic context based on SARIF.""" - -from __future__ import annotations - -import contextlib -import dataclasses -import gzip -import logging -import typing -from typing import Callable, Generator, List, Literal, Mapping, Optional - -from onnxscript.diagnostics import infra -from onnxscript.diagnostics.infra import formatter, sarif, utils -from onnxscript.diagnostics.infra.sarif import version as sarif_version - -if typing.TYPE_CHECKING: - from typing_extensions import Self - - -@dataclasses.dataclass -class Diagnostic: - rule: infra.Rule - level: infra.Level - message: Optional[str] = None - locations: List[infra.Location] = dataclasses.field(default_factory=list) - stacks: List[infra.Stack] = dataclasses.field(default_factory=list) - graphs: List[infra.Graph] = dataclasses.field(default_factory=list) - thread_flow_locations: List[infra.ThreadFlowLocation] = dataclasses.field( - default_factory=list - ) - additional_message: Optional[str] = None - tags: List[infra.Tag] = dataclasses.field(default_factory=list) - source_exception: Optional[Exception] = None - """The exception that caused this diagnostic to be created.""" - - def __post_init__(self) -> None: - pass - - def sarif(self) -> sarif.Result: - """Returns the SARIF Result representation of this diagnostic.""" - message = self.message or self.rule.message_default_template - if self.additional_message: - message_markdown = ( - f"{message}\n\n## Additional Message:\n\n{self.additional_message}" - ) - else: - message_markdown = message - - kind: Literal["informational", "fail"] = ( - "informational" if self.level == infra.Level.NONE else "fail" - ) - - sarif_result = sarif.Result( - message=sarif.Message(text=message, markdown=message_markdown), - level=self.level.name.lower(), # type: ignore[arg-type] - rule_id=self.rule.id, - kind=kind, - ) - sarif_result.locations = [location.sarif() for location in self.locations] - sarif_result.stacks = [stack.sarif() for stack in self.stacks] - sarif_result.graphs = [graph.sarif() for graph in self.graphs] - sarif_result.code_flows = [ - sarif.CodeFlow( - thread_flows=[ - sarif.ThreadFlow( - locations=[loc.sarif() for loc in self.thread_flow_locations] - ) - ] - ) - ] - sarif_result.properties = sarif.PropertyBag(tags=[tag.value for tag in self.tags]) - return sarif_result - - def with_location(self: Self, location: infra.Location) -> Self: - """Adds a location to the diagnostic.""" - self.locations.append(location) - return self - - def with_thread_flow_location(self: Self, location: infra.ThreadFlowLocation) -> Self: - """Adds a thread flow location to the diagnostic.""" - self.thread_flow_locations.append(location) - return self - - def with_stack(self: Self, stack: infra.Stack) -> Self: - """Adds a stack to the diagnostic.""" - self.stacks.append(stack) - return self - - def with_graph(self: Self, graph: infra.Graph) -> Self: - """Adds a graph to the diagnostic.""" - self.graphs.append(graph) - return self - - def with_additional_message(self: Self, message: str) -> Self: - """Adds an additional message to the diagnostic.""" - if self.additional_message is None: - self.additional_message = message - else: - self.additional_message = f"{self.additional_message}\n{message}" - return self - - def with_source_exception(self: Self, exception: Exception) -> Self: - """Adds the source exception to the diagnostic.""" - self.source_exception = exception - return self - - def record_python_call_stack(self, frames_to_skip: int) -> infra.Stack: - """Records the current Python call stack.""" - frames_to_skip += 1 # Skip this function. - stack = utils.python_call_stack(frames_to_skip=frames_to_skip) - self.with_stack(stack) - if len(stack.frames) > 0: - self.with_location(stack.frames[0].location) - return stack - - def record_python_call( - self, - fn: Callable, - state: Mapping[str, str], - message: Optional[str] = None, - frames_to_skip: int = 0, - ) -> infra.ThreadFlowLocation: - """Records a python call as one thread flow step.""" - frames_to_skip += 1 # Skip this function. - stack = utils.python_call_stack(frames_to_skip=frames_to_skip, frames_to_log=5) - location = utils.function_location(fn) - location.message = message - # Add function location to the top of the stack. - stack.frames.insert(0, infra.StackFrame(location=location)) - thread_flow_location = infra.ThreadFlowLocation( - location=location, - state=state, - index=len(self.thread_flow_locations), - stack=stack, - ) - self.with_thread_flow_location(thread_flow_location) - return thread_flow_location - - def pretty_print(self, verbose: bool = False, log_level: infra.Level = infra.Level.ERROR): - """Prints the diagnostics in a human-readable format. - - Args: - verbose: If True, prints all information. E.g. stack frames, graphs, etc. - Otherwise, only prints compact information. E.g., rule name and display message. - log_level: The minimum level of diagnostics to print. - """ - if self.level.value < log_level.value: - return - formatter.pretty_print_item_title(f"{self.level.name}: {self.rule.name}") - print(self.message) - print(self.additional_message) - - if not verbose: - print("\n") - return - - formatter.pretty_print_title("Locations", fill_char="-") - for location in self.locations: - location.pretty_print() - for stack in self.stacks: - stack.pretty_print() - formatter.pretty_print_title("Thread Flow Locations", fill_char="-") - for thread_flow_location in self.thread_flow_locations: - thread_flow_location.pretty_print(verbose=verbose) - for graph in self.graphs: - graph.pretty_print(verbose=verbose) - - print() - - # TODO: print help url to rule at the end. - - -class RuntimeErrorWithDiagnosticError(RuntimeError): - """Runtime error with enclosed diagnostic information.""" - - def __init__(self, diagnostic: Diagnostic): - super().__init__(diagnostic.message) - self.diagnostic = diagnostic - - -@dataclasses.dataclass -class DiagnosticContext: - name: str - version: str - options: infra.DiagnosticOptions = dataclasses.field( - default_factory=infra.DiagnosticOptions - ) - diagnostics: List[Diagnostic] = dataclasses.field(init=False, default_factory=list) - logger: logging.Logger = dataclasses.field( - init=True, default_factory=lambda: logging.getLogger().getChild("diagnostics") - ) - # TODO(bowbao): Implement this. - # _invocation: infra.Invocation = dataclasses.field(init=False) - _inflight_diagnostics: List[Diagnostic] = dataclasses.field( - init=False, default_factory=list - ) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return None - - def sarif(self) -> sarif.Run: - """Returns the SARIF Run object.""" - unique_rules = {diagnostic.rule for diagnostic in self.diagnostics} - return sarif.Run( - tool=sarif.Tool( - driver=sarif.ToolComponent( - name=self.name, - version=self.version, - rules=[rule.sarif() for rule in unique_rules], - ) - ), - results=[diagnostic.sarif() for diagnostic in self.diagnostics], - ) - - def sarif_log(self) -> sarif.SarifLog: # type: ignore[name-defined] - """Returns the SARIF Log object.""" - return sarif.SarifLog( - version=sarif_version.SARIF_VERSION, - schema_uri=sarif_version.SARIF_SCHEMA_LINK, - runs=[self.sarif()], - ) - - def to_json(self) -> str: - return formatter.sarif_to_json(self.sarif_log()) - - def dump(self, file_path: str, compress: bool = False) -> None: - """Dumps the SARIF log to a file.""" - if compress: - with gzip.open(file_path, "wt", encoding="utf-8") as f: - f.write(self.to_json()) - else: - with open(file_path, "w", encoding="utf-8") as f: - f.write(self.to_json()) - - def log(self, diagnostic: Diagnostic) -> None: - """Adds a diagnostic to the context. - - Use this method to add diagnostics that are not created by the context. - - Args: - diagnostic: The diagnostic to add. - """ - if not isinstance(diagnostic, Diagnostic): - raise TypeError( - f"Expected diagnostic of type {Diagnostic}, got {type(diagnostic)}" - ) - self.diagnostics.append(diagnostic) - self.logger.log(diagnostic.level, diagnostic.message) - self.logger.log(diagnostic.level, diagnostic.additional_message) - - def log_and_raise_if_error(self, diagnostic: Diagnostic) -> None: - self.log(diagnostic) - if diagnostic.level == infra.Level.ERROR: - raise RuntimeErrorWithDiagnosticError(diagnostic) from diagnostic.source_exception - - @contextlib.contextmanager - def add_inflight_diagnostic( - self, diagnostic: Diagnostic - ) -> Generator[Diagnostic, None, None]: - """Adds a diagnostic to the context. - - Use this method to add diagnostics that are not created by the context. - - Args: - diagnostic: The diagnostic to add. - """ - self._inflight_diagnostics.append(diagnostic) - try: - yield diagnostic - finally: - self._inflight_diagnostics.pop() - - def push_inflight_diagnostic(self, diagnostic: Diagnostic) -> None: - """Pushes a diagnostic to the inflight diagnostics stack. - - Args: - diagnostic: The diagnostic to push. - - Raises: - ValueError: If the rule is not supported by the tool. - """ - self._inflight_diagnostics.append(diagnostic) - - def pop_inflight_diagnostic(self) -> Diagnostic: - """Pops the last diagnostic from the inflight diagnostics stack. - - Returns: - The popped diagnostic. - """ - return self._inflight_diagnostics.pop() - - def inflight_diagnostic(self, rule: Optional[infra.Rule] = None) -> Diagnostic: - if rule is None: - # TODO(bowbao): Create builtin-rules and create diagnostic using that. - if len(self._inflight_diagnostics) <= 0: - raise AssertionError("No inflight diagnostics") - - return self._inflight_diagnostics[-1] - else: - # TODO(bowbao): Improve efficiency with Mapping[Rule, List[Diagnostic]] - for diagnostic in reversed(self._inflight_diagnostics): - if diagnostic.rule == rule: - return diagnostic - raise AssertionError(f"No inflight diagnostic for rule {rule.name}") - - def pretty_print( - self, verbose: Optional[bool] = None, log_level: Optional[infra.Level] = None - ) -> None: - """Prints the diagnostics in a human-readable format. - - Args: - verbose: Whether to print the diagnostics in verbose mode. See Diagnostic.pretty_print. - If not specified, uses the value of 'self.options.log_verbose'. - log_level: The minimum level of diagnostics to print. - If not specified, uses the value of 'self.options.log_level'. - """ - if verbose is None: - verbose = self.options.log_verbose - if log_level is None: - log_level = self.options.log_level - - formatter.pretty_print_title(f"Diagnostic Run {self.name} version {self.version}") - print(f"verbose: {verbose}, log level: {log_level}") - diagnostic_stats = dict.fromkeys(infra.Level, 0) - for diagnostic in self.diagnostics: - diagnostic_stats[diagnostic.level] += 1 - formatter.pretty_print_title( - " ".join(f"{diagnostic_stats[level]} {level.name}" for level in infra.Level) - ) - - for diagnostic in self.diagnostics: - diagnostic.pretty_print(verbose, log_level) - - unprinted_diagnostic_stats = [ - (level, count) - for level, count in diagnostic_stats.items() - if count > 0 and level.value < log_level.value - ] - if unprinted_diagnostic_stats: - print( - f"{' '.join(f'{count} {level.name}' for level, count in unprinted_diagnostic_stats)} " - "were not printed due to the log level." - ) - print() diff --git a/onnxscript/diagnostics/infra/decorator.py b/onnxscript/diagnostics/infra/decorator.py deleted file mode 100644 index 56a362624..000000000 --- a/onnxscript/diagnostics/infra/decorator.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import functools -import traceback -from typing import Any, Callable, Dict, Optional, Tuple, Type - -from onnxscript._internal import runtime_typing -from onnxscript.diagnostics import infra -from onnxscript.diagnostics.infra import formatter, utils - -MessageFormatterType = Callable[..., str] - - -@runtime_typing.checked -def format_message_in_text( - fn: Callable, # pylint: disable=unused-argument - *args: Any, - **kwargs: Any, -) -> str: - return f"{formatter.display_name(fn)}. " - - -@runtime_typing.checked -def format_exception_in_markdown(exception: Exception) -> str: - msg_list = ["### Exception log", "```"] - msg_list.extend( - traceback.format_exception(type(exception), exception, exception.__traceback__) - ) - msg_list.append("```") - return "\n".join(msg_list) - - -@runtime_typing.checked -def format_function_signature_in_markdown( - fn: Callable, - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - format_argument: Callable[[Any], str] = formatter.format_argument, -) -> str: - msg_list = [f"### Function Signature {formatter.display_name(fn)}"] - - state = utils.function_state(fn, args, kwargs) - - for k, v in state.items(): - msg_list.append(f"- {k}: {format_argument(v)}") - - return "\n".join(msg_list) - - -@runtime_typing.checked -def format_return_values_in_markdown( - return_values: Any, - format_argument: Callable[[Any], str] = formatter.format_argument, -) -> str: - return f"- Return value: {format_argument(return_values)}" - - -ModifierCallableType = Callable[ - [infra.Diagnostic, Callable, Tuple[Any, ...], Dict[str, Any], Any], None -] - - -@runtime_typing.checked -def diagnose_call( - rule: infra.Rule, - *, - level: infra.Level = infra.Level.NONE, - diagnostic_type: Type[infra.Diagnostic] = infra.Diagnostic, - format_argument: Callable[[Any], str] = formatter.format_argument, - diagnostic_message_formatter: MessageFormatterType = format_message_in_text, -) -> Callable: - def decorator(fn): - @functools.wraps(fn) - def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements - common_error_message = "diagnose_call can only be applied to callables" - if not callable(fn): - raise AssertionError( # noqa: TRY004 - f"{common_error_message}. Got {type(fn)} instead of callable." - ) - arg0 = args[0] if len(args) > 0 else None - if isinstance(ctx := arg0, infra.DiagnosticContext): - pass - elif isinstance( - ctx := getattr(arg0, "diagnostic_context", None), - infra.DiagnosticContext, - ): - pass - else: - # NOTE: At decorate time, it can't tell if a callable is function or method. - # Technically both are regarded as function at that time. - raise AssertionError( # noqa: TRY004 - f"{common_error_message}. For {fn}, " - f"If it is a function, a DiagnosticContext instance must be present as " - f"the first argument. " - f"If it is a method, a DiagnosticContext instance must be present as " - f"the attribute 'diagnostic_context' of the 'self' argument." - ) - - diag = diagnostic_type( - rule, - level, - diagnostic_message_formatter(fn, *args, **kwargs), - ) - - # pop the decorator frame - # TODO(bowbao): by default diagnostic doesn't have stack. - # So need to check before doing this. Make the code cleaner. - # Option: do not capture stack by default in diagnostic initialization. - stack: Optional[infra.Stack] = None - if len(diag.stacks) > 0: - stack = diag.stacks[0] - stack.frames.pop(0) - - # set function location - fn_location = utils.function_location(fn) - diag.locations.insert(0, fn_location) - # Add function location to the top of the stack. - if stack is not None: - stack.frames.insert(0, infra.StackFrame(location=fn_location)) - - additional_messages = [ - format_function_signature_in_markdown(fn, args, kwargs, format_argument), - ] - - return_values: Any = None - with ctx.add_inflight_diagnostic(diag) as diag: - try: - return_values = fn(*args, **kwargs) - additional_messages.append( - format_return_values_in_markdown(return_values, format_argument) - ) - except Exception as e: # pylint: disable=broad-exception-caught - # Record exception. - diag.level = infra.levels.ERROR - # TODO(bowbao): Message emitting api. - diag.message = diag.message or "" - diag.message += f"Raised from:\n {type(e).__name__}: {e}" - diag.with_source_exception(e) - additional_messages.append(format_exception_in_markdown(e)) - else: - return return_values - finally: - diag.with_additional_message("\n".join(additional_messages).strip()) - ctx.log_and_raise_if_error(diag) - - return wrapper - - return decorator - - -# TODO(bowbao): decorator to report only when failed. diff --git a/onnxscript/diagnostics/infra/formatter.py b/onnxscript/diagnostics/infra/formatter.py deleted file mode 100644 index 1ccf77b5c..000000000 --- a/onnxscript/diagnostics/infra/formatter.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import dataclasses -import json -import re -from typing import Any, Callable, Dict, List, Optional, Union - -from onnxscript._internal import runtime_typing -from onnxscript.diagnostics.infra import sarif - -# A list of types in the SARIF module to support pretty printing. -# This is solely for type annotation for the functions below. -_SarifClass = Union[ - sarif.SarifLog, - sarif.Run, - sarif.ReportingDescriptor, - sarif.Result, -] - - -@runtime_typing.checked -def snake_case_to_camel_case(s: str) -> str: - splits = s.split("_") - if len(splits) <= 1: - return s - return "".join([splits[0], *map(str.capitalize, splits[1:])]) - - -@runtime_typing.checked -def camel_case_to_snake_case(s: str) -> str: - return re.sub(r"([A-Z])", r"_\1", s).lower() - - -@runtime_typing.checked -def kebab_case_to_snake_case(s: str) -> str: - return s.replace("-", "_") - - -@runtime_typing.checked -def _convert_key( - object: Union[Dict[str, Any], Any], convert: Callable[[str], str] -) -> Union[Dict[str, Any], Any]: - """Convert and update keys in a dictionary with "convert". - - Any value that is a dictionary will be recursively updated. - Any value that is a list will be recursively searched. - - Args: - object: The object to update. - convert: The function to convert the keys, e.g. `kebab_case_to_snake_case`. - - Returns: - The updated object. - """ - if not isinstance(object, Dict): - return object - new_dict = {} - for k, v in object.items(): - new_k = convert(k) - if isinstance(v, Dict): - new_v = _convert_key(v, convert) - elif isinstance(v, List): - new_v = [_convert_key(elem, convert) for elem in v] - else: - new_v = v - if new_v is None: - # Otherwise unnesseraily bloated sarif log with "null"s. - continue - if new_v == -1: - # WAR: -1 as default value shouldn't be logged into sarif. - continue - - new_dict[new_k] = new_v - - return new_dict - - -@runtime_typing.checked -def sarif_to_json(attr_cls_obj: _SarifClass, indent: Optional[str] = " ") -> str: - dict = dataclasses.asdict(attr_cls_obj) - dict = _convert_key(dict, snake_case_to_camel_case) - return json.dumps(dict, indent=indent, separators=(",", ":")) - - -@runtime_typing.checked -def pretty_print_title( - title: str, width: int = 80, fill_char: str = "=", print_output: bool = True -) -> str: - """Pretty prints title in below format: - - ==================== title ==================== - """ - msg = f" {title} ".center(width, fill_char) - if print_output: - print(msg) - return msg - - -@runtime_typing.checked -def pretty_print_item_title( - title: str, fill_char: str = "=", print_output: bool = True -) -> str: - """Pretty prints title in below format: - - title - ===== - """ - msg_list = [] - msg_list.append(title) - msg_list.append(fill_char * len(title)) - - msg = "\n".join(msg_list) - if print_output: - print(msg) - return msg - - -@runtime_typing.checked -def format_argument(obj: Any) -> str: - return f"{type(obj)}" - - -@runtime_typing.checked -def display_name(fn: Callable) -> str: - if hasattr(fn, "__qualname__"): - return fn.__qualname__ - elif hasattr(fn, "__name__"): - return fn.__name__ - else: - return str(fn) diff --git a/onnxscript/diagnostics/infra/sarif/__init__.py b/onnxscript/diagnostics/infra/sarif/__init__.py deleted file mode 100644 index e610c3b75..000000000 --- a/onnxscript/diagnostics/infra/sarif/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from onnxscript.diagnostics.infra.sarif._address import Address -from onnxscript.diagnostics.infra.sarif._artifact import Artifact -from onnxscript.diagnostics.infra.sarif._artifact_change import ArtifactChange -from onnxscript.diagnostics.infra.sarif._artifact_content import ArtifactContent -from onnxscript.diagnostics.infra.sarif._artifact_location import ArtifactLocation -from onnxscript.diagnostics.infra.sarif._attachment import Attachment -from onnxscript.diagnostics.infra.sarif._code_flow import CodeFlow -from onnxscript.diagnostics.infra.sarif._configuration_override import ( - ConfigurationOverride, -) -from onnxscript.diagnostics.infra.sarif._conversion import Conversion -from onnxscript.diagnostics.infra.sarif._edge import Edge -from onnxscript.diagnostics.infra.sarif._edge_traversal import EdgeTraversal -from onnxscript.diagnostics.infra.sarif._exception import Exception -from onnxscript.diagnostics.infra.sarif._external_properties import ExternalProperties -from onnxscript.diagnostics.infra.sarif._external_property_file_reference import ( - ExternalPropertyFileReference, -) -from onnxscript.diagnostics.infra.sarif._external_property_file_references import ( - ExternalPropertyFileReferences, -) -from onnxscript.diagnostics.infra.sarif._fix import Fix -from onnxscript.diagnostics.infra.sarif._graph import Graph -from onnxscript.diagnostics.infra.sarif._graph_traversal import GraphTraversal -from onnxscript.diagnostics.infra.sarif._invocation import Invocation -from onnxscript.diagnostics.infra.sarif._location import Location -from onnxscript.diagnostics.infra.sarif._location_relationship import ( - LocationRelationship, -) -from onnxscript.diagnostics.infra.sarif._logical_location import LogicalLocation -from onnxscript.diagnostics.infra.sarif._message import Message -from onnxscript.diagnostics.infra.sarif._multiformat_message_string import ( - MultiformatMessageString, -) -from onnxscript.diagnostics.infra.sarif._node import Node -from onnxscript.diagnostics.infra.sarif._notification import Notification -from onnxscript.diagnostics.infra.sarif._physical_location import PhysicalLocation -from onnxscript.diagnostics.infra.sarif._property_bag import PropertyBag -from onnxscript.diagnostics.infra.sarif._rectangle import Rectangle -from onnxscript.diagnostics.infra.sarif._region import Region -from onnxscript.diagnostics.infra.sarif._replacement import Replacement -from onnxscript.diagnostics.infra.sarif._reporting_configuration import ( - ReportingConfiguration, -) -from onnxscript.diagnostics.infra.sarif._reporting_descriptor import ReportingDescriptor -from onnxscript.diagnostics.infra.sarif._reporting_descriptor_reference import ( - ReportingDescriptorReference, -) -from onnxscript.diagnostics.infra.sarif._reporting_descriptor_relationship import ( - ReportingDescriptorRelationship, -) -from onnxscript.diagnostics.infra.sarif._result import Result -from onnxscript.diagnostics.infra.sarif._result_provenance import ResultProvenance -from onnxscript.diagnostics.infra.sarif._run import Run -from onnxscript.diagnostics.infra.sarif._run_automation_details import ( - RunAutomationDetails, -) -from onnxscript.diagnostics.infra.sarif._sarif_log import SarifLog -from onnxscript.diagnostics.infra.sarif._special_locations import SpecialLocations -from onnxscript.diagnostics.infra.sarif._stack import Stack -from onnxscript.diagnostics.infra.sarif._stack_frame import StackFrame -from onnxscript.diagnostics.infra.sarif._suppression import Suppression -from onnxscript.diagnostics.infra.sarif._thread_flow import ThreadFlow -from onnxscript.diagnostics.infra.sarif._thread_flow_location import ThreadFlowLocation -from onnxscript.diagnostics.infra.sarif._tool import Tool -from onnxscript.diagnostics.infra.sarif._tool_component import ToolComponent -from onnxscript.diagnostics.infra.sarif._tool_component_reference import ( - ToolComponentReference, -) -from onnxscript.diagnostics.infra.sarif._translation_metadata import TranslationMetadata -from onnxscript.diagnostics.infra.sarif._version_control_details import ( - VersionControlDetails, -) -from onnxscript.diagnostics.infra.sarif._web_request import WebRequest -from onnxscript.diagnostics.infra.sarif._web_response import WebResponse - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_address.py b/onnxscript/diagnostics/infra/sarif/_address.py deleted file mode 100644 index c4b691f34..000000000 --- a/onnxscript/diagnostics/infra/sarif/_address.py +++ /dev/null @@ -1,46 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class Address: - """A physical or virtual address, or a range of addresses, in an 'addressable region' (memory or a binary file).""" - - absolute_address: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "absoluteAddress"} - ) - fully_qualified_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullyQualifiedName"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - kind: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "kind"} - ) - length: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "length"} - ) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - offset_from_parent: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "offsetFromParent"} - ) - parent_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "parentIndex"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - relative_address: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "relativeAddress"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact.py b/onnxscript/diagnostics/infra/sarif/_artifact.py deleted file mode 100644 index afec8b5e9..000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact.py +++ /dev/null @@ -1,84 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_content, - _artifact_location, - _message, - _property_bag, -) - - -@dataclasses.dataclass -class Artifact: - """A single artifact. In some cases, this artifact might be nested within another artifact.""" - - contents: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "contents"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - encoding: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "encoding"} - ) - hashes: Any = dataclasses.field(default=None, metadata={"schema_property_name": "hashes"}) - last_modified_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "lastModifiedTimeUtc"} - ) - length: int = dataclasses.field(default=-1, metadata={"schema_property_name": "length"}) - location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - mime_type: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "mimeType"} - ) - offset: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "offset"} - ) - parent_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "parentIndex"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - roles: Optional[ - List[ - Literal[ - "analysisTarget", - "attachment", - "responseFile", - "resultFile", - "standardStream", - "tracedFile", - "unmodified", - "modified", - "added", - "deleted", - "renamed", - "uncontrolled", - "driver", - "extension", - "translation", - "taxonomy", - "policy", - "referencedOnCommandLine", - "memoryContents", - "directory", - "userSpecifiedConfiguration", - "toolSpecifiedConfiguration", - "debugOutputFile", - ] - ] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "roles"}) - source_language: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "sourceLanguage"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_change.py b/onnxscript/diagnostics/infra/sarif/_artifact_change.py deleted file mode 100644 index 3db2c0444..000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact_change.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _property_bag, - _replacement, -) - - -@dataclasses.dataclass -class ArtifactChange: - """A change to a single artifact.""" - - artifact_location: _artifact_location.ArtifactLocation = dataclasses.field( - metadata={"schema_property_name": "artifactLocation"} - ) - replacements: List[_replacement.Replacement] = dataclasses.field( - metadata={"schema_property_name": "replacements"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_content.py b/onnxscript/diagnostics/infra/sarif/_artifact_content.py deleted file mode 100644 index 403806619..000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact_content.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _multiformat_message_string, - _property_bag, -) - - -@dataclasses.dataclass -class ArtifactContent: - """Represents the contents of an artifact.""" - - binary: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "binary"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - rendered: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "rendered"}) - ) - text: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "text"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_artifact_location.py b/onnxscript/diagnostics/infra/sarif/_artifact_location.py deleted file mode 100644 index ed6f9b391..000000000 --- a/onnxscript/diagnostics/infra/sarif/_artifact_location.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class ArtifactLocation: - """Specifies the location of an artifact.""" - - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "uri"} - ) - uri_base_id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "uriBaseId"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_attachment.py b/onnxscript/diagnostics/infra/sarif/_attachment.py deleted file mode 100644 index b58b858e0..000000000 --- a/onnxscript/diagnostics/infra/sarif/_attachment.py +++ /dev/null @@ -1,39 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _message, - _property_bag, - _rectangle, - _region, -) - - -@dataclasses.dataclass -class Attachment: - """An artifact relevant to a result.""" - - artifact_location: _artifact_location.ArtifactLocation = dataclasses.field( - metadata={"schema_property_name": "artifactLocation"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - rectangles: Optional[List[_rectangle.Rectangle]] = dataclasses.field( - default=None, metadata={"schema_property_name": "rectangles"} - ) - regions: Optional[List[_region.Region]] = dataclasses.field( - default=None, metadata={"schema_property_name": "regions"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_code_flow.py b/onnxscript/diagnostics/infra/sarif/_code_flow.py deleted file mode 100644 index 69615f18f..000000000 --- a/onnxscript/diagnostics/infra/sarif/_code_flow.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag, _thread_flow - - -@dataclasses.dataclass -class CodeFlow: - """A set of threadFlows which together describe a pattern of code execution relevant to detecting a result.""" - - thread_flows: List[_thread_flow.ThreadFlow] = dataclasses.field( - metadata={"schema_property_name": "threadFlows"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_configuration_override.py b/onnxscript/diagnostics/infra/sarif/_configuration_override.py deleted file mode 100644 index c2fa3ae0a..000000000 --- a/onnxscript/diagnostics/infra/sarif/_configuration_override.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _property_bag, - _reporting_configuration, - _reporting_descriptor_reference, -) - - -@dataclasses.dataclass -class ConfigurationOverride: - """Information about how a specific rule or notification was reconfigured at runtime.""" - - configuration: _reporting_configuration.ReportingConfiguration = dataclasses.field( - metadata={"schema_property_name": "configuration"} - ) - descriptor: _reporting_descriptor_reference.ReportingDescriptorReference = ( - dataclasses.field(metadata={"schema_property_name": "descriptor"}) - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_conversion.py b/onnxscript/diagnostics/infra/sarif/_conversion.py deleted file mode 100644 index 6078c525f..000000000 --- a/onnxscript/diagnostics/infra/sarif/_conversion.py +++ /dev/null @@ -1,35 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _invocation, - _property_bag, - _tool, -) - - -@dataclasses.dataclass -class Conversion: - """Describes how a converter transformed the output of a static analysis tool from the analysis tool's native output format into the SARIF format.""" - - tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"}) - analysis_tool_log_files: Optional[List[_artifact_location.ArtifactLocation]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "analysisToolLogFiles"} - ) - ) - invocation: Optional[_invocation.Invocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "invocation"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_edge.py b/onnxscript/diagnostics/infra/sarif/_edge.py deleted file mode 100644 index 1142e61dc..000000000 --- a/onnxscript/diagnostics/infra/sarif/_edge.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class Edge: - """Represents a directed edge in a graph.""" - - id: str = dataclasses.field(metadata={"schema_property_name": "id"}) - source_node_id: str = dataclasses.field(metadata={"schema_property_name": "sourceNodeId"}) - target_node_id: str = dataclasses.field(metadata={"schema_property_name": "targetNodeId"}) - label: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "label"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_edge_traversal.py b/onnxscript/diagnostics/infra/sarif/_edge_traversal.py deleted file mode 100644 index dbaba449e..000000000 --- a/onnxscript/diagnostics/infra/sarif/_edge_traversal.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class EdgeTraversal: - """Represents the traversal of a single edge during a graph traversal.""" - - edge_id: str = dataclasses.field(metadata={"schema_property_name": "edgeId"}) - final_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "finalState"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - step_over_edge_count: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "stepOverEdgeCount"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_exception.py b/onnxscript/diagnostics/infra/sarif/_exception.py deleted file mode 100644 index 71c0db73a..000000000 --- a/onnxscript/diagnostics/infra/sarif/_exception.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _exception, _property_bag, _stack - - -@dataclasses.dataclass -class Exception: - """Describes a runtime exception encountered during the execution of an analysis tool.""" - - inner_exceptions: Optional[List[_exception.Exception]] = dataclasses.field( - default=None, metadata={"schema_property_name": "innerExceptions"} - ) - kind: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "kind"} - ) - message: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - stack: Optional[_stack.Stack] = dataclasses.field( - default=None, metadata={"schema_property_name": "stack"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_external_properties.py b/onnxscript/diagnostics/infra/sarif/_external_properties.py deleted file mode 100644 index d63a16aff..000000000 --- a/onnxscript/diagnostics/infra/sarif/_external_properties.py +++ /dev/null @@ -1,96 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _address, - _artifact, - _conversion, - _graph, - _invocation, - _logical_location, - _property_bag, - _result, - _thread_flow_location, - _tool_component, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class ExternalProperties: - """The top-level element of an external property file.""" - - addresses: Optional[List[_address.Address]] = dataclasses.field( - default=None, metadata={"schema_property_name": "addresses"} - ) - artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field( - default=None, metadata={"schema_property_name": "artifacts"} - ) - conversion: Optional[_conversion.Conversion] = dataclasses.field( - default=None, metadata={"schema_property_name": "conversion"} - ) - driver: Optional[_tool_component.ToolComponent] = dataclasses.field( - default=None, metadata={"schema_property_name": "driver"} - ) - extensions: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "extensions"} - ) - externalized_properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "externalizedProperties"} - ) - graphs: Optional[List[_graph.Graph]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphs"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - invocations: Optional[List[_invocation.Invocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "invocations"} - ) - logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "logicalLocations"} - ) - policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "policies"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - results: Optional[List[_result.Result]] = dataclasses.field( - default=None, metadata={"schema_property_name": "results"} - ) - run_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "runGuid"} - ) - schema: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "schema"} - ) - taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "taxonomies"} - ) - thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "threadFlowLocations"} - ) - ) - translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "translations"} - ) - version: Optional[Literal["2.1.0"]] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - web_requests: Optional[List[_web_request.WebRequest]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequests"} - ) - web_responses: Optional[List[_web_response.WebResponse]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponses"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py b/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py deleted file mode 100644 index b5bfec032..000000000 --- a/onnxscript/diagnostics/infra/sarif/_external_property_file_reference.py +++ /dev/null @@ -1,30 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag - - -@dataclasses.dataclass -class ExternalPropertyFileReference: - """Contains information that enables a SARIF consumer to locate the external property file that contains the value of an externalized property associated with the run.""" - - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - item_count: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "itemCount"} - ) - location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py b/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py deleted file mode 100644 index d596a7a87..000000000 --- a/onnxscript/diagnostics/infra/sarif/_external_property_file_references.py +++ /dev/null @@ -1,76 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _external_property_file_reference, - _property_bag, -) - - -@dataclasses.dataclass -class ExternalPropertyFileReferences: - """References to external property files that should be inlined with the content of a root log file.""" - - addresses: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "addresses"}) - artifacts: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "artifacts"}) - conversion: Optional[_external_property_file_reference.ExternalPropertyFileReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "conversion"}) - ) - driver: Optional[_external_property_file_reference.ExternalPropertyFileReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "driver"}) - ) - extensions: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "extensions"}) - externalized_properties: Optional[ - _external_property_file_reference.ExternalPropertyFileReference - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "externalizedProperties"} - ) - graphs: Optional[List[_external_property_file_reference.ExternalPropertyFileReference]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "graphs"}) - ) - invocations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "invocations"}) - logical_locations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "logicalLocations"}) - policies: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "policies"}) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - results: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "results"}) - taxonomies: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "taxonomies"}) - thread_flow_locations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "threadFlowLocations"} - ) - translations: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "translations"}) - web_requests: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "webRequests"}) - web_responses: Optional[ - List[_external_property_file_reference.ExternalPropertyFileReference] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "webResponses"}) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_fix.py b/onnxscript/diagnostics/infra/sarif/_fix.py deleted file mode 100644 index 042f70f47..000000000 --- a/onnxscript/diagnostics/infra/sarif/_fix.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_change, _message, _property_bag - - -@dataclasses.dataclass -class Fix: - """A proposed fix for the problem represented by a result object. A fix specifies a set of artifacts to modify. For each artifact, it specifies a set of bytes to remove, and provides a set of new bytes to replace them.""" - - artifact_changes: List[_artifact_change.ArtifactChange] = dataclasses.field( - metadata={"schema_property_name": "artifactChanges"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_graph.py b/onnxscript/diagnostics/infra/sarif/_graph.py deleted file mode 100644 index f068e663d..000000000 --- a/onnxscript/diagnostics/infra/sarif/_graph.py +++ /dev/null @@ -1,30 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _edge, _message, _node, _property_bag - - -@dataclasses.dataclass -class Graph: - """A network of nodes and directed edges that describes some aspect of the structure of the code (for example, a call graph).""" - - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - edges: Optional[List[_edge.Edge]] = dataclasses.field( - default=None, metadata={"schema_property_name": "edges"} - ) - nodes: Optional[List[_node.Node]] = dataclasses.field( - default=None, metadata={"schema_property_name": "nodes"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_graph_traversal.py b/onnxscript/diagnostics/infra/sarif/_graph_traversal.py deleted file mode 100644 index ec9c92a9f..000000000 --- a/onnxscript/diagnostics/infra/sarif/_graph_traversal.py +++ /dev/null @@ -1,39 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import _edge_traversal, _message, _property_bag - - -@dataclasses.dataclass -class GraphTraversal: - """Represents a path through a graph.""" - - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - edge_traversals: Optional[List[_edge_traversal.EdgeTraversal]] = dataclasses.field( - default=None, metadata={"schema_property_name": "edgeTraversals"} - ) - immutable_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "immutableState"} - ) - initial_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "initialState"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - result_graph_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "resultGraphIndex"} - ) - run_graph_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "runGraphIndex"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_invocation.py b/onnxscript/diagnostics/infra/sarif/_invocation.py deleted file mode 100644 index 6f96c9a86..000000000 --- a/onnxscript/diagnostics/infra/sarif/_invocation.py +++ /dev/null @@ -1,111 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _configuration_override, - _notification, - _property_bag, -) - - -@dataclasses.dataclass -class Invocation: - """The runtime environment of the analysis tool run.""" - - execution_successful: bool = dataclasses.field( - metadata={"schema_property_name": "executionSuccessful"} - ) - account: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "account"} - ) - arguments: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "arguments"} - ) - command_line: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "commandLine"} - ) - end_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "endTimeUtc"} - ) - environment_variables: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "environmentVariables"} - ) - executable_location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "executableLocation"} - ) - exit_code: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitCode"} - ) - exit_code_description: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitCodeDescription"} - ) - exit_signal_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitSignalName"} - ) - exit_signal_number: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "exitSignalNumber"} - ) - machine: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "machine"} - ) - notification_configuration_overrides: Optional[ - List[_configuration_override.ConfigurationOverride] - ] = dataclasses.field( - default=None, - metadata={"schema_property_name": "notificationConfigurationOverrides"}, - ) - process_id: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "processId"} - ) - process_start_failure_message: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "processStartFailureMessage"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - response_files: Optional[List[_artifact_location.ArtifactLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "responseFiles"} - ) - rule_configuration_overrides: Optional[ - List[_configuration_override.ConfigurationOverride] - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "ruleConfigurationOverrides"} - ) - start_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "startTimeUtc"} - ) - stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stderr"} - ) - stdin: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stdin"} - ) - stdout: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stdout"} - ) - stdout_stderr: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "stdoutStderr"} - ) - tool_configuration_notifications: Optional[List[_notification.Notification]] = ( - dataclasses.field( - default=None, - metadata={"schema_property_name": "toolConfigurationNotifications"}, - ) - ) - tool_execution_notifications: Optional[List[_notification.Notification]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "toolExecutionNotifications"} - ) - ) - working_directory: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "workingDirectory"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_location.py b/onnxscript/diagnostics/infra/sarif/_location.py deleted file mode 100644 index 319856f8d..000000000 --- a/onnxscript/diagnostics/infra/sarif/_location.py +++ /dev/null @@ -1,44 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _location_relationship, - _logical_location, - _message, - _physical_location, - _property_bag, - _region, -) - - -@dataclasses.dataclass -class Location: - """A location within a programming artifact.""" - - annotations: Optional[List[_region.Region]] = dataclasses.field( - default=None, metadata={"schema_property_name": "annotations"} - ) - id: int = dataclasses.field(default=-1, metadata={"schema_property_name": "id"}) - logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "logicalLocations"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - physical_location: Optional[_physical_location.PhysicalLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "physicalLocation"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - relationships: Optional[List[_location_relationship.LocationRelationship]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "relationships"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_location_relationship.py b/onnxscript/diagnostics/infra/sarif/_location_relationship.py deleted file mode 100644 index 35ca00c8a..000000000 --- a/onnxscript/diagnostics/infra/sarif/_location_relationship.py +++ /dev/null @@ -1,28 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class LocationRelationship: - """Information about the relation of one location to another.""" - - target: int = dataclasses.field(metadata={"schema_property_name": "target"}) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - kinds: List[str] = dataclasses.field( - default_factory=lambda: ["relevant"], metadata={"schema_property_name": "kinds"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_logical_location.py b/onnxscript/diagnostics/infra/sarif/_logical_location.py deleted file mode 100644 index 7f2880eef..000000000 --- a/onnxscript/diagnostics/infra/sarif/_logical_location.py +++ /dev/null @@ -1,37 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class LogicalLocation: - """A logical location of a construct that produced a result.""" - - decorated_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "decoratedName"} - ) - fully_qualified_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullyQualifiedName"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - kind: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "kind"} - ) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - parent_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "parentIndex"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_message.py b/onnxscript/diagnostics/infra/sarif/_message.py deleted file mode 100644 index 0c9adce22..000000000 --- a/onnxscript/diagnostics/infra/sarif/_message.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class Message: - """Encapsulates a message intended to be read by the end user.""" - - arguments: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "arguments"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - markdown: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "markdown"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - text: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "text"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py b/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py deleted file mode 100644 index 154b9cc41..000000000 --- a/onnxscript/diagnostics/infra/sarif/_multiformat_message_string.py +++ /dev/null @@ -1,25 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class MultiformatMessageString: - """A message string or message format string rendered in multiple formats.""" - - text: str = dataclasses.field(metadata={"schema_property_name": "text"}) - markdown: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "markdown"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_node.py b/onnxscript/diagnostics/infra/sarif/_node.py deleted file mode 100644 index 0f11e3731..000000000 --- a/onnxscript/diagnostics/infra/sarif/_node.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _location, _message, _node, _property_bag - - -@dataclasses.dataclass -class Node: - """Represents a node in a graph.""" - - id: str = dataclasses.field(metadata={"schema_property_name": "id"}) - children: Optional[List[_node.Node]] = dataclasses.field( - default=None, metadata={"schema_property_name": "children"} - ) - label: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "label"} - ) - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_notification.py b/onnxscript/diagnostics/infra/sarif/_notification.py deleted file mode 100644 index f41a9f8d5..000000000 --- a/onnxscript/diagnostics/infra/sarif/_notification.py +++ /dev/null @@ -1,49 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _exception, - _location, - _message, - _property_bag, - _reporting_descriptor_reference, -) - - -@dataclasses.dataclass -class Notification: - """Describes a condition relevant to the tool itself, as opposed to being relevant to a target being analyzed by the tool.""" - - message: _message.Message = dataclasses.field(metadata={"schema_property_name": "message"}) - associated_rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "associatedRule"}) - ) - descriptor: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "descriptor"}) - ) - exception: Optional[_exception.Exception] = dataclasses.field( - default=None, metadata={"schema_property_name": "exception"} - ) - level: Literal["none", "note", "warning", "error"] = dataclasses.field( - default="warning", metadata={"schema_property_name": "level"} - ) - locations: Optional[List[_location.Location]] = dataclasses.field( - default=None, metadata={"schema_property_name": "locations"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - thread_id: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "threadId"} - ) - time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "timeUtc"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_physical_location.py b/onnxscript/diagnostics/infra/sarif/_physical_location.py deleted file mode 100644 index 357e85af4..000000000 --- a/onnxscript/diagnostics/infra/sarif/_physical_location.py +++ /dev/null @@ -1,38 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _address, - _artifact_location, - _property_bag, - _region, -) - - -@dataclasses.dataclass -class PhysicalLocation: - """A physical location relevant to a result. Specifies a reference to a programming artifact together with a range of bytes or characters within that artifact.""" - - address: Optional[_address.Address] = dataclasses.field( - default=None, metadata={"schema_property_name": "address"} - ) - artifact_location: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "artifactLocation"} - ) - context_region: Optional[_region.Region] = dataclasses.field( - default=None, metadata={"schema_property_name": "contextRegion"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - region: Optional[_region.Region] = dataclasses.field( - default=None, metadata={"schema_property_name": "region"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_property_bag.py b/onnxscript/diagnostics/infra/sarif/_property_bag.py deleted file mode 100644 index 0b95c6e6e..000000000 --- a/onnxscript/diagnostics/infra/sarif/_property_bag.py +++ /dev/null @@ -1,19 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - - -@dataclasses.dataclass -class PropertyBag: - """Key/value pairs that provide additional information about the object.""" - - tags: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "tags"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_rectangle.py b/onnxscript/diagnostics/infra/sarif/_rectangle.py deleted file mode 100644 index a7c9aecd1..000000000 --- a/onnxscript/diagnostics/infra/sarif/_rectangle.py +++ /dev/null @@ -1,36 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class Rectangle: - """An area within an image.""" - - bottom: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "bottom"} - ) - left: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "left"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - right: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "right"} - ) - top: Optional[float] = dataclasses.field( - default=None, metadata={"schema_property_name": "top"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_region.py b/onnxscript/diagnostics/infra/sarif/_region.py deleted file mode 100644 index 35a4b7f31..000000000 --- a/onnxscript/diagnostics/infra/sarif/_region.py +++ /dev/null @@ -1,58 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_content, - _message, - _property_bag, -) - - -@dataclasses.dataclass -class Region: - """A region within an artifact where a result was detected.""" - - byte_length: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "byteLength"} - ) - byte_offset: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "byteOffset"} - ) - char_length: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "charLength"} - ) - char_offset: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "charOffset"} - ) - end_column: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "endColumn"} - ) - end_line: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "endLine"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - snippet: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "snippet"} - ) - source_language: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "sourceLanguage"} - ) - start_column: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "startColumn"} - ) - start_line: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "startLine"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_replacement.py b/onnxscript/diagnostics/infra/sarif/_replacement.py deleted file mode 100644 index 125ed7570..000000000 --- a/onnxscript/diagnostics/infra/sarif/_replacement.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag, _region - - -@dataclasses.dataclass -class Replacement: - """The replacement of a single region of an artifact.""" - - deleted_region: _region.Region = dataclasses.field( - metadata={"schema_property_name": "deletedRegion"} - ) - inserted_content: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "insertedContent"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py b/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py deleted file mode 100644 index e3da0a77b..000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_configuration.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Literal, Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class ReportingConfiguration: - """Information about a rule or notification that can be configured at runtime.""" - - enabled: bool = dataclasses.field( - default=True, metadata={"schema_property_name": "enabled"} - ) - level: Literal["none", "note", "warning", "error"] = dataclasses.field( - default="warning", metadata={"schema_property_name": "level"} - ) - parameters: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "parameters"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - rank: float = dataclasses.field(default=-1.0, metadata={"schema_property_name": "rank"}) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py deleted file mode 100644 index 85e14f376..000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor.py +++ /dev/null @@ -1,65 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _multiformat_message_string, - _property_bag, - _reporting_configuration, - _reporting_descriptor_relationship, -) - - -@dataclasses.dataclass -class ReportingDescriptor: - """Metadata that describes a specific report produced by the tool, as part of the analysis it provides or its runtime reporting.""" - - id: str = dataclasses.field(metadata={"schema_property_name": "id"}) - default_configuration: Optional[_reporting_configuration.ReportingConfiguration] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "defaultConfiguration"} - ) - ) - deprecated_guids: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "deprecatedGuids"} - ) - deprecated_ids: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "deprecatedIds"} - ) - deprecated_names: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "deprecatedNames"} - ) - full_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"}) - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - help: Optional[_multiformat_message_string.MultiformatMessageString] = dataclasses.field( - default=None, metadata={"schema_property_name": "help"} - ) - help_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "helpUri"} - ) - message_strings: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "messageStrings"} - ) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - relationships: Optional[ - List[_reporting_descriptor_relationship.ReportingDescriptorRelationship] - ] = dataclasses.field(default=None, metadata={"schema_property_name": "relationships"}) - short_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py deleted file mode 100644 index f4e6f2260..000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_reference.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag, _tool_component_reference - - -@dataclasses.dataclass -class ReportingDescriptorReference: - """Information about how to locate a relevant reporting descriptor.""" - - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - tool_component: Optional[_tool_component_reference.ToolComponentReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "toolComponent"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py b/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py deleted file mode 100644 index 52db517db..000000000 --- a/onnxscript/diagnostics/infra/sarif/_reporting_descriptor_relationship.py +++ /dev/null @@ -1,34 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _message, - _property_bag, - _reporting_descriptor_reference, -) - - -@dataclasses.dataclass -class ReportingDescriptorRelationship: - """Information about the relation of one reporting descriptor to another.""" - - target: _reporting_descriptor_reference.ReportingDescriptorReference = dataclasses.field( - metadata={"schema_property_name": "target"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - kinds: List[str] = dataclasses.field( - default_factory=lambda: ["relevant"], metadata={"schema_property_name": "kinds"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_result.py b/onnxscript/diagnostics/infra/sarif/_result.py deleted file mode 100644 index 3dfa564b5..000000000 --- a/onnxscript/diagnostics/infra/sarif/_result.py +++ /dev/null @@ -1,120 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _attachment, - _code_flow, - _fix, - _graph, - _graph_traversal, - _location, - _message, - _property_bag, - _reporting_descriptor_reference, - _result_provenance, - _stack, - _suppression, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class Result: - """A result produced by an analysis tool.""" - - message: _message.Message = dataclasses.field(metadata={"schema_property_name": "message"}) - analysis_target: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "analysisTarget"} - ) - attachments: Optional[List[_attachment.Attachment]] = dataclasses.field( - default=None, metadata={"schema_property_name": "attachments"} - ) - baseline_state: Optional[Literal["new", "unchanged", "updated", "absent"]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "baselineState"}) - ) - code_flows: Optional[List[_code_flow.CodeFlow]] = dataclasses.field( - default=None, metadata={"schema_property_name": "codeFlows"} - ) - correlation_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "correlationGuid"} - ) - fingerprints: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "fingerprints"} - ) - fixes: Optional[List[_fix.Fix]] = dataclasses.field( - default=None, metadata={"schema_property_name": "fixes"} - ) - graph_traversals: Optional[List[_graph_traversal.GraphTraversal]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphTraversals"} - ) - graphs: Optional[List[_graph.Graph]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphs"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - hosted_viewer_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "hostedViewerUri"} - ) - kind: Literal["notApplicable", "pass", "fail", "review", "open", "informational"] = ( - dataclasses.field(default="fail", metadata={"schema_property_name": "kind"}) - ) - level: Literal["none", "note", "warning", "error"] = dataclasses.field( - default="warning", metadata={"schema_property_name": "level"} - ) - locations: Optional[List[_location.Location]] = dataclasses.field( - default=None, metadata={"schema_property_name": "locations"} - ) - occurrence_count: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "occurrenceCount"} - ) - partial_fingerprints: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "partialFingerprints"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - provenance: Optional[_result_provenance.ResultProvenance] = dataclasses.field( - default=None, metadata={"schema_property_name": "provenance"} - ) - rank: float = dataclasses.field(default=-1.0, metadata={"schema_property_name": "rank"}) - related_locations: Optional[List[_location.Location]] = dataclasses.field( - default=None, metadata={"schema_property_name": "relatedLocations"} - ) - rule: Optional[_reporting_descriptor_reference.ReportingDescriptorReference] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "rule"}) - ) - rule_id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "ruleId"} - ) - rule_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "ruleIndex"} - ) - stacks: Optional[List[_stack.Stack]] = dataclasses.field( - default=None, metadata={"schema_property_name": "stacks"} - ) - suppressions: Optional[List[_suppression.Suppression]] = dataclasses.field( - default=None, metadata={"schema_property_name": "suppressions"} - ) - taxa: Optional[List[_reporting_descriptor_reference.ReportingDescriptorReference]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "taxa"}) - ) - web_request: Optional[_web_request.WebRequest] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequest"} - ) - web_response: Optional[_web_response.WebResponse] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponse"} - ) - work_item_uris: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "workItemUris"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_result_provenance.py b/onnxscript/diagnostics/infra/sarif/_result_provenance.py deleted file mode 100644 index 74ea9e1e9..000000000 --- a/onnxscript/diagnostics/infra/sarif/_result_provenance.py +++ /dev/null @@ -1,39 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _physical_location, _property_bag - - -@dataclasses.dataclass -class ResultProvenance: - """Contains information about how and when a result was detected.""" - - conversion_sources: Optional[List[_physical_location.PhysicalLocation]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "conversionSources"}) - ) - first_detection_run_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "firstDetectionRunGuid"} - ) - first_detection_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "firstDetectionTimeUtc"} - ) - invocation_index: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "invocationIndex"} - ) - last_detection_run_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "lastDetectionRunGuid"} - ) - last_detection_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "lastDetectionTimeUtc"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_run.py b/onnxscript/diagnostics/infra/sarif/_run.py deleted file mode 100644 index 8df4f9b57..000000000 --- a/onnxscript/diagnostics/infra/sarif/_run.py +++ /dev/null @@ -1,126 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _address, - _artifact, - _conversion, - _external_property_file_references, - _graph, - _invocation, - _logical_location, - _property_bag, - _result, - _run_automation_details, - _special_locations, - _thread_flow_location, - _tool, - _tool_component, - _version_control_details, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class Run: - """Describes a single run of an analysis tool, and contains the reported output of that run.""" - - tool: _tool.Tool = dataclasses.field(metadata={"schema_property_name": "tool"}) - addresses: Optional[List[_address.Address]] = dataclasses.field( - default=None, metadata={"schema_property_name": "addresses"} - ) - artifacts: Optional[List[_artifact.Artifact]] = dataclasses.field( - default=None, metadata={"schema_property_name": "artifacts"} - ) - automation_details: Optional[_run_automation_details.RunAutomationDetails] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "automationDetails"}) - ) - baseline_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "baselineGuid"} - ) - column_kind: Optional[Literal["utf16CodeUnits", "unicodeCodePoints"]] = dataclasses.field( - default=None, metadata={"schema_property_name": "columnKind"} - ) - conversion: Optional[_conversion.Conversion] = dataclasses.field( - default=None, metadata={"schema_property_name": "conversion"} - ) - default_encoding: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "defaultEncoding"} - ) - default_source_language: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "defaultSourceLanguage"} - ) - external_property_file_references: Optional[ - _external_property_file_references.ExternalPropertyFileReferences - ] = dataclasses.field( - default=None, - metadata={"schema_property_name": "externalPropertyFileReferences"}, - ) - graphs: Optional[List[_graph.Graph]] = dataclasses.field( - default=None, metadata={"schema_property_name": "graphs"} - ) - invocations: Optional[List[_invocation.Invocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "invocations"} - ) - language: str = dataclasses.field( - default="en-US", metadata={"schema_property_name": "language"} - ) - logical_locations: Optional[List[_logical_location.LogicalLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "logicalLocations"} - ) - newline_sequences: List[str] = dataclasses.field( - default_factory=lambda: ["\r\n", "\n"], - metadata={"schema_property_name": "newlineSequences"}, - ) - original_uri_base_ids: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "originalUriBaseIds"} - ) - policies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "policies"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - redaction_tokens: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "redactionTokens"} - ) - results: Optional[List[_result.Result]] = dataclasses.field( - default=None, metadata={"schema_property_name": "results"} - ) - run_aggregates: Optional[List[_run_automation_details.RunAutomationDetails]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "runAggregates"}) - ) - special_locations: Optional[_special_locations.SpecialLocations] = dataclasses.field( - default=None, metadata={"schema_property_name": "specialLocations"} - ) - taxonomies: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "taxonomies"} - ) - thread_flow_locations: Optional[List[_thread_flow_location.ThreadFlowLocation]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "threadFlowLocations"} - ) - ) - translations: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "translations"} - ) - version_control_provenance: Optional[ - List[_version_control_details.VersionControlDetails] - ] = dataclasses.field( - default=None, metadata={"schema_property_name": "versionControlProvenance"} - ) - web_requests: Optional[List[_web_request.WebRequest]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequests"} - ) - web_responses: Optional[List[_web_response.WebResponse]] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponses"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_run_automation_details.py b/onnxscript/diagnostics/infra/sarif/_run_automation_details.py deleted file mode 100644 index f41dfcc28..000000000 --- a/onnxscript/diagnostics/infra/sarif/_run_automation_details.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag - - -@dataclasses.dataclass -class RunAutomationDetails: - """Information that describes a run's identity and role within an engineering system process.""" - - correlation_guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "correlationGuid"} - ) - description: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "description"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_sarif_log.py b/onnxscript/diagnostics/infra/sarif/_sarif_log.py deleted file mode 100644 index aa39c52f1..000000000 --- a/onnxscript/diagnostics/infra/sarif/_sarif_log.py +++ /dev/null @@ -1,31 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import _external_properties, _property_bag, _run - - -@dataclasses.dataclass -class SarifLog: - """Static Analysis Results Format (SARIF) Version 2.1.0 JSON Schema: a standard format for the output of static analysis tools.""" - - runs: List[_run.Run] = dataclasses.field(metadata={"schema_property_name": "runs"}) - version: Literal["2.1.0"] = dataclasses.field(metadata={"schema_property_name": "version"}) - schema_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "$schema"} - ) - inline_external_properties: Optional[List[_external_properties.ExternalProperties]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "inlineExternalProperties"} - ) - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_special_locations.py b/onnxscript/diagnostics/infra/sarif/_special_locations.py deleted file mode 100644 index ee7897951..000000000 --- a/onnxscript/diagnostics/infra/sarif/_special_locations.py +++ /dev/null @@ -1,24 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag - - -@dataclasses.dataclass -class SpecialLocations: - """Defines locations of special significance to SARIF consumers.""" - - display_base: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "displayBase"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_stack.py b/onnxscript/diagnostics/infra/sarif/_stack.py deleted file mode 100644 index e250b75df..000000000 --- a/onnxscript/diagnostics/infra/sarif/_stack.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _message, _property_bag, _stack_frame - - -@dataclasses.dataclass -class Stack: - """A call stack that is relevant to a result.""" - - frames: List[_stack_frame.StackFrame] = dataclasses.field( - metadata={"schema_property_name": "frames"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_stack_frame.py b/onnxscript/diagnostics/infra/sarif/_stack_frame.py deleted file mode 100644 index 24d9fe820..000000000 --- a/onnxscript/diagnostics/infra/sarif/_stack_frame.py +++ /dev/null @@ -1,33 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _location, _property_bag - - -@dataclasses.dataclass -class StackFrame: - """A function call within a stack trace.""" - - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - module: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "module"} - ) - parameters: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "parameters"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - thread_id: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "threadId"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_suppression.py b/onnxscript/diagnostics/infra/sarif/_suppression.py deleted file mode 100644 index ae477178b..000000000 --- a/onnxscript/diagnostics/infra/sarif/_suppression.py +++ /dev/null @@ -1,36 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Literal, Optional - -from onnxscript.diagnostics.infra.sarif import _location, _property_bag - - -@dataclasses.dataclass -class Suppression: - """A suppression that is relevant to a result.""" - - kind: Literal["inSource", "external"] = dataclasses.field( - metadata={"schema_property_name": "kind"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - justification: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "justification"} - ) - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - state: Optional[Literal["accepted", "underReview", "rejected"]] = dataclasses.field( - default=None, metadata={"schema_property_name": "state"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_thread_flow.py b/onnxscript/diagnostics/infra/sarif/_thread_flow.py deleted file mode 100644 index d3d169367..000000000 --- a/onnxscript/diagnostics/infra/sarif/_thread_flow.py +++ /dev/null @@ -1,40 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _message, - _property_bag, - _thread_flow_location, -) - - -@dataclasses.dataclass -class ThreadFlow: - """Describes a sequence of code locations that specify a path through a single thread of execution such as an operating system or fiber.""" - - locations: List[_thread_flow_location.ThreadFlowLocation] = dataclasses.field( - metadata={"schema_property_name": "locations"} - ) - id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "id"} - ) - immutable_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "immutableState"} - ) - initial_state: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "initialState"} - ) - message: Optional[_message.Message] = dataclasses.field( - default=None, metadata={"schema_property_name": "message"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py b/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py deleted file mode 100644 index 949c42d80..000000000 --- a/onnxscript/diagnostics/infra/sarif/_thread_flow_location.py +++ /dev/null @@ -1,63 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _location, - _property_bag, - _reporting_descriptor_reference, - _stack, - _web_request, - _web_response, -) - - -@dataclasses.dataclass -class ThreadFlowLocation: - """A location visited by an analysis tool while simulating or monitoring the execution of a program.""" - - execution_order: int = dataclasses.field( - default=-1, metadata={"schema_property_name": "executionOrder"} - ) - execution_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "executionTimeUtc"} - ) - importance: Literal["important", "essential", "unimportant"] = dataclasses.field( - default="important", metadata={"schema_property_name": "importance"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - kinds: Optional[List[str]] = dataclasses.field( - default=None, metadata={"schema_property_name": "kinds"} - ) - location: Optional[_location.Location] = dataclasses.field( - default=None, metadata={"schema_property_name": "location"} - ) - module: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "module"} - ) - nesting_level: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "nestingLevel"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - stack: Optional[_stack.Stack] = dataclasses.field( - default=None, metadata={"schema_property_name": "stack"} - ) - state: Any = dataclasses.field(default=None, metadata={"schema_property_name": "state"}) - taxa: Optional[List[_reporting_descriptor_reference.ReportingDescriptorReference]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "taxa"}) - ) - web_request: Optional[_web_request.WebRequest] = dataclasses.field( - default=None, metadata={"schema_property_name": "webRequest"} - ) - web_response: Optional[_web_response.WebResponse] = dataclasses.field( - default=None, metadata={"schema_property_name": "webResponse"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_tool.py b/onnxscript/diagnostics/infra/sarif/_tool.py deleted file mode 100644 index 79589ddf7..000000000 --- a/onnxscript/diagnostics/infra/sarif/_tool.py +++ /dev/null @@ -1,27 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import List, Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag, _tool_component - - -@dataclasses.dataclass -class Tool: - """The analysis tool that was run.""" - - driver: _tool_component.ToolComponent = dataclasses.field( - metadata={"schema_property_name": "driver"} - ) - extensions: Optional[List[_tool_component.ToolComponent]] = dataclasses.field( - default=None, metadata={"schema_property_name": "extensions"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_tool_component.py b/onnxscript/diagnostics/infra/sarif/_tool_component.py deleted file mode 100644 index 47925ed74..000000000 --- a/onnxscript/diagnostics/infra/sarif/_tool_component.py +++ /dev/null @@ -1,115 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, List, Literal, Optional - -from onnxscript.diagnostics.infra.sarif import ( - _artifact_location, - _multiformat_message_string, - _property_bag, - _reporting_descriptor, - _tool_component_reference, - _translation_metadata, -) - - -@dataclasses.dataclass -class ToolComponent: - """A component, such as a plug-in or the driver, of the analysis tool that was run.""" - - name: str = dataclasses.field(metadata={"schema_property_name": "name"}) - associated_component: Optional[_tool_component_reference.ToolComponentReference] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "associatedComponent"} - ) - ) - contents: List[Literal["localizedData", "nonLocalizedData"]] = dataclasses.field( - default_factory=lambda: ["localizedData", "nonLocalizedData"], - metadata={"schema_property_name": "contents"}, - ) - dotted_quad_file_version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "dottedQuadFileVersion"} - ) - download_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "downloadUri"} - ) - full_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"}) - ) - full_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullName"} - ) - global_message_strings: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "globalMessageStrings"} - ) - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - information_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "informationUri"} - ) - is_comprehensive: Optional[bool] = dataclasses.field( - default=None, metadata={"schema_property_name": "isComprehensive"} - ) - language: str = dataclasses.field( - default="en-US", metadata={"schema_property_name": "language"} - ) - localized_data_semantic_version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "localizedDataSemanticVersion"} - ) - locations: Optional[List[_artifact_location.ArtifactLocation]] = dataclasses.field( - default=None, metadata={"schema_property_name": "locations"} - ) - minimum_required_localized_data_semantic_version: Optional[str] = dataclasses.field( - default=None, - metadata={"schema_property_name": "minimumRequiredLocalizedDataSemanticVersion"}, - ) - notifications: Optional[List[_reporting_descriptor.ReportingDescriptor]] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "notifications"}) - ) - organization: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "organization"} - ) - product: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "product"} - ) - product_suite: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "productSuite"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - release_date_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "releaseDateUtc"} - ) - rules: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field( - default=None, metadata={"schema_property_name": "rules"} - ) - semantic_version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "semanticVersion"} - ) - short_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"}) - ) - supported_taxonomies: Optional[List[_tool_component_reference.ToolComponentReference]] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "supportedTaxonomies"} - ) - ) - taxa: Optional[List[_reporting_descriptor.ReportingDescriptor]] = dataclasses.field( - default=None, metadata={"schema_property_name": "taxa"} - ) - translation_metadata: Optional[_translation_metadata.TranslationMetadata] = ( - dataclasses.field( - default=None, metadata={"schema_property_name": "translationMetadata"} - ) - ) - version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py b/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py deleted file mode 100644 index 09cc2b908..000000000 --- a/onnxscript/diagnostics/infra/sarif/_tool_component_reference.py +++ /dev/null @@ -1,28 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _property_bag - - -@dataclasses.dataclass -class ToolComponentReference: - """Identifies a particular toolComponent object, either the driver or an extension.""" - - guid: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "guid"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "name"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_translation_metadata.py b/onnxscript/diagnostics/infra/sarif/_translation_metadata.py deleted file mode 100644 index f05125a59..000000000 --- a/onnxscript/diagnostics/infra/sarif/_translation_metadata.py +++ /dev/null @@ -1,40 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import ( - _multiformat_message_string, - _property_bag, -) - - -@dataclasses.dataclass -class TranslationMetadata: - """Provides additional metadata related to translation.""" - - name: str = dataclasses.field(metadata={"schema_property_name": "name"}) - download_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "downloadUri"} - ) - full_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "fullDescription"}) - ) - full_name: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "fullName"} - ) - information_uri: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "informationUri"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - short_description: Optional[_multiformat_message_string.MultiformatMessageString] = ( - dataclasses.field(default=None, metadata={"schema_property_name": "shortDescription"}) - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_version_control_details.py b/onnxscript/diagnostics/infra/sarif/_version_control_details.py deleted file mode 100644 index f56498bb6..000000000 --- a/onnxscript/diagnostics/infra/sarif/_version_control_details.py +++ /dev/null @@ -1,37 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_location, _property_bag - - -@dataclasses.dataclass -class VersionControlDetails: - """Specifies the information necessary to retrieve a desired revision from a version control system.""" - - repository_uri: str = dataclasses.field(metadata={"schema_property_name": "repositoryUri"}) - as_of_time_utc: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "asOfTimeUtc"} - ) - branch: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "branch"} - ) - mapped_to: Optional[_artifact_location.ArtifactLocation] = dataclasses.field( - default=None, metadata={"schema_property_name": "mappedTo"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - revision_id: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "revisionId"} - ) - revision_tag: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "revisionTag"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_web_request.py b/onnxscript/diagnostics/infra/sarif/_web_request.py deleted file mode 100644 index b574882f9..000000000 --- a/onnxscript/diagnostics/infra/sarif/_web_request.py +++ /dev/null @@ -1,43 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag - - -@dataclasses.dataclass -class WebRequest: - """Describes an HTTP request.""" - - body: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "body"} - ) - headers: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "headers"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - method: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "method"} - ) - parameters: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "parameters"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - protocol: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "protocol"} - ) - target: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "target"} - ) - version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/_web_response.py b/onnxscript/diagnostics/infra/sarif/_web_response.py deleted file mode 100644 index 3753036ab..000000000 --- a/onnxscript/diagnostics/infra/sarif/_web_response.py +++ /dev/null @@ -1,43 +0,0 @@ -# DO NOT EDIT! This file was generated by jschema_to_python version 0.0.1.dev29, -# with extension for dataclasses and type annotation. - -from __future__ import annotations - -import dataclasses -from typing import Any, Optional - -from onnxscript.diagnostics.infra.sarif import _artifact_content, _property_bag - - -@dataclasses.dataclass -class WebResponse: - """Describes the response to an HTTP request.""" - - body: Optional[_artifact_content.ArtifactContent] = dataclasses.field( - default=None, metadata={"schema_property_name": "body"} - ) - headers: Any = dataclasses.field( - default=None, metadata={"schema_property_name": "headers"} - ) - index: int = dataclasses.field(default=-1, metadata={"schema_property_name": "index"}) - no_response_received: Optional[bool] = dataclasses.field( - default=None, metadata={"schema_property_name": "noResponseReceived"} - ) - properties: Optional[_property_bag.PropertyBag] = dataclasses.field( - default=None, metadata={"schema_property_name": "properties"} - ) - protocol: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "protocol"} - ) - reason_phrase: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "reasonPhrase"} - ) - status_code: Optional[int] = dataclasses.field( - default=None, metadata={"schema_property_name": "statusCode"} - ) - version: Optional[str] = dataclasses.field( - default=None, metadata={"schema_property_name": "version"} - ) - - -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/sarif/version.py b/onnxscript/diagnostics/infra/sarif/version.py deleted file mode 100644 index 020a28bf7..000000000 --- a/onnxscript/diagnostics/infra/sarif/version.py +++ /dev/null @@ -1,7 +0,0 @@ -from typing import Final - -SARIF_VERSION: Final = "2.1.0" -SARIF_SCHEMA_LINK: Final = ( - "https://docs.oasis-open.org/sarif/sarif/v2.1.0/cs01/schemas/sarif-schema-2.1.0.json" -) -# flake8: noqa diff --git a/onnxscript/diagnostics/infra/utils.py b/onnxscript/diagnostics/infra/utils.py deleted file mode 100644 index 463fc3ea0..000000000 --- a/onnxscript/diagnostics/infra/utils.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import functools -import inspect -import traceback -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple - -from onnxscript._internal import runtime_typing -from onnxscript.diagnostics.infra import _infra, formatter - - -@runtime_typing.checked -def python_frame(frame: traceback.FrameSummary) -> _infra.StackFrame: - """Returns a StackFrame for the given traceback.FrameSummary.""" - snippet = frame.line - - return _infra.StackFrame( - location=_infra.Location( - uri=frame.filename, - line=frame.lineno, - snippet=snippet, - function=frame.name, - message=snippet, - ) - ) - - -@runtime_typing.checked -def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack: - """Returns the current Python call stack.""" - if frames_to_skip < 0: - raise ValueError("frames_to_skip must be non-negative") - if frames_to_log < 0: - raise ValueError("frames_to_log must be non-negative") - frames_to_skip += 2 # Skip this function and beartype. - stack = _infra.Stack() - # Frames are returned in order of oldest to newest. - frames = traceback.extract_stack(limit=frames_to_skip + frames_to_log) - frames.reverse() - stack.frames = [python_frame(frame) for frame in frames[frames_to_skip:]] - stack.message = "Python call stack" - return stack - - -@functools.lru_cache -def _function_source_info(fn: Callable) -> Tuple[Sequence[str], int, Optional[str]]: - """Returns the source lines, line number, and source file path for the given function. - - Essentially, inspect.getsourcelines() and inspect.getsourcefile() combined. - Caching is applied to reduce the performance impact of this function. - """ - source_lines, lineno = inspect.getsourcelines(fn) - return source_lines, lineno, inspect.getsourcefile(fn) - - -@runtime_typing.checked -def function_location(fn: Callable) -> _infra.Location: - """Returns a Location for the given function.""" - source_lines, lineno, uri = _function_source_info(fn) - snippet = source_lines[0].strip() if len(source_lines) > 0 else "" - return _infra.Location( - uri=uri, - line=lineno, - snippet=snippet, - message=formatter.display_name(fn), - ) - - -@runtime_typing.checked -def function_state( - fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any] -) -> Mapping[str, Any]: - bind = inspect.signature(fn).bind(*args, **kwargs) - return bind.arguments diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index daf63d86a..f59505ccc 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -832,7 +832,7 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): existing_value_info = {info.name: info for info in onnx_model.graph.value_info} # Override value_info for top level graph inputs. - for input in self.torch_graph.inputs(): + for input in self.torch_graph.inputs(): # pylint: disable=not-an-iterable if input not in self._value_to_tensor: raise RuntimeError(f"Input '{input.debugName()}' has no type.") tensor = self._value_to_tensor[input] @@ -847,7 +847,7 @@ def _override_with_symbolic_value_info_proto(self, onnx_model: onnx.ModelProto): break # Override value_info for top level graph outputs. - for output in self.torch_graph.outputs(): + for output in self.torch_graph.outputs(): # pylint: disable=not-an-iterable if output not in self._value_to_tensor: raise RuntimeError(f"Output '{output.debugName()}' has no type.") tensor = self._value_to_tensor[output] diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 584c178d5..1145e9b13 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1495,10 +1495,19 @@ def aten_bitwise_xor(self: TInt, other: TInt) -> TInt: return op.BitwiseXor(self, other) -def aten_blackman_window(window_length: int) -> TensorType: +@torch_op("aten::blackman_window", trace_only=True) +def aten_blackman_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + return op.BlackmanWindow(window_length, output_datatype=dtype) def aten_block_diag(tensors: Sequence[TensorType]) -> TensorType: @@ -3921,16 +3930,38 @@ def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL: return op.And(self, op.Not(other)) -def aten_hamming_window(window_length: int) -> TensorType: +@torch_op("aten::hamming_window", trace_only=True) +def aten_hamming_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + # ONNX uses different alpha/beta values for the Hamming window + # Whereas PyTorch uses alpha=0.54, beta=0.46, ONNX uses + # alpha=0.543478, beta=0.456522. This causes a slight difference + # in the output values, but we still uses the HammingWindow op for performance. + return op.HammingWindow(window_length, output_datatype=dtype) -def aten_hann_window(window_length: int) -> TensorType: +@torch_op("aten::hann_window", trace_only=True) +def aten_hann_window( + window_length: int, + dtype: int = 1, + layout: str = "", + device: str = "", + pin_memory: bool = False, +) -> TensorType: """hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor""" - raise NotImplementedError() + if dtype is None or dtype == -1: + dtype = 1 + return op.HannWindow(window_length, output_datatype=dtype) def aten_hardshrink(self: TensorType, lambd: float = 0.5) -> TensorType: diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 661a5cd82..1ecfa0911 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -300,6 +300,29 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) -> return default +@register("Reshape") +def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue: + """Replace a Reshape node by Identity when applicable.""" + input = _get_input(node, 0) + shape = _get_input(node, 1) + if input is None or shape is None: + return None + input_shape = input.shape + if input_shape is None: + return None + input_shape_dims = list(input_shape.dims) + if any(not isinstance(dim, int) for dim in input_shape_dims): + return None + shape_value = _get_numpy_value(shape) + if shape_value is None: + return None + target_shape_dims = shape_value.tolist() + if input_shape_dims == target_shape_dims: + # No need to check for special values like -1, 0, etc. here + return op.Identity(input) + return None + + @register("Cast") def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue: input = _get_input(node, 0) diff --git a/onnxscript/rewriter/_ir_utils.py b/onnxscript/rewriter/_ir_utils.py index 7c303556a..1d657a5ab 100644 --- a/onnxscript/rewriter/_ir_utils.py +++ b/onnxscript/rewriter/_ir_utils.py @@ -2,6 +2,9 @@ # Licensed under the MIT License. from __future__ import annotations +import math +from typing import Callable + import numpy as np import onnxscript.ir as ir @@ -77,3 +80,27 @@ def get_singleton_value(val: ir.Value | None): if np_val is not None and np_val.size == 1: return np_val.item() return None + + +def is_singleton_value( + val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None +) -> bool: + """Returns True if the value is a single element tensor with given value, and False otherwise.""" + scalar = get_singleton_value(val) + if scalar is None: + return False + if callable(expected): + return expected(scalar) + if isinstance(expected, int): + return expected == scalar + # rtol must be specified for float comparison + assert rtol is not None + return math.isclose(scalar, expected, rel_tol=rtol) + + +def has_rank(value: ir.Value | None, rank: int) -> bool: + """Returns True if the value is statically known to have the given rank, and False otherwise.""" + if value is None: + return False + shape = value.shape + return (shape is not None) and (shape.rank() == rank) diff --git a/onnxscript/rewriter/generic_pattern.py b/onnxscript/rewriter/generic_pattern.py index 2926f5964..de06d7a22 100644 --- a/onnxscript/rewriter/generic_pattern.py +++ b/onnxscript/rewriter/generic_pattern.py @@ -551,7 +551,12 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> orp.MatchResult | None: + if not remove_nodes: + raise NotImplementedError( + "remove_nodes=False is not implemented in GenericPatternMatcher" + ) del model del graph_or_function self.verbose = verbose diff --git a/onnxscript/rewriter/onnxruntime/xformers/__init__.py b/onnxscript/rewriter/onnxruntime/xformers/__init__.py index 44b5591d8..43cec1352 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/__init__.py +++ b/onnxscript/rewriter/onnxruntime/xformers/__init__.py @@ -1,3 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from __future__ import annotations + +__all__ = [ + "fuse_rms_normalization", + "fuse_normalization", + "fuse_rotary_embedding", + "fuse_cos_sin_cache", +] + +from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding +from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization diff --git a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py index 0b4e2c55f..b9ed0aecf 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py +++ b/onnxscript/rewriter/onnxruntime/xformers/_test_utils.py @@ -25,7 +25,7 @@ def ort_run(model_name: str, model, inputs): providers = ["CPUExecutionProvider"] with tempfile.TemporaryDirectory() as temp_dir: model_path = os.path.join(temp_dir, f"{model_name}.onnx") - io.save(model, model_path) + _save(model, model_path) # Run model session = onnxruntime.InferenceSession(model_path, providers=providers) ort_outputs = session.run(None, inputs) diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py new file mode 100644 index 000000000..46272ccf9 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import numpy as np + +import onnxscript.ir as ir +from onnxscript.optimizer import remove_unused_nodes +from onnxscript.rewriter import _ir_utils, pattern + +# Rewrite the computation of cos/sin cache into the form expected by ORT's custom ops. + +# We match against the following code pattern: +# Original code (from transformers) for computing cos/sin cache for RoPE: +# https://github.com/huggingface/transformers/blob/0ade1caa356dce6b70ef8293addeb0898f177206/src/transformers/models/llama/modeling_llama.py#L135 +# position_ids_expanded = position_ids[:, None, :].float() +# freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) +# emb = torch.cat((freqs, freqs), dim=-1) +# cos = emb.cos() +# sin = emb.sin() +# +# We rewrite this pattern into the following form: +# inv_freq_values = inv_freq_expanded.reshape(1, -1) +# pos_id_range = np.arange(max_pos_id, dtype=np.float32).reshape(-1, 1) +# angles = np.matmul(pos_id_range, inv_freq_values) +# cos_value = np.cos(angles) +# sin_value = np.sin(angles) +# cos_2d = op.Constant(value=ir.tensor(cos_value)) +# sin_2d = op.Constant(value=ir.tensor(sin_value)) +# +# This produces cos/sin values in a form that can be used by ORT's custom ops. + +# TODO: To apply the pattern-rewrite, we need to know the maximum position id. +# Need to find a way to get this information from the model or its config. + + +class CosSinCacheFusion(pattern.RewriteRuleClassBase): + def __init__(self, name: str, max_pos_id: int): + # This pattern makes use of shared Cos/Sin values. So, we can't remove the + # matched nodes as part of the rewrite-step. We apply a separate final + # pass to remove unused nodes. + super().__init__(name, remove_nodes=False) + self._max_pos_id = max_pos_id + + def pattern(self, op, x, inv_freq, position_ids, interleaved, num_heads): + position_ids_expanded = op.Unsqueeze(position_ids, 1) + position_ids_expanded = op.Cast(position_ids_expanded, to=ir.DataType.FLOAT) + freqs = op.MatMul(inv_freq, position_ids_expanded) + freqs = op.Transpose(freqs, perm=[0, 2, 1]) + emb = op.Concat(freqs, freqs, axis=-1) + cos = op.Cos(emb) + sin = op.Sin(emb) + cos_4d = op.Unsqueeze(cos, 1) # convert + sin_4d = op.Unsqueeze(sin, 1) + return op.RotaryEmbedding( + x, + cos_4d, + sin_4d, + interleaved=interleaved, + num_heads=num_heads, + _domain="ai.onnxruntime.fusion", + ) + + def check(self, context, inv_freq, position_ids, **_) -> bool: + if not _ir_utils.has_rank(position_ids, 2): + return False + if not _ir_utils.has_rank(inv_freq, 3): + return False + inv_freq_shape = inv_freq.shape + if inv_freq.const_value is None: + return False + return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1 + + def rewrite(self, op, x, inv_freq, position_ids, interleaved, num_heads, **_): + inv_freq_values = inv_freq.const_value.numpy().reshape(1, -1) + pos_id_range = np.arange(self._max_pos_id, dtype=np.float32).reshape(-1, 1) + angles = np.matmul(pos_id_range, inv_freq_values) + cos_value = np.cos(angles) + sin_value = np.sin(angles) + cos_2d = op.Constant(value=ir.tensor(cos_value)) + sin_2d = op.Constant(value=ir.tensor(sin_value)) + return op.RotaryEmbedding( + x, + position_ids, + cos_2d, + sin_2d, + interleaved=interleaved, + num_heads=num_heads, + _domain="com.microsoft", + ) + + +_rule = CosSinCacheFusion.rule("CosSinCache", 2048) + +cos_sin_cache_rules = pattern.RewriteRuleSet([_rule]) + + +def fuse_cos_sin_cache(model: ir.Model) -> int: + count = cos_sin_cache_rules.apply_to_model(model) + print(f"CosSinCache count: {count}") + remove_unused_nodes(model) + return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py new file mode 100644 index 000000000..dfe6625a8 --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers import fuse_cos_sin_cache, fuse_rotary_embedding +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run + + +class TestCosSinCacheTransform(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + inputs = smollm_test.get_ort_inputs() + original_outputs = ort_run("original", model, inputs) + count = fuse_rotary_embedding(model) + self.assertGreater(count, 0) + count = fuse_cos_sin_cache(model) + self.assertGreater(count, 0) + new_outputs = ort_run("optimized", model, inputs) + assert_allclose(new_outputs, original_outputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py index 1f7a96df1..1e348acfb 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization.py @@ -35,14 +35,10 @@ def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool): cast_input: Whether to cast input to do the normalization in a different precision. cast_normalized: Whether to cast the normalized output to the target dtype (same as scale). """ - self._name = name + super().__init__(name=name) self._cast_input = cast_input self._cast_normalized = cast_normalized - @property - def name(self): - return self._name - def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): if self._cast_input: x = op.Cast(x, to=compute_dtype) @@ -95,5 +91,5 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): def fuse_rms_normalization(model: ir.Model) -> None: - count = rms_normalization_ruleset.apply_to_model(model, verbose=5) + count = rms_normalization_ruleset.apply_to_model(model) print(f"RMS Normalization count: {count}") diff --git a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py index 79a966838..30080474c 100644 --- a/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py +++ b/onnxscript/rewriter/onnxruntime/xformers/rms_normalization_test.py @@ -4,21 +4,12 @@ import unittest -import onnx - import onnxscript.optimizer from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData from onnxscript.rewriter.onnxruntime.xformers._test_utils import assert_allclose, ort_run from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization -def model_repr(self): - return f"Model({self.graph.name})" - - -onnx.ModelProto.__repr__ = model_repr - - class TestRmsNormalization(unittest.TestCase): def test_smollm(self): smollm_test = _SmollmTestData() diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py new file mode 100644 index 000000000..b36cf2c9b --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter import _ir_utils, pattern + +# Add first version of the RotaryEmbeddingFusion rule. This considers only one simple pattern +# for full rotation without interleaving. +# TODO(rama): Add pattern variations to handle other cases (interleaved, as well as partial rotation). + +# Note: This targets the new op being proposed to ONNX. This version does not exist in ORT yet. +# so it can't be tested by running against ORT. See cos_sin_cache.py for a transformation that +# rewrites the pattern into one that can be run against ORT. + + +def _rotate_half_pattern(op, x, start1, end1, start2, end2): + # Slice(input, starts, ends, axes, steps) + x1 = op.Slice(x, start1, end1, [3], [1]) + x2 = op.Slice(x, start2, end2, [3], [1]) + minus_x2 = op.Neg(x2) + rotated_x = op.Concat(minus_x2, x1, axis=-1) + return rotated_x + + +class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, cos, sin, start1, end1, start2, end2): + return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin + + def check(self, op, x, start1, end1, start2, end2, **_): + # x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads) + if x is None or x.shape is None or len(x.shape) != 4: + return False + if not isinstance(x.shape[1], int): + return False + head_size = x.shape[3] + if not isinstance(head_size, int): + return False + half_head_size = head_size // 2 + + # Check that x is being split into two equal halves of size half_head_size + return ( + _ir_utils.is_singleton_value(start1, 0) + and _ir_utils.is_singleton_value(end1, half_head_size) + and _ir_utils.is_singleton_value(start2, half_head_size) + and _ir_utils.is_singleton_value(end2, lambda x: x >= head_size) + ) + + def rewrite(self, op, x, cos, sin, **_): + num_heads = x.shape[1] + return op.RotaryEmbedding( + x, cos, sin, interleaved=0, num_heads=num_heads, _domain="ai.onnxruntime.fusion" + ) + + +_rule = RotaryEmbeddingFusion.rule() + +rotary_embedding_rules = pattern.RewriteRuleSet([_rule]) + + +def fuse_rotary_embedding(model: ir.Model) -> int: + count = rotary_embedding_rules.apply_to_model(model) + print(f"Rotary Embedding count: {count}") + return count diff --git a/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py new file mode 100644 index 000000000..6f8d37dee --- /dev/null +++ b/onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnxscript.optimizer +from onnxscript.rewriter.onnxruntime.xformers._smollm_1layer import _SmollmTestData +from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding + + +class TestRotaryEmbedding(unittest.TestCase): + def test_smollm(self): + smollm_test = _SmollmTestData() + model = smollm_test.get_onnx_model() + onnxscript.optimizer.optimize(model) + fuse_rotary_embedding(model) + op_types = [n.op_type for n in model.graph] + self.assertIn("RotaryEmbedding", op_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index f2faf77c3..a961ae872 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -946,6 +946,7 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node.""" @@ -1144,6 +1145,7 @@ def _match_single_output_node( model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, + check_removable: bool, ) -> MatchResult: del model del graph_or_function @@ -1162,13 +1164,13 @@ def _match_single_output_node( output_values = self._get_output_values() if output_values is None: return match - if not _valid_to_replace(match.nodes, output_values): + if check_removable and not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) return match - def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: + def _multi_match(self, candidate: Iterable[ir.Node], check_removable: bool) -> MatchResult: """Find a match for a pattern with multiple output nodes. For a pattern with K output nodes, the input candidate should specify K nodes @@ -1176,6 +1178,8 @@ def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: Args: candidate: An iterable of nodes that will be matched against the pattern output nodes. + check_removable: If True, check that the matched nodes can be removed (that is, that + they are not used elsewhere in the graph). """ match = self._match for pattern_node, node in zip(self.pattern.output_nodes, candidate): @@ -1185,7 +1189,7 @@ def _multi_match(self, candidate: Iterable[ir.Node]) -> MatchResult: if output_values is None: return match - if not _valid_to_replace(match.nodes, output_values): + if check_removable and not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) @@ -1197,6 +1201,7 @@ def match( graph_or_function: ir.Graph | ir.Function, node: ir.Node, verbose: int = 0, + remove_nodes: bool = True, ) -> MatchResult: """Match the pattern against the subgraph ending at the given node. @@ -1216,7 +1221,9 @@ def match( if self.pattern.has_single_output_node: self._init_match(verbose) - return self._match_single_output_node(model, graph_or_function, node) + return self._match_single_output_node( + model, graph_or_function, node, check_removable=remove_nodes + ) else: # Note: This is a potentially expensive algorithm for matching patterns with # multiple output nodes. For patterns with N output nodes, we try all possible @@ -1243,7 +1250,7 @@ def get_nodes(pattern_node): match = None for combination in itertools.product(*candidates): self._init_match(verbose) - match = self._multi_match(combination) + match = self._multi_match(combination, check_removable=remove_nodes) if match: return match if match is None: @@ -1260,6 +1267,7 @@ def __init__( matcher: PatternMatcher | Callable[[GraphPattern], PatternMatcher] | None = None, verbose: int = 0, name: str | None = None, + remove_nodes: bool = True, ) -> None: """Create a rewrite rule. @@ -1275,6 +1283,7 @@ def __init__( If not provided, a default matcher will be used. verbose: The verbosity level of the rule. name: An optional name for the pattern that will show up in verbose logging. + remove_nodes: If True, the matched nodes will be removed from the graph. """ if not isinstance(target_pattern, GraphPattern): @@ -1298,6 +1307,7 @@ def __init__( self._matcher = matcher(self._target_pattern) self._verbose = verbose self.name = name + self.remove_nodes = remove_nodes def __str__(self) -> str: if self.name: @@ -1317,7 +1327,9 @@ def try_rewrite( if verbose and verbose > 2: print(f"[try_rewrite] {self}") verbose = verbose if verbose is not None else self._verbose - match = self._matcher.match(model, graph_or_function, node, verbose=verbose) + match = self._matcher.match( + model, graph_or_function, node, verbose=verbose, remove_nodes=self.remove_nodes + ) if match: context = None # TODO(rama) for var in self._target_pattern.inputs: @@ -1440,19 +1452,23 @@ class RewriteRuleClassBase: def rule(cls, *args, **kwargs): instance = cls(*args, **kwargs) return RewriteRule( - instance.pattern, instance.rewrite, instance.check, name=instance.name + instance.pattern, + instance.rewrite, + instance.check, + name=instance.name, + remove_nodes=instance.remove_nodes, ) - @property - def name(self): - """Default implementation of name property.""" - return self.__class__.__name__ + def __init__(self, name: str | None = None, remove_nodes: bool = True) -> None: + self.name = name or self.__class__.__name__ + self.remove_nodes = remove_nodes def pattern(self, op, *args, **kwargs): raise NotImplementedError("Method 'pattern' must be implemented by derived class.") def check(self, op, *args, **kwargs): - raise NotImplementedError("Method 'check' must be implemented by derived class.") + # Default check function that always returns True. + return True def rewrite(self, op, *args, **kwargs): raise NotImplementedError("Method 'rewrite' must be implemented by derived class.") @@ -1488,7 +1504,7 @@ def _apply_to_graph_or_function( _convenience.replace_nodes_and_values( graph_or_function, node, - delta.match.nodes, + delta.match.nodes if rule.remove_nodes else [], delta.new_nodes, delta.match.outputs, delta.new_outputs, diff --git a/pyproject.toml b/pyproject.toml index 4771d85b9..61128ac9e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", ] -dependencies = ["numpy", "onnx>=1.16", "typing_extensions", "ml_dtypes", "packaging"] +dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes", "packaging"] [tool.setuptools.packages.find] include = ["onnxscript*"] diff --git a/pyproject_pylint.toml b/pyproject_pylint.toml index 227a361b8..673439074 100644 --- a/pyproject_pylint.toml +++ b/pyproject_pylint.toml @@ -24,6 +24,7 @@ disable = [ "too-many-instance-attributes", "too-many-lines", "too-many-locals", + "too-many-positional-arguments", "too-many-public-methods", "too-many-return-statements", "too-many-statements", # TODO: we should work on these: too-many-xxx series diff --git a/requirements-dev.txt b/requirements-dev.txt index 103fab8ab..355fce3bf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -2,7 +2,7 @@ setuptools>=61.0.0 numpy onnx-weekly>=1.17.0.dev20240325 onnxruntime>=1.17.0 -typing_extensions +typing_extensions>=4.10 rich>=13.7.1 # Docs site diff --git a/requirements/ci/requirements-onnx-weekly.txt b/requirements/ci/requirements-onnx-weekly.txt index 5dc19b92d..ccae99b0b 100644 --- a/requirements/ci/requirements-onnx-weekly.txt +++ b/requirements/ci/requirements-onnx-weekly.txt @@ -1 +1 @@ -onnx-weekly==1.18.0.dev20241217 +onnx-weekly==1.18.0.dev20250106 diff --git a/requirements/lintrunner/requirements.txt b/requirements/lintrunner/requirements.txt index e6adda625..3ea357152 100644 --- a/requirements/lintrunner/requirements.txt +++ b/requirements/lintrunner/requirements.txt @@ -1,11 +1,11 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.8.0 # RUFF, RUFF-FIX -ruff==0.8.4 +ruff==0.8.6 # MYPY mypy==1.10.1 -types-PyYAML==6.0.12.20240808 +types-PyYAML==6.0.12.20241230 # PYLINT -pylint==2.17.6 +pylint==3.3.3 # EDITORCONFIG-CHECKER editorconfig-checker==3.0.3 diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index 91f1df916..4dc486c5e 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1935,6 +1935,16 @@ def shape(size, rank, with_batch_channel=True): ) +def sample_inputs_window_functions(op_info, device, dtype, requires_grad, **kwargs): + del op_info + del kwargs + del device + del requires_grad + + for window_length in [2, 3, 7, 10, 32]: + yield opinfo_core.SampleInput(window_length, kwargs=dict(dtype=dtype)) + + class _TestParamsMaxPoolEmptyStrideBase: # Adapted from https://github.com/pytorch/pytorch/blob/d6d55f8590eab05d2536756fb4efcfb2d07eb81a/torch/testing/_internal/common_methods_invocations.py#L3203 def __init__(self): @@ -2037,6 +2047,13 @@ def __init__(self): sample_inputs_func=sample_inputs_bernoulli_p_deterministic, supports_out=False, ), + opinfo_core.OpInfo( + "ops.aten.blackman_window", + aten_name="blackman_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.col2im", aten_name="col2im", @@ -2115,6 +2132,20 @@ def __init__(self): lhs_make_tensor_kwargs=dict(low=0), rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0), ), + opinfo_core.OpInfo( + "ops.aten.hamming_window", + aten_name="hamming_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), + opinfo_core.OpInfo( + "ops.aten.hann_window", + aten_name="hann_window", + dtypes=common_dtype.floating_types_and(torch.bfloat16), + sample_inputs_func=sample_inputs_window_functions, + supports_out=False, + ), opinfo_core.OpInfo( "ops.aten.index.Tensor", aten_name="index.Tensor", diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 07164d594..bebd9a8ab 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -695,6 +695,7 @@ def _where_input_wrangler( TorchLibOpInfo("bitwise_right_shift_int64", core_ops.aten_bitwise_right_shift_int64), TorchLibOpInfo("bitwise_right_shift_int8", core_ops.aten_bitwise_right_shift_int8), TorchLibOpInfo("bitwise_xor", core_ops.aten_bitwise_xor), + TorchLibOpInfo("ops.aten.blackman_window", core_ops.aten_blackman_window), TorchLibOpInfo("bmm", core_ops.aten_bmm), TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to), TorchLibOpInfo("cat", core_ops.aten_cat).skip( @@ -1630,6 +1631,12 @@ def _where_input_wrangler( matcher=lambda sample: any(dim == 0 for dim in sample.input.shape), reason="Using op.InstanceNormalization to simulate GroupNorm, which does not support 0-dim input", ), + TorchLibOpInfo( + "ops.aten.hamming_window", + core_ops.aten_hamming_window, + tolerance={torch.float32: (8e-2, 6e-3)}, + ), + TorchLibOpInfo("ops.aten.hann_window", core_ops.aten_hann_window), TorchLibOpInfo("heaviside", core_ops.aten_heaviside), TorchLibOpInfo( "nn.functional.grid_sample",