Skip to content

Commit

Permalink
ParseDecl API returns a struct to preserve API stability
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Feb 6, 2025
1 parent e04d94c commit ee172f9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
21 changes: 18 additions & 3 deletions ast_canopy/ast_canopy/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
# SPDX-License-Identifier: Apache-2.0

import subprocess
import shutil
import os
import tempfile
import logging
from typing import Optional
from dataclasses import dataclass

from numba.cuda.cuda_paths import get_nvidia_nvvm_ctk, get_cuda_home

Expand All @@ -17,6 +19,16 @@
logger = logging.getLogger(f"AST_Canopy.{__name__}")


@dataclass
class Declarations:
structs: list[Struct]
functions: list[Function]
function_templates: list[bindings.FunctionTemplate]
class_templates: list[ClassTemplate]
typedefs: list[bindings.Typedef]
enums: list[bindings.Enum]


def get_default_cuda_path() -> Optional[str]:
"""Return the path to the default CUDA home directory."""

Expand All @@ -38,7 +50,7 @@ def get_default_nvcc_path() -> Optional[str]:
nvvm_path = get_nvidia_nvvm_ctk()

if not nvvm_path:
return
return shutil.which("nvcc")

root = os.path.dirname(os.path.dirname(nvvm_path))
nvcc_path = os.path.join(root, "bin", "nvcc")
Expand Down Expand Up @@ -79,7 +91,10 @@ def get_default_cuda_compiler_include(default="/usr/local/cuda/include") -> str:

nvcc_bin = get_default_nvcc_path()
if not nvcc_bin:
logger.warning("Could not find NVCC binary. Using default nvcc bin from env.")
logger.warning(
"Could not find NVCC binary. AST_Canopy will attempt to "
"invoke `nvcc` directly in the subsequent commands."
)
nvcc_bin = "nvcc"

with tempfile.NamedTemporaryFile(suffix=".cu") as tmp_file:
Expand Down Expand Up @@ -230,7 +245,7 @@ def custom_cuda_home() -> list[str]:
ClassTemplate.from_c_obj(c_obj) for c_obj in decls.class_templates
]

return (
return Declarations(
structs,
functions,
decls.function_templates,
Expand Down
21 changes: 11 additions & 10 deletions ast_canopy/tests/test_parse_from_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
import pickle
from dataclasses import astuple

import pytest

Expand Down Expand Up @@ -80,7 +81,7 @@ def test_load_ast_structs(sample_struct_source, test_pickle):
sample_struct_source, [sample_struct_source], "sm_80"
)

structs, _, _, _, _, _ = decls
structs, _, _, _, _, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(s) for s in structs]
Expand Down Expand Up @@ -186,7 +187,7 @@ def test_load_ast_functions(sample_function_source, test_pickle):
sample_function_source, [sample_function_source], "sm_80"
)

_, functions, _, _, _, _ = decls
_, functions, _, _, _, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(f) for f in functions]
Expand Down Expand Up @@ -232,7 +233,7 @@ def test_load_ast_typedefs(sample_typedef_source, test_pickle):
sample_typedef_source, [sample_typedef_source], "sm_80"
)

structs, _, _, _, typedefs, _ = decls
structs, _, _, _, typedefs, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(t) for t in typedefs]
Expand Down Expand Up @@ -261,7 +262,7 @@ def test_load_ast_function_templates(sample_function_template_source, test_pickl
sample_function_template_source, [sample_function_template_source], "sm_80"
)

_, _, ft, _, _, _ = decls
_, _, ft, _, _, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(f) for f in ft]
Expand Down Expand Up @@ -300,7 +301,7 @@ def test_load_ast_class_templates(sample_class_template_source, test_pickle):
sample_class_template_source, [sample_class_template_source], "sm_80"
)

_, _, _, ct, _, _ = decls
_, _, _, ct, _, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(c) for c in ct]
Expand Down Expand Up @@ -357,7 +358,7 @@ def test_load_ast_nested_structs(sample_nested_structs_source, test_pickle):
sample_nested_structs_source, [sample_nested_structs_source], "sm_80"
)

structs, _, _, _, _, _ = decls
structs, _, _, _, _, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(s) for s in structs]
Expand All @@ -377,7 +378,7 @@ def test_load_ast_access_specifiers(sample_access_specifier_source, test_pickle)
sample_access_specifier_source, [sample_access_specifier_source], "sm_80"
)

structs, _, _, _, _, _ = decls
structs, _, _, _, _, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(s) for s in structs]
Expand Down Expand Up @@ -412,7 +413,7 @@ def test_load_enum(sample_enum_source, test_pickle):
sample_enum_source, [sample_enum_source], "sm_80"
)

_, _, _, _, _, enums = decls
_, _, _, _, _, enums = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(e) for e in enums]
Expand Down Expand Up @@ -453,7 +454,7 @@ def test_load_struct_function_execution_space(
sample_execution_space_source, [sample_execution_space_source], "sm_80"
)

structs, functions, _, _, _, _ = decls
structs, functions, _, _, _, _ = astuple(decls)

if test_pickle:
pickled = [pickle.dumps(s) for s in structs]
Expand Down Expand Up @@ -526,7 +527,7 @@ def test_load_by_cc(cc, answer, sample_load_by_cc_source):
sample_load_by_cc_source, [sample_load_by_cc_source], cc
)

structs, functions, _, _, _, _ = decls
structs, functions, _, _, _, _ = astuple(decls)

assert len(structs) == len(answer["structs"])
assert len(functions) == len(answer["functions"])
Expand Down

0 comments on commit ee172f9

Please sign in to comment.