Skip to content

Commit

Permalink
Fix detection and rewriting of JS modules
Browse files Browse the repository at this point in the history
  • Loading branch information
benoit74 committed Apr 11, 2024
1 parent 4534b69 commit ee3fa4b
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 80 deletions.
27 changes: 23 additions & 4 deletions src/warc2zim/content_rewriting/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
path: str,
record: ArcWarcRecord,
existing_zim_paths: set[ZimPath],
js_modules: set[ZimPath],
):
self.content = get_record_content(record)

Expand All @@ -74,6 +75,7 @@ def __init__(
)

self.rewrite_mode = self.get_rewrite_mode(record, mimetype)
self.js_modules = js_modules

@property
def content_str(self):
Expand All @@ -94,6 +96,8 @@ def rewrite(
return self.rewrite_css()

if self.rewrite_mode == "javascript":
if any(path.value == self.path for path in self.js_modules):
opts["isModule"] = True
return self.rewrite_js(opts)

if self.rewrite_mode == "jsonp":
Expand Down Expand Up @@ -132,6 +136,14 @@ def get_rewrite_mode(self, record, mimetype):

return None

def js_module_found(self, zim_path: ZimPath):
"""Notification helper, for rewriters to call when they have found a JS module
They call it with the JS module expected ZIM path since they are the only one
to know the current document URL/path + the JS module URL.
"""
self.js_modules.add(zim_path)

def rewrite_html(self, head_template: Template, css_insert: str | None):
orig_url = urlsplit(self.orig_url_str)

Expand All @@ -145,9 +157,12 @@ def rewrite_html(self, head_template: Template, css_insert: str | None):
orig_scheme=orig_url.scheme,
orig_host=orig_url.netloc,
)
return HtmlRewriter(self.url_rewriter, head_insert, css_insert).rewrite(
self.content_str
)
return HtmlRewriter(
url_rewriter=self.url_rewriter,
pre_head_insert=head_insert,
post_head_insert=css_insert,
notify_js_module=self.js_module_found,
).rewrite(self.content_str)

@no_title
def rewrite_css(self) -> str | bytes:
Expand All @@ -156,7 +171,11 @@ def rewrite_css(self) -> str | bytes:
@no_title
def rewrite_js(self, opts: dict[str, Any]) -> str | bytes:
ds_rules = get_ds_rules(self.orig_url_str)
rewriter = JsRewriter(self.url_rewriter, ds_rules)
rewriter = JsRewriter(
url_rewriter=self.url_rewriter,
extra_rules=ds_rules,
notify_js_module=self.js_module_found,
)
return rewriter.rewrite(self.content.decode(), opts)

@no_title
Expand Down
62 changes: 42 additions & 20 deletions src/warc2zim/content_rewriting/html.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
from collections import namedtuple
from collections.abc import Callable
from html import escape
from html.parser import HTMLParser

Expand All @@ -8,7 +9,7 @@
from warc2zim.content_rewriting.ds import get_ds_rules
from warc2zim.content_rewriting.js import JsRewriter
from warc2zim.content_rewriting.rx_replacer import RxRewriter
from warc2zim.url_rewriting import ArticleUrlRewriter
from warc2zim.url_rewriting import ArticleUrlRewriter, ZimPath

AttrsList = list[tuple[str, str | None]]

