Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ipython parser for python >= 3.3 #229

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions pipreqs/pipreqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@
import requests
from yarg import json2package
from yarg.exceptions import HTTPError
import copy

if sys.version_info.major == 3 and sys.version_info.minor >= 3:
import IPython
import nbformat

SUPPORT_IPYTHON = True
transformer = IPython.core.inputtransformer2.TransformerManager()
else:
SUPPORT_IPYTHON = False

from pipreqs import __version__

Expand Down Expand Up @@ -110,10 +120,12 @@ def get_all_imports(

walk = os.walk(path, followlinks=follow_links)
for root, dirs, files in walk:
files_raw = copy.deepcopy(files)

dirs[:] = [d for d in dirs if d not in ignore_dirs]

candidates.append(os.path.basename(root))
files = [fn for fn in files if os.path.splitext(fn)[1] == ".py"]
files = [fn for fn in files_raw if os.path.splitext(fn)[1] == ".py"]

candidates += [os.path.splitext(fn)[0] for fn in files]
for file_name in files:
Expand All @@ -136,6 +148,33 @@ def get_all_imports(
else:
logging.error("Failed on file: %s" % file_name)
raise exc

if SUPPORT_IPYTHON:
files_ipy = [fn for fn in files_raw if os.path.splitext(fn)[1] == ".ipynb"]

for file_name in files_ipy:
file_name = os.path.join(root, file_name)
nb = nbformat.read(file_name, as_version=4)
contents = ""
for cell in nb.cells:
if cell.cell_type == "code":
contents += transformer.transform_cell(cell.source) + "\n"
try:
tree = ast.parse(contents)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for subnode in node.names:
raw_imports.add(subnode.name)
elif isinstance(node, ast.ImportFrom):
raw_imports.add(node.module)
except Exception as exc:
if ignore_errors:
traceback.print_exc(exc)
logging.warn("Failed on file: %s" % file_name)
continue
else:
logging.error("Failed on file: %s" % file_name)
raise exc

# Clean up imports
for name in [n for n in raw_imports if n]:
Expand All @@ -157,8 +196,8 @@ def get_all_imports(
return list(packages - data)


def filter_line(line):
return len(line) > 0 and line[0] != "#"
def filter_line(l):
return len(l) > 0 and l[0] != "#"


def generate_requirements_file(path, imports):
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
wheel==0.23.0
Yarg==0.1.9
docopt==0.6.2
docopt==0.6.2
IPython==7.17.0
nbformat==5.0.7
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
history = history_file.read().replace('.. :changelog:', '')

requirements = [
'docopt', 'yarg'
'docopt', 'yarg', 'IPython>=7.0.0', 'nbformat>=5.0.0'
]

setup(
Expand Down