Skip to content

Commit

Permalink
refactor: tidy codes
Browse files Browse the repository at this point in the history
Signed-off-by: Jack Cherng <[email protected]>
  • Loading branch information
jfcherng committed Jun 24, 2024
1 parent bd7ddaf commit 09c2c5f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 49 deletions.
42 changes: 36 additions & 6 deletions plugin/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
from __future__ import annotations

import os
import re
import subprocess
from typing import Any
import sys
from collections.abc import Generator, Iterable
from typing import Any, TypeVar

_T = TypeVar("_T")


def camel_to_snake(s: str) -> str:
"""Converts "CamelCase" to "snake_case"."""
return "".join((f"_{c}" if c.isupper() else c) for c in s).strip("_").lower()


def snake_to_camel(s: str, *, upper_first: bool = True) -> str:
"""Converts "snake_case" to "CamelCase"."""
first, *others = s.split("_")
return (first.title() if upper_first else first.lower()) + "".join(map(str.title, others))


if sys.version_info >= (3, 9):
remove_prefix = str.removeprefix
remove_suffix = str.removesuffix
else:

def remove_prefix(s: str, prefix: str) -> str:
"""Remove the prefix from the string. I.e., str.removeprefix in Python 3.9."""
return s[len(prefix) :] if s.startswith(prefix) else s

def remove_suffix(s: str, suffix: str) -> str:
"""Remove the suffix from the string. I.e., str.removesuffix in Python 3.9."""
# suffix="" should not call s[:-0]
return s[: -len(suffix)] if suffix and s.endswith(suffix) else s


def drop_falsy(iterable: Iterable[_T | None]) -> Generator[_T, None, None]:
"""Drops falsy values from the iterable."""
yield from filter(None, iterable)


def get_default_startupinfo() -> Any:
Expand All @@ -14,7 +48,3 @@ def get_default_startupinfo() -> Any:
STARTUPINFO.wShowWindow = subprocess.SW_HIDE # type: ignore
return STARTUPINFO
return None


def lowercase_drive_letter(path: str) -> str:
return re.sub(r"^[A-Z]+(?=:\\)", lambda m: m.group(0).lower(), path)
61 changes: 18 additions & 43 deletions plugin/venv_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import configparser
import os
import re
import shutil
import subprocess
from abc import ABC, abstractmethod
Expand All @@ -16,6 +15,7 @@
from typing_extensions import Self

from .log import log_error
from .utils import camel_to_snake, remove_suffix


def find_venv_by_finder_names(finder_names: Sequence[str], *, project_dir: Path) -> VenvInfo | None:
Expand Down Expand Up @@ -152,32 +152,15 @@ def from_pyvenv_cfg_file(cls, pyvenv_cfg_file: str | Path) -> Self | None:
return cls.from_venv_dir(venv_dir)

@staticmethod
def parse_pyvenv_cfg(pyvenv_cfg: Path) -> dict[str, Any]:
def parse_pyvenv_cfg(pyvenv_cfg: Path) -> dict[str, str]:
# value of these keys are expected to be a string
str_attr = {"command", "executable", "home", "implementation", "prompt", "uv", "version", "version_info"}

def _cast(key: str, val: str) -> Any:
if key in str_attr:
return val
if val.lower() == "true":
return True
if val.lower() == "false":
return False
if val.isdigit():
return int(val)
try:
return float(val)
except ValueError:
pass
return val

config = configparser.ConfigParser()
try:
content = pyvenv_cfg.read_text(encoding="utf-8")
config.read_string(f"[USER]\n{content}")
except Exception:
return {}
return {k: _cast(k, v) for k, v in config.items("USER")}
return dict(config.items("USER"))


class BaseVenvFinder(ABC):
Expand All @@ -188,12 +171,7 @@ def __init__(self, project_dir: Path) -> None:
@final
@classmethod
def name(cls) -> str:
name = cls.__name__
# remove trailing "VenvFinder"
if name.endswith("VenvFinder"):
name = name[: -len("VenvFinder")]
# CamelCase to snake_case
return "".join(f"_{c.lower()}" if c.isupper() else c for c in name).lstrip("_")
return camel_to_snake(remove_suffix(cls.__name__, "VenvFinder"))

@final
@classmethod
Expand Down Expand Up @@ -227,12 +205,12 @@ def _find_venv(self) -> VenvInfo | None:
"""Find the virtual environment. Implement this method by the subclass."""

@staticmethod
def _find_from_venv_dir_candidates(candidates: Iterable[Path]) -> VenvInfo | None:
def _find_from_venv_dirs(venv_dirs: Iterable[Path]) -> VenvInfo | None:
def _filtered_candidates() -> Generator[Path, None, None]:
for candidate in candidates:
for venv_dir in venv_dirs:
try:
if candidate.is_dir():
yield candidate
if venv_dir.is_dir():
yield venv_dir
except PermissionError:
pass

Expand Down Expand Up @@ -277,7 +255,7 @@ def _can_support(cls, project_dir: Path) -> bool:
return True

def _find_venv(self) -> VenvInfo | None:
return self._find_from_venv_dir_candidates(self.project_dir.iterdir())
return self._find_from_venv_dirs(self.project_dir.iterdir())


class EnvVarCondaPrefixVenvFinder(BaseVenvFinder):
Expand All @@ -289,12 +267,10 @@ class EnvVarCondaPrefixVenvFinder(BaseVenvFinder):

@classmethod
def _can_support(cls, project_dir: Path) -> bool:
return True
return "CONDA_PREFIX" in os.environ

def _find_venv(self) -> VenvInfo | None:
if conda_prefix := os.environ.get("CONDA_PREFIX", ""):
return VenvInfo.from_venv_dir(conda_prefix)
return None
return VenvInfo.from_venv_dir(os.environ["CONDA_PREFIX"])


class EnvVarVirtualEnvVenvFinder(BaseVenvFinder):
Expand All @@ -306,12 +282,10 @@ class EnvVarVirtualEnvVenvFinder(BaseVenvFinder):

@classmethod
def _can_support(cls, project_dir: Path) -> bool:
return True
return "VIRTUAL_ENV" in os.environ

def _find_venv(self) -> VenvInfo | None:
if virtual_env := os.environ.get("VIRTUAL_ENV", ""):
return VenvInfo.from_venv_dir(virtual_env)
return None
return VenvInfo.from_venv_dir(os.environ["VIRTUAL_ENV"])


class LocalDotVenvVenvFinder(BaseVenvFinder):
Expand All @@ -326,7 +300,7 @@ def _can_support(cls, project_dir: Path) -> bool:
return True

def _find_venv(self) -> VenvInfo | None:
return self._find_from_venv_dir_candidates((
return self._find_from_venv_dirs((
self.project_dir / ".venv",
self.project_dir / "venv",
))
Expand Down Expand Up @@ -452,7 +426,8 @@ def _find_venv(self) -> VenvInfo | None:
return None
stdout, _, _ = output

if m := re.search(r"^venv: (.*)$", stdout, re.MULTILINE):
venv_dir = m.group(1)
return VenvInfo.from_venv_dir(venv_dir)
for line in stdout.splitlines():
pre, sep, post = line.partition(":")
if sep and pre == "venv":
return VenvInfo.from_venv_dir(post.strip())
return None

0 comments on commit 09c2c5f

Please sign in to comment.