Skip to content

Commit

Permalink
Fix pyproject.toml detection when [tool.black] section is omitted (#…
Browse files Browse the repository at this point in the history
…2242)

* Use custom pyproject.toml resolution instead of black's

 When loading its configuration, datamodel-codegen now searches for
 pyproject.toml files with [tool.datamodel-codegen] sections
 independently, rather than relying on black's project root detection
 (which fails if a [tool.black] section is not present).

* Add tests for pyproject.toml configuration handling

* Switch contextlib.chdir to datamodel_code_generator.chdir in test

* Ignore old black version

---------

Co-authored-by: Koudai Aono <[email protected]>
  • Loading branch information
otonnesen and koxudaxi authored Jan 12, 2025
1 parent 396601e commit 6e33a33
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 12 deletions.
34 changes: 23 additions & 11 deletions datamodel_code_generator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from datamodel_code_generator.format import (
DatetimeClassType,
PythonVersion,
black_find_project_root,
is_supported_in_black,
)
from datamodel_code_generator.parser import LiteralType
Expand Down Expand Up @@ -366,6 +365,26 @@ def merge_args(self, args: Namespace) -> None:
setattr(self, field_name, getattr(parsed_args, field_name))


def _get_pyproject_toml_config(source: Path) -> Optional[Dict[str, Any]]:
"""Find and return the [tool.datamodel-codgen] section of the closest
pyproject.toml if it exists.
"""

current_path = source
while current_path != current_path.parent:
if (current_path / 'pyproject.toml').is_file():
pyproject_toml = load_toml(current_path / 'pyproject.toml')
if 'datamodel-codegen' in pyproject_toml.get('tool', {}):
return pyproject_toml['tool']['datamodel-codegen']

if (current_path / '.git').exists():
# Stop early if we see a git repository root.
return None

current_path = current_path.parent
return None


def main(args: Optional[Sequence[str]] = None) -> Exit:
"""Main function."""

Expand All @@ -383,16 +402,9 @@ def main(args: Optional[Sequence[str]] = None) -> Exit:
print(version)
exit(0)

root = black_find_project_root((Path().resolve(),))
pyproject_toml_path = root / 'pyproject.toml'
if pyproject_toml_path.is_file():
pyproject_toml: Dict[str, Any] = {
k.replace('-', '_'): v
for k, v in load_toml(pyproject_toml_path)
.get('tool', {})
.get('datamodel-codegen', {})
.items()
}
pyproject_config = _get_pyproject_toml_config(Path().resolve())
if pyproject_config is not None:
pyproject_toml = {k.replace('-', '_'): v for k, v in pyproject_config.items()}
else:
pyproject_toml = {}

Expand Down
69 changes: 69 additions & 0 deletions tests/data/expected/main_kr/pyproject/output.strictstr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# generated by datamodel-codegen:
# filename: api.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import List, Optional

from pydantic import AnyUrl, BaseModel, Field, StrictStr


class Pet(BaseModel):
id: int
name: StrictStr
tag: Optional[StrictStr] = None


class Pets(BaseModel):
__root__: List[Pet]


class User(BaseModel):
id: int
name: StrictStr
tag: Optional[StrictStr] = None


class Users(BaseModel):
__root__: List[User]


class Id(BaseModel):
__root__: StrictStr


class Rules(BaseModel):
__root__: List[StrictStr]


class Error(BaseModel):
code: int
message: StrictStr


class Api(BaseModel):
apiKey: Optional[StrictStr] = Field(
None, description='To be used as a dataset parameter value'
)
apiVersionNumber: Optional[StrictStr] = Field(
None, description='To be used as a version parameter value'
)
apiUrl: Optional[AnyUrl] = Field(
None, description="The URL describing the dataset's fields"
)
apiDocumentationUrl: Optional[AnyUrl] = Field(
None, description='A URL to the API console for each API'
)


class Apis(BaseModel):
__root__: List[Api]


class Event(BaseModel):
name: Optional[StrictStr] = None


class Result(BaseModel):
event: Optional[Event] = None
42 changes: 41 additions & 1 deletion tests/test_main_kr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from freezegun import freeze_time

from datamodel_code_generator import inferred_message
from datamodel_code_generator import chdir, inferred_message
from datamodel_code_generator.__main__ import Exit, main

try:
Expand Down Expand Up @@ -207,6 +207,46 @@ def test_pyproject():
)


@pytest.mark.skipif(
black.__version__.split('.')[0] == '19',
reason="Installed black doesn't support the old style",
)
@freeze_time('2019-07-26')
def test_pyproject_with_tool_section():
"""Test that a pyproject.toml with a [tool.datamodel-codegen] section is
found and its configuration applied.
"""
with TemporaryDirectory() as output_dir:
output_dir = Path(output_dir)
pyproject_toml = """
[tool.datamodel-codegen]
target-python-version = "3.10"
strict-types = ["str"]
"""
with open(output_dir / 'pyproject.toml', 'w') as f:
f.write(pyproject_toml)
output_file: Path = output_dir / 'output.py'

# Run main from within the output directory so we can find our
# pyproject.toml.
with chdir(output_dir):
return_code: Exit = main(
[
'--input',
str((OPEN_API_DATA_PATH / 'api.yaml').resolve()),
'--output',
str(output_file.resolve()),
]
)

assert return_code == Exit.OK
assert (
output_file.read_text()
# We expect the output to use pydantic.StrictStr in place of str
== (EXPECTED_MAIN_KR_PATH / 'pyproject' / 'output.strictstr.py').read_text()
)


@freeze_time('2019-07-26')
def test_main_use_schema_description():
with TemporaryDirectory() as output_dir:
Expand Down

0 comments on commit 6e33a33

Please sign in to comment.