Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateus Latrova authored and mateuslatrova committed Nov 27, 2023
1 parent df509bb commit 932deea
Show file tree
Hide file tree
Showing 7 changed files with 1,108 additions and 145 deletions.
110 changes: 70 additions & 40 deletions pipreqs/pipreqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
<compat> | e.g. Flask~=1.1.2
<gt> | e.g. Flask>=1.1.2
<no-pin> | e.g. Flask
--ignore-notebooks Ignore jupyter notebook files.
--scan-notebooks Look for imports in jupyter notebook files.
"""
from contextlib import contextmanager
import os
Expand All @@ -49,18 +49,22 @@
from yarg import json2package
from yarg.exceptions import HTTPError

try:
PythonExporter = None
ignore_notebooks = False
from nbconvert import PythonExporter
except ImportError:
pass

from pipreqs import __version__

REGEXP = [re.compile(r"^import (.+)$"), re.compile(r"^from ((?!\.+).*?) import (?:.*)$")]


scan_noteboooks = False


class NbconvertNotInstalled(ImportError):
default_message = (
"In order to scan jupyter notebooks, please install the nbconvert and ipython libraries"
)

def __init__(self, message=default_message):
super().__init__(message)


@contextmanager
def _open(filename=None, mode="r"):
Expand Down Expand Up @@ -115,31 +119,22 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
ignore_dirs_parsed.append(os.path.basename(os.path.realpath(e)))
ignore_dirs.extend(ignore_dirs_parsed)

extensions = get_file_extensions()

walk = os.walk(path, followlinks=follow_links)
for root, dirs, files in walk:
dirs[:] = [d for d in dirs if d not in ignore_dirs]

candidates.append(os.path.basename(root))
if notebooks_are_enabled():
files = [fn for fn in files if file_ext_is_allowed(fn, [".py", ".ipynb"])]
else:
files = [fn for fn in files if file_ext_is_allowed(fn, [".py"])]
py_files = [file for file in files if file_ext_is_allowed(file, [".py"])]
candidates.extend([os.path.splitext(filename)[0] for filename in py_files])

candidates = list(
map(
lambda fn: os.path.splitext(fn)[0],
filter(lambda fn: file_ext_is_allowed(fn, [".py"]), files),
)
)
files = [fn for fn in files if file_ext_is_allowed(fn, extensions)]

for file_name in files:
file_name = os.path.join(root, file_name)
contents = ""
if file_ext_is_allowed(file_name, [".py"]):
with open(file_name, "r", encoding=encoding) as f:
contents = f.read()
elif file_ext_is_allowed(file_name, [".ipynb"]) and notebooks_are_enabled():
contents = ipynb_2_py(file_name, encoding=encoding)
contents = read_file_content(file_name, encoding)

try:
tree = ast.parse(contents)
for node in ast.walk(tree):
Expand Down Expand Up @@ -176,8 +171,17 @@ def get_all_imports(path, encoding="utf-8", extra_ignore_dirs=None, follow_links
return list(packages - data)


def notebooks_are_enabled():
return PythonExporter and not ignore_notebooks
def get_file_extensions():
return [".py", ".ipynb"] if scan_noteboooks else [".py"]


def read_file_content(file_name: str, encoding="utf-8"):
if file_ext_is_allowed(file_name, [".py"]):
with open(file_name, "r", encoding=encoding) as f:
contents = f.read()
elif file_ext_is_allowed(file_name, [".ipynb"]) and scan_noteboooks:
contents = ipynb_2_py(file_name, encoding=encoding)
return contents


def file_ext_is_allowed(file_name, acceptable):
Expand All @@ -195,7 +199,6 @@ def ipynb_2_py(file_name, encoding="utf-8"):
str: parsed string
"""

exporter = PythonExporter()
(body, _) = exporter.from_filename(file_name)

Expand Down Expand Up @@ -484,12 +487,27 @@ def dynamic_versioning(scheme, imports):
return imports, symbol


def handle_scan_noteboooks():
if not scan_noteboooks:
logging.info("Not scanning for jupyter notebooks.")
return

try:
global PythonExporter
from nbconvert import PythonExporter
except ImportError:
raise NbconvertNotInstalled()


def init(args):
global ignore_notebooks
global scan_noteboooks
encoding = args.get("--encoding")
extra_ignore_dirs = args.get("--ignore")
follow_links = not args.get("--no-follow-links")
ignore_notebooks = args.get("--ignore-notebooks")

scan_noteboooks = args.get("--scan-notebooks", False)
handle_scan_noteboooks()

input_path = args["<path>"]

if encoding is None:
Expand All @@ -500,8 +518,15 @@ def init(args):
if extra_ignore_dirs:
extra_ignore_dirs = extra_ignore_dirs.split(",")

path = args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
if not args["--print"] and not args["--savepath"] and not args["--force"] and os.path.exists(path):
path = (
args["--savepath"] if args["--savepath"] else os.path.join(input_path, "requirements.txt")
)
if (
not args["--print"]
and not args["--savepath"]
and not args["--force"]
and os.path.exists(path)
):
logging.warning("requirements.txt already exists, " "use --force to overwrite it")
return

Expand Down Expand Up @@ -532,14 +557,17 @@ def init(args):
# the list of exported modules, installed locally
# and the package name is not in the list of local module names
# it add to difference
difference = [x for x in candidates if
# aggregate all export lists into one
# flatten the list
# check if candidate is in exports
x.lower() not in [y for x in local for y in x['exports']]
and
# check if candidate is package names
x.lower() not in [x['name'] for x in local]]
difference = [
x
for x in candidates
if
# aggregate all export lists into one
# flatten the list
# check if candidate is in exports
x.lower() not in [y for x in local for y in x["exports"]] and
# check if candidate is package names
x.lower() not in [x["name"] for x in local]
]

imports = local + get_imports_info(difference, proxy=proxy, pypi_server=pypi_server)
# sort imports based on lowercase name of package, similar to `pip freeze`.
Expand All @@ -558,7 +586,9 @@ def init(args):
if scheme in ["compat", "gt", "no-pin"]:
imports, symbol = dynamic_versioning(scheme, imports)
else:
raise ValueError("Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead")
raise ValueError(
"Invalid argument for mode flag, " "use 'compat', 'gt' or 'no-pin' instead"
)
else:
symbol = "=="

Expand Down
Loading

0 comments on commit 932deea

Please sign in to comment.