diff --git a/datamodel_code_generator/__main__.py b/datamodel_code_generator/__main__.py index 03c4e241b..458e0a42b 100644 --- a/datamodel_code_generator/__main__.py +++ b/datamodel_code_generator/__main__.py @@ -53,7 +53,6 @@ from datamodel_code_generator.format import ( DatetimeClassType, PythonVersion, - black_find_project_root, is_supported_in_black, ) from datamodel_code_generator.parser import LiteralType @@ -366,6 +365,26 @@ def merge_args(self, args: Namespace) -> None: setattr(self, field_name, getattr(parsed_args, field_name)) +def _get_pyproject_toml_config(source: Path) -> Optional[Dict[str, Any]]: + """Find and return the [tool.datamodel-codgen] section of the closest + pyproject.toml if it exists. + """ + + current_path = source + while current_path != current_path.parent: + if (current_path / 'pyproject.toml').is_file(): + pyproject_toml = load_toml(current_path / 'pyproject.toml') + if 'datamodel-codegen' in pyproject_toml.get('tool', {}): + return pyproject_toml['tool']['datamodel-codegen'] + + if (current_path / '.git').exists(): + # Stop early if we see a git repository root. + return None + + current_path = current_path.parent + return None + + def main(args: Optional[Sequence[str]] = None) -> Exit: """Main function.""" @@ -383,16 +402,9 @@ def main(args: Optional[Sequence[str]] = None) -> Exit: print(version) exit(0) - root = black_find_project_root((Path().resolve(),)) - pyproject_toml_path = root / 'pyproject.toml' - if pyproject_toml_path.is_file(): - pyproject_toml: Dict[str, Any] = { - k.replace('-', '_'): v - for k, v in load_toml(pyproject_toml_path) - .get('tool', {}) - .get('datamodel-codegen', {}) - .items() - } + pyproject_config = _get_pyproject_toml_config(Path().resolve()) + if pyproject_config is not None: + pyproject_toml = {k.replace('-', '_'): v for k, v in pyproject_config.items()} else: pyproject_toml = {} diff --git a/tests/data/expected/main_kr/pyproject/output.strictstr.py b/tests/data/expected/main_kr/pyproject/output.strictstr.py new file mode 100644 index 000000000..8ddb3f3e4 --- /dev/null +++ b/tests/data/expected/main_kr/pyproject/output.strictstr.py @@ -0,0 +1,69 @@ +# generated by datamodel-codegen: +# filename: api.yaml +# timestamp: 2019-07-26T00:00:00+00:00 + +from __future__ import annotations + +from typing import List, Optional + +from pydantic import AnyUrl, BaseModel, Field, StrictStr + + +class Pet(BaseModel): + id: int + name: StrictStr + tag: Optional[StrictStr] = None + + +class Pets(BaseModel): + __root__: List[Pet] + + +class User(BaseModel): + id: int + name: StrictStr + tag: Optional[StrictStr] = None + + +class Users(BaseModel): + __root__: List[User] + + +class Id(BaseModel): + __root__: StrictStr + + +class Rules(BaseModel): + __root__: List[StrictStr] + + +class Error(BaseModel): + code: int + message: StrictStr + + +class Api(BaseModel): + apiKey: Optional[StrictStr] = Field( + None, description='To be used as a dataset parameter value' + ) + apiVersionNumber: Optional[StrictStr] = Field( + None, description='To be used as a version parameter value' + ) + apiUrl: Optional[AnyUrl] = Field( + None, description="The URL describing the dataset's fields" + ) + apiDocumentationUrl: Optional[AnyUrl] = Field( + None, description='A URL to the API console for each API' + ) + + +class Apis(BaseModel): + __root__: List[Api] + + +class Event(BaseModel): + name: Optional[StrictStr] = None + + +class Result(BaseModel): + event: Optional[Event] = None diff --git a/tests/test_main_kr.py b/tests/test_main_kr.py index 13e157f04..eb32ae1dc 100644 --- a/tests/test_main_kr.py +++ b/tests/test_main_kr.py @@ -7,7 +7,7 @@ import pytest from freezegun import freeze_time -from datamodel_code_generator import inferred_message +from datamodel_code_generator import chdir, inferred_message from datamodel_code_generator.__main__ import Exit, main try: @@ -207,6 +207,46 @@ def test_pyproject(): ) +@pytest.mark.skipif( + black.__version__.split('.')[0] == '19', + reason="Installed black doesn't support the old style", +) +@freeze_time('2019-07-26') +def test_pyproject_with_tool_section(): + """Test that a pyproject.toml with a [tool.datamodel-codegen] section is + found and its configuration applied. + """ + with TemporaryDirectory() as output_dir: + output_dir = Path(output_dir) + pyproject_toml = """ +[tool.datamodel-codegen] +target-python-version = "3.10" +strict-types = ["str"] +""" + with open(output_dir / 'pyproject.toml', 'w') as f: + f.write(pyproject_toml) + output_file: Path = output_dir / 'output.py' + + # Run main from within the output directory so we can find our + # pyproject.toml. + with chdir(output_dir): + return_code: Exit = main( + [ + '--input', + str((OPEN_API_DATA_PATH / 'api.yaml').resolve()), + '--output', + str(output_file.resolve()), + ] + ) + + assert return_code == Exit.OK + assert ( + output_file.read_text() + # We expect the output to use pydantic.StrictStr in place of str + == (EXPECTED_MAIN_KR_PATH / 'pyproject' / 'output.strictstr.py').read_text() + ) + + @freeze_time('2019-07-26') def test_main_use_schema_description(): with TemporaryDirectory() as output_dir: