From 8c7b360685f1222563340b95e98041619a9d098b Mon Sep 17 00:00:00 2001 From: dolf Date: Tue, 1 Oct 2024 13:11:01 +0200 Subject: [PATCH] Allow specifying the encoding of the VBA source code. Replace invalid sequences instead of failing. --- mkdocstrings_handlers/vba/_handler.py | 21 ++++++++++--- test/handler/__init__.py | 0 test/handler/test_collect.py | 45 +++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) create mode 100644 test/handler/__init__.py create mode 100644 test/handler/test_collect.py diff --git a/mkdocstrings_handlers/vba/_handler.py b/mkdocstrings_handlers/vba/_handler.py index af4b8ad..60828f9 100644 --- a/mkdocstrings_handlers/vba/_handler.py +++ b/mkdocstrings_handlers/vba/_handler.py @@ -12,7 +12,6 @@ MutableMapping, Dict, Mapping, - Set, Tuple, ) @@ -40,9 +39,17 @@ class VbaHandler(BaseHandler): The directory in which to look for VBA files. """ - def __init__(self, *, base_dir: Path, **kwargs: Any) -> None: + encoding: str + """ + The encoding to use when reading VBA files. + Excel exports .bas and .cls files as `latin1`. + See https://en.wikipedia.org/wiki/ISO/IEC_8859-1 . + """ + + def __init__(self, *, base_dir: Path, encoding: str, **kwargs: Any) -> None: super().__init__(**kwargs) self.base_dir = base_dir + self.encoding = encoding name: str = "vba" """ @@ -121,9 +128,7 @@ def collect( if not p.exists(): raise CollectionError("File not found.") - with p.open("r") as f: - code = f.read() - + code = p.read_text(encoding=self.encoding, errors="replace") code = collapse_long_lines(code) return VbaModuleInfo( @@ -178,6 +183,7 @@ def get_handler( theme: str = "material", custom_templates: str | None = None, config_file_path: str | None = None, + encoding: str = "latin1", **kwargs: Any, ) -> VbaHandler: """ @@ -187,6 +193,10 @@ def get_handler( theme: The theme to use when rendering contents. custom_templates: Directory containing custom templates. config_file_path: The MkDocs configuration file path. + encoding: + The encoding to use when reading VBA files. + Excel exports .bas and .cls files as `latin1`. + See https://en.wikipedia.org/wiki/ISO/IEC_8859-1 . kwargs: Extra keyword arguments that we don't use. Returns: @@ -198,6 +208,7 @@ def get_handler( if config_file_path else Path(".").resolve() ), + encoding=encoding, handler="vba", theme=theme, custom_templates=custom_templates, diff --git a/test/handler/__init__.py b/test/handler/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/handler/test_collect.py b/test/handler/test_collect.py new file mode 100644 index 0000000..5fdff67 --- /dev/null +++ b/test/handler/test_collect.py @@ -0,0 +1,45 @@ +import unittest +from pathlib import Path +from tempfile import TemporaryDirectory + +from mkdocstrings_handlers.vba import get_handler + +# noinspection PyProtectedMember +from mkdocstrings_handlers.vba._types import VbaModuleInfo + + +def _test_collect(*, write_bytes: bytes, read_encoding: str) -> VbaModuleInfo: + with TemporaryDirectory() as tmp_dir_str: + tmp_dir = Path(tmp_dir_str) + handler = get_handler(encoding=read_encoding) + p = tmp_dir / "source.bas" + p.write_bytes(write_bytes) + return handler.collect(identifier=p.as_posix(), config={}) + + +class TestCollect(unittest.TestCase): + + def test_undefined_unicode(self) -> None: + # See https://symbl.cc/en/unicode-table/#undefined-0 for values that are undefined in Unicode. + # \xe2\xbf\xaf is utf-8 for the undefined Unicode point U+2FEF + module_info = _test_collect( + write_bytes=b"Foo \xe2\xbf\xaf Bar", + read_encoding="utf-8", + ) + self.assertEqual(["Foo \u2fef Bar"], module_info.source) + + def test_invalid_utf8(self) -> None: + # invalid start byte + module_info = _test_collect( + write_bytes=b"\x89\x89\x89\x89", + read_encoding="utf-8", + ) + self.assertEqual(["����"], module_info.source) + + def test_invalid_latin1(self) -> None: + module_info = _test_collect( + write_bytes="🎵".encode("utf-8"), + read_encoding="latin1", + ) + # Since `latin1` is a single-byte encoding, it can't detect invalid sequences, and so we get mojibake. + self.assertEqual(["ð\x9f\x8eµ"], module_info.source)