Expand All @@ -21,6 +22,7 @@ def __init__(
url_rewriter: ArticleUrlRewriter,
pre_head_insert: str,
post_head_insert: str | None,
notify_js_module: Callable[[ZimPath], None],
):
super().__init__(convert_charrefs=False)
self.url_rewriter = url_rewriter
Expand All @@ -29,9 +31,10 @@ def __init__(
self.output = None
# This works only for tag without children.
# But as we use it to get the title, we are ok
self._active_tag = None
self.html_rewrite_context = None
self.pre_head_insert = pre_head_insert
self.post_head_insert = post_head_insert
self.notify_js_module = notify_js_module

def rewrite(self, content: str) -> RewritenHtml:
if self.output is not None:
Expand All @@ -49,10 +52,25 @@ def send(self, value: str):
self.output.write(value) # pyright: ignore[reportOptionalMemberAccess]

def handle_starttag(self, tag: str, attrs: AttrsList, *, auto_close: bool = False):
self._active_tag = tag
if tag == "script":
if "json" in (self.extract_attr(attrs, "type") or ""):
self._active_tag = "json"
script_type = self.extract_attr(attrs, "type")
if script_type == "json":
self.html_rewrite_context = "json"
elif script_type == "module":
self.html_rewrite_context = "js-module"
else:
self.html_rewrite_context = "js-classic"
elif tag == "link":
self.html_rewrite_context = "link"
link_rel = self.extract_attr(attrs, "rel")
if link_rel == "modulepreload":
self.html_rewrite_context = "js-module"
elif link_rel == "preload":
preload_type = self.extract_attr(attrs, "as")
if preload_type == "script":
self.html_rewrite_context = "js-classic"
else:
self.html_rewrite_context = tag

self.send(f"<{tag}")
if attrs:
Expand All @@ -63,7 +81,7 @@ def handle_starttag(self, tag: str, attrs: AttrsList, *, auto_close: bool = Fals
)
else:
url_rewriter = self.url_rewriter
self.send(self.transform_attrs(attrs, url_rewriter, self.css_rewriter))
self.send(self.transform_attrs(attrs, url_rewriter))

if auto_close:
self.send(" />")
Expand All @@ -73,25 +91,31 @@ def handle_starttag(self, tag: str, attrs: AttrsList, *, auto_close: bool = Fals
self.send(self.pre_head_insert)

def handle_endtag(self, tag: str):
self._active_tag = None
self.html_rewrite_context = None
if tag == "head" and self.post_head_insert:
self.send(self.post_head_insert)
self.send(f"</{tag}>")

def handle_startendtag(self, tag: str, attrs: AttrsList):
self.handle_starttag(tag, attrs, auto_close=True)
self._active_tag = None
self.html_rewrite_context = None

def handle_data(self, data: str):
if self._active_tag == "title" and self.title is None:
if self.html_rewrite_context == "title" and self.title is None:
self.title = data.strip()
elif self._active_tag == "style":
elif self.html_rewrite_context == "style":
data = self.css_rewriter.rewrite(data)
elif self._active_tag == "script":
rules = get_ds_rules(self.url_rewriter.article_url.value)
elif self.html_rewrite_context and self.html_rewrite_context.startswith("js-"):
if data.strip():
data = JsRewriter(self.url_rewriter, rules).rewrite(data)
elif self._active_tag == "json":
data = JsRewriter(
url_rewriter=self.url_rewriter,
extra_rules=get_ds_rules(self.url_rewriter.article_url.value),
notify_js_module=self.notify_js_module,
).rewrite(
data,
opts={"isModule": self.html_rewrite_context == "js-module"},
)
elif self.html_rewrite_context == "json":
if data.strip():
rules = get_ds_rules(self.url_rewriter.article_url.value)
if rules:
Expand Down Expand Up @@ -120,12 +144,13 @@ def process_attr(
self,
attr: tuple[str, str | None],
url_rewriter: UrlRewriterProto,
css_rewriter: CssRewriter,
) -> tuple[str, str | None]:
if not attr[1]:
return attr

if attr[0] in ("href", "src"):
if self.html_rewrite_context == "js-module":
self.notify_js_module(self.url_rewriter.get_item_path(attr[1]))
return (attr[0], url_rewriter(attr[1]))
if attr[0] == "srcset":
value_list = attr[1].split(",")
Expand All @@ -137,7 +162,7 @@ def process_attr(
new_value_list.append(new_value)
return (attr[0], ", ".join(new_value_list))
if attr[0] == "style":
return (attr[0], css_rewriter.rewrite_inline(attr[1]))
return (attr[0], self.css_rewriter.rewrite_inline(attr[1]))
return attr

def format_attr(self, name: str, value: str | None) -> str:
Expand All @@ -150,11 +175,8 @@ def transform_attrs(
self,
attrs: AttrsList,
url_rewriter: UrlRewriterProto,
css_rewriter: CssRewriter,
) -> str:
processed_attrs = (
self.process_attr(attr, url_rewriter, css_rewriter) for attr in attrs
)
processed_attrs = (self.process_attr(attr, url_rewriter) for attr in attrs)
return " ".join(self.format_attr(*attr) for attr in processed_attrs)

def extract_attr(
Expand Down
51 changes: 24 additions & 27 deletions src/warc2zim/content_rewriting/js.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from collections.abc import Iterable
from collections.abc import Callable, Iterable
from typing import Any

from warc2zim.content_rewriting import UrlRewriterProto
Expand All @@ -14,13 +14,6 @@
)
from warc2zim.url_rewriting import ZimPath

# Regex used to check if we ar import or exporting things in the string
# ie : If we are a module
IMPORT_RX = re.compile(r"""^\s*?import\s*?[{"'*]""")
EXPORT_RX = re.compile(
r"^\s*?export\s*?({([\s\w,$\n]+?)}[\s;]*|default|class)\s+", re.M
)

# The regex used to rewrite `import ...` in module code.
IMPORT_MATCH_RX = re.compile(
r"""^\s*?import(?:['"\s]*(?:[\w*${}\s,]+from\s*)?['"\s]?['"\s])(?:.*?)['"\s]""",
Expand Down Expand Up @@ -188,13 +181,6 @@ def create_js_rules() -> list[TransformationRule]:
REWRITE_JS_RULES = create_js_rules()


def js_rewriter_builder(url_rewriter: UrlRewriterProto):
def build_js_rewriter(extra_rules):
return JsRewriter(url_rewriter, extra_rules)

return build_js_rewriter


class JsRewriter(RxRewriter):
"""
JsRewriter is in charge of rewriting the js code stored in our zim file.
Expand All @@ -203,13 +189,15 @@ class JsRewriter(RxRewriter):
def __init__(
self,
url_rewriter: UrlRewriterProto,
notify_js_module: Callable[[ZimPath], None],
extra_rules: Iterable[TransformationRule] | None = None,
):
super().__init__(None)
self.extra_rules = extra_rules or []
self.first_buff = self._init_local_declaration(GLOBAL_OVERRIDES)
self.last_buff = "\n}"
self.url_rewriter = url_rewriter
self.notify_js_module = notify_js_module

def _init_local_declaration(self, local_decls: Iterable[str]) -> str:
"""
Expand Down Expand Up @@ -244,24 +232,17 @@ def _get_module_decl(self, local_decls: Iterable[str]) -> str:
f"""import {{ {", ".join(local_decls)} }} from "{wb_module_decl_url}";\n"""
)

def _detect_is_module(self, text: str) -> bool:
if "import" in text and IMPORT_RX.search(text):
return True
if "export" in text and EXPORT_RX.search(text):
return True
return False

def rewrite(self, text: str, opts: dict[str, Any] | None = None) -> str:
"""
Rewrite the js code in `text`.
"""
opts = opts or {}
if not opts.get("isModule"):
opts["isModule"] = self._detect_is_module(text)

is_module = opts.get("isModule", False)

rules = REWRITE_JS_RULES[:]

if opts["isModule"]:
if is_module:
rules.append(self._get_esm_import_rule())

rules += self.extra_rules
Expand All @@ -270,7 +251,7 @@ def rewrite(self, text: str, opts: dict[str, Any] | None = None) -> str:

new_text = super().rewrite(text, opts)

if opts["isModule"]:
if is_module:
return self._get_module_decl(GLOBAL_OVERRIDES) + new_text

if GLOBALS_RX.search(text):
Expand All @@ -282,11 +263,27 @@ def rewrite(self, text: str, opts: dict[str, Any] | None = None) -> str:
return new_text

def _get_esm_import_rule(self) -> TransformationRule:
def get_rewriten_import_url(url):
"""Rewrite the import URL
This takes into account that the result must be a relative URL, i.e. it
cannot be 'vendor.module.js' but must be './vendor.module.js'.
"""
url = self.url_rewriter(url)
if not (
url.startswith("/") or url.startswith("./") or url.startswith("../")
):
url = "./" + url
return url

def rewrite_import():
def func(m_object, _opts):
def sub_funct(match):
self.notify_js_module(
self.url_rewriter.get_item_path(match.group(2))
)
return (
f"{match.group(1)}{self.url_rewriter(match.group(2))}"
f"{match.group(1)}{get_rewriten_import_url(match.group(2))}"
f"{match.group(3)}"
)

Expand Down
2 changes: 2 additions & 0 deletions src/warc2zim/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self, args):
self.added_zim_items: set[ZimPath] = set()
self.revisits: dict[ZimPath, ZimPath] = {}
self.expected_zim_items: set[ZimPath] = set()
self.js_modules: set[ZimPath] = set()

# progress file handling
self.stats_filename = (
Expand Down Expand Up @@ -532,6 +533,7 @@ def add_items_for_warc_record(self, record):
self.head_template,
self.css_insert,
self.expected_zim_items,
self.js_modules,
)

if len(payload_item.content) != 0:
Expand Down
7 changes: 4 additions & 3 deletions src/warc2zim/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ def __init__(
head_template: Template,
css_insert: str | None,
existing_zim_paths: set[ZimPath],
js_modules: set[ZimPath],
):
super().__init__()

self.path = path
self.mimetype = get_record_mime_type(record)
(self.title, self.content) = Rewriter(path, record, existing_zim_paths).rewrite(
head_template, css_insert
)
(self.title, self.content) = Rewriter(
path, record, existing_zim_paths, js_modules
).rewrite(head_template, css_insert)

def get_hints(self):
is_front = self.mimetype.startswith("text/html")
Expand Down
7 changes: 7 additions & 0 deletions src/warc2zim/url_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,13 @@ def __init__(self, article_url: HttpUrl, existing_zim_paths: set[ZimPath]):
self.existing_zim_paths = existing_zim_paths
self.missing_zim_paths: set[ZimPath] = set()

def get_item_path(self, item_url: str) -> ZimPath:
"""Utility to transform an item URL into a ZimPath"""

item_absolute_url = urljoin(self.article_url.value, item_url)
item_path = normalize(HttpUrl(item_absolute_url))
return item_path

def __call__(self, item_url: str, *, rewrite_all_url: bool = True) -> str:
"""Rewrite a url contained in a article.
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest


@pytest.fixture(scope="module")
def no_js_notify():

def no_js_notify_handler(_: str):
pass

yield no_js_notify_handler
Loading

0 comments on commit ee3fa4b

Please sign in to comment.