Skip to content

Commit

Permalink
Merge pull request #20 from allenai/soldni/cli-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni authored Jul 21, 2023
2 parents 10c2964 + fc831cb commit a37e7c6
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dolma"
version = "0.6.5"
version = "0.7.0"
edition = "2021"
license = "Apache-2.0"

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dolma"
version = "0.6.5"
version = "0.7.0"
description = "Data filters"
license = {text = "Apache-2.0"}
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def print_config(config: Any, console: Optional[Console] = None) -> None:

class BaseCli(Generic[D]):
CONFIG: Type[D]
DESCRIPTION: Optional[str] = None

@classmethod
def make_parser(cls, parser: A) -> A:
Expand Down
6 changes: 4 additions & 2 deletions python/dolma/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

from .deduper import DeduperCli
from .mixer import MixerCli
from .tagger import TaggerCli
from .tagger import ListTaggerCli, TaggerCli

AVAILABLE_COMMANDS = {
"dedupe": DeduperCli,
"mix": MixerCli,
"tag": TaggerCli,
"list": ListTaggerCli
# following functionality is not yet implemented
# "visualize": None,
# "browse": None,
# "stats": None,
Expand All @@ -38,7 +40,7 @@ def main(argv: Optional[List[str]] = None):

for command, cli in AVAILABLE_COMMANDS.items():
if cli is not None:
cli.make_parser(subparsers.add_parser(command))
cli.make_parser(subparsers.add_parser(command, help=cli.DESCRIPTION))

args = parser.parse_args(argv)

Expand Down
1 change: 1 addition & 0 deletions python/dolma/cli/deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class DeduperConfig:

class DeduperCli(BaseCli):
CONFIG = DeduperConfig
DESCRIPTION = "Deduplicate documents or paragraphs using a bloom filter."

@classmethod
def run(cls, parsed_config: DeduperConfig):
Expand Down
1 change: 1 addition & 0 deletions python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class MixerConfig:

class MixerCli(BaseCli):
CONFIG = MixerConfig
DESCRIPTION = "Mix documents from multiple streams."

@classmethod
def run(cls, parsed_config: MixerConfig):
Expand Down
40 changes: 37 additions & 3 deletions python/dolma/cli/tagger.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from dataclasses import dataclass
from typing import List, Optional

from omegaconf import MISSING
from rich.console import Console
from rich.table import Table

from dolma.cli import BaseCli, field, print_config
from dolma.cli.shared import WorkDirConfig
from dolma.core.errors import DolmaConfigError
from dolma.core.loggers import get_logger
from dolma.core.paths import glob_path
from dolma.core.registry import TaggerRegistry
from dolma.core.runtime import create_and_run_tagger


Expand All @@ -29,14 +31,18 @@ class TaggerConfig:
default=[],
help="List of taggers to run.",
)
experiment: str = field(
default=MISSING,
experiment: Optional[str] = field(
default=None,
help="Name of the experiment.",
)
processes: int = field(
default=1,
help="Number of parallel processes to use.",
)
ignore_existing: bool = field(
default=False,
help="Whether to ignore existing outputs and re-run the taggers.",
)
debug: bool = field(
default=False,
help="Whether to run in debug mode.",
Expand All @@ -48,6 +54,10 @@ class TaggerConfig:

class TaggerCli(BaseCli):
CONFIG = TaggerConfig
DESCRIPTION = (
"Tag documents or spans of documents using one or more taggers. "
"For a list of available taggers, run `dolma list`."
)

@classmethod
def run(cls, parsed_config: TaggerConfig):
Expand Down Expand Up @@ -76,7 +86,31 @@ def run(cls, parsed_config: TaggerConfig):
destination=parsed_config.destination,
metadata=metadata,
taggers=taggers,
ignore_existing=parsed_config.ignore_existing,
num_processes=parsed_config.processes,
experiment=parsed_config.experiment,
debug=parsed_config.debug,
)


@dataclass
class ListTaggerConfig:
...


class ListTaggerCli(BaseCli):
CONFIG = ListTaggerConfig
DESCRIPTION = "List available taggers."

@classmethod
def run(cls, parsed_config: ListTaggerConfig):
table = Table(title="dolma taggers", style="bold")
table.add_column("name", justify="left", style="cyan")
table.add_column("class", justify="left", style="magenta")

for tagger_name, tagger_cls in sorted(TaggerRegistry.taggers()):
tagger_repr = f"{tagger_cls.__module__}.{tagger_cls.__name__}"
table.add_row(tagger_name, tagger_repr)

console = Console()
console.print(table)
17 changes: 11 additions & 6 deletions python/dolma/core/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
"""

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple

from msgspec import Struct
from typing_extensions import TypeAlias

TaggerOutputValueType: TypeAlias = Tuple[int, int, float]
TaggerOutputType: TypeAlias = List[TaggerOutputValueType]
TaggerOutputDictType: TypeAlias = Dict[str, TaggerOutputType]


class InputSpec(Struct):
Expand All @@ -22,7 +27,7 @@ class InputSpec(Struct):

class OutputSpec(Struct):
id: str
attributes: Dict[str, List[List[Union[int, float]]]]
attributes: Dict[str, List[Tuple[int, int, float]]]
source: Optional[str] = None


Expand Down Expand Up @@ -80,7 +85,7 @@ def select(self, doc: Document) -> str:
return doc.text[self.start : self.end]

@classmethod
def from_spec(cls, attribute_name: str, attribute_value: List[Union[int, float]]) -> "Span":
def from_spec(cls, attribute_name: str, attribute_value: TaggerOutputValueType) -> "Span":
if "__" in attribute_name:
# bff tagger has different name
exp_name, tgr_name, attr_type = attribute_name.split("__", 2)
Expand All @@ -97,12 +102,12 @@ def from_spec(cls, attribute_name: str, attribute_value: List[Union[int, float]]
tagger=tgr_name,
)

def to_spec(self) -> Tuple[str, List[Union[int, float]]]:
def to_spec(self) -> Tuple[str, TaggerOutputValueType]:
assert self.experiment is not None, "Experiment name must be set to convert to spec"
assert self.tagger is not None, "Tagger name must be set to convert to spec"
return (
f"{self.experiment}__{self.tagger}__{self.type}",
[self.start, self.end, self.score],
(self.start, self.end, self.score),
)

def __len__(self) -> int:
Expand Down Expand Up @@ -146,7 +151,7 @@ def from_spec(cls, doc: InputSpec, *attrs_groups: OutputSpec) -> "DocResult":

def to_spec(self) -> Tuple[InputSpec, OutputSpec]:
doc_spec = self.doc.to_spec()
attributes: Dict[str, List[List[Union[int, float]]]] = {}
attributes: Dict[str, List[TaggerOutputValueType]] = {}

for span in self.spans:
attr_name, attr_value = span.to_spec()
Expand Down
31 changes: 25 additions & 6 deletions python/dolma/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
import tqdm

from .errors import DolmaError, DolmaRetryableFailure
from .paths import add_suffix, glob_path, make_relative, mkdir_p, sub_prefix
from .paths import (
add_suffix,
glob_path,
join_path,
make_relative,
mkdir_p,
split_path,
sub_prefix,
)

METADATA_SUFFIX = ".done.txt"

Expand Down Expand Up @@ -116,10 +124,6 @@ def __init__(
if any("*" in p for p in itertools.chain(self.dst_prefixes, self.meta_prefixes)):
raise ValueError("Destination and metadata prefixes cannot contain wildcards.")

for i in range(len(self.src_prefixes)):
# adding a wildcard to the end of the each source prefix if it doesn't have one
self.src_prefixes[i] = add_suffix(p, "*") if "*" not in (p := self.src_prefixes[i]) else p

@classmethod
def process_single(
cls,
Expand Down Expand Up @@ -317,9 +321,24 @@ def _get_all_paths(self) -> Tuple[List[str], List[str], List[str]]:
all_source_paths, all_destination_paths, all_metadata_paths = [], [], []

for src_prefix, dst_prefix, meta_prefix in zip(self.src_prefixes, self.dst_prefixes, self.meta_prefixes):
prefix, rel_paths = make_relative(list(glob_path(src_prefix)))
current_source_prefixes = sorted(glob_path(src_prefix))

if len(current_source_prefixes) > 1:
# make relative only makes sense if there is more than one path; otherwise, it's unclear
# what a relative path would be.
prefix, rel_paths = make_relative(current_source_prefixes)
elif len(current_source_prefixes) == 1:
# in case we have a single path, we can just use the path minus the file as the shared prefix,
# and the file as the relative path
prot, parts = split_path(current_source_prefixes[0])
prefix, rel_paths = join_path(prot, *parts[:-1]), [parts[-1]]
else:
raise ValueError(f"Could not find any files matching {src_prefix}")

# shuffle the order of the files so time estimation in progress bars is more accurate
random.shuffle(rel_paths)

# get a list of which metadata files already exist
existing_metadata_names = set(
sub_prefix(path, meta_prefix).strip(METADATA_SUFFIX) for path in glob_path(meta_prefix)
)
Expand Down
57 changes: 43 additions & 14 deletions python/dolma/core/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"join_path",
"is_glob",
"split_glob",
"partition_path",
]


Expand All @@ -27,6 +28,15 @@
}


RE_ANY_ESCAPE = re.compile(r"(?<!\\)(\*\?\[\])")
RE_GLOB_STAR_ESCAPE = re.compile(r"(?<!\\)\*")
RE_GLOB_ONE_ESCAPE = re.compile(r"(?<!\\)\?")
RE_GLOB_OPEN_ESCAPE = re.compile(r"(?<!\\)\[")
RE_GLOB_CLOSE_ESCAPE = re.compile(r"(?<!\\)\]")
ESCAPE_SYMBOLS_MAP = {"*": "\u2581", "?": "\u2582", "[": "\u2583", "]": "\u2584"}
REVERSE_ESCAPE_SYMBOLS_MAP = {v: k for k, v in ESCAPE_SYMBOLS_MAP.items()}


def _get_fs(path: Union[Path, str]) -> AbstractFileSystem:
"""
Get the filesystem class for a given path.
Expand All @@ -46,12 +56,11 @@ def _escape_glob(s: Union[str, Path]) -> str:
"""
Escape glob characters in a string.
"""
r"(?<!\\)[*?[\]]"
s = str(s)
s = re.sub(r"(?<!\\)\*", "\u2581", s)
s = re.sub(r"(?<!\\)\?", "\u2582", s)
s = re.sub(r"(?<!\\)\[", "\u2583", s)
s = re.sub(r"(?<!\\)\]", "\u2584", s)
s = RE_GLOB_STAR_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["*"], s)
s = RE_GLOB_ONE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["?"], s)
s = RE_GLOB_OPEN_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["["], s)
s = RE_GLOB_CLOSE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["]"], s)
return s


Expand All @@ -60,10 +69,8 @@ def _unescape_glob(s: Union[str, Path]) -> str:
Unescape glob characters in a string.
"""
s = str(s)
s = re.sub("\u2581", "*", s)
s = re.sub("\u2582", "?", s)
s = re.sub("\u2583", "[", s)
s = re.sub("\u2584", "]", s)
for k, v in REVERSE_ESCAPE_SYMBOLS_MAP.items():
s = s.replace(k, v)
return s


Expand All @@ -87,6 +94,26 @@ def _unpathify(protocol: str, path: Path) -> str:
return path_str


def partition_path(path: str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]:
"""Partition a path into its protocol, symbols before a glob, and symbols after a glob."""
# split the path into its protocol and path components
prot, path_obj = _pathify(path)

# we need to first figure out if this path has a glob by checking if any of the escaped symbols for
# globs are in the path.
glob_locs = [i for i, p in enumerate(path_obj.parts) if any(c in p for c in REVERSE_ESCAPE_SYMBOLS_MAP)]

# make the path components before the glob
pre_glob_path = path_obj.parts[: glob_locs[0]] if glob_locs else path_obj.parts
pre_glob_path = tuple(_unescape_glob(p) for p in pre_glob_path)

# make the path components after the glob
post_glob_path = path_obj.parts[glob_locs[0] + 1 :] if glob_locs else ()
post_glob_path = tuple(_unescape_glob(p) for p in post_glob_path)

return prot, pre_glob_path, post_glob_path


def split_path(path: str) -> Tuple[str, Tuple[str, ...]]:
"""
Split a path into its protocol and path components.
Expand All @@ -95,7 +122,7 @@ def split_path(path: str) -> Tuple[str, Tuple[str, ...]]:
return protocol, tuple(_unescape_glob(p) for p in _path.parts)


def join_path(protocol: str, *parts: Union[str, Iterable[str]]) -> str:
def join_path(protocol: Union[str, None], *parts: Union[str, Iterable[str]]) -> str:
"""
Join a path from its protocol and path components.
"""
Expand All @@ -106,14 +133,17 @@ def join_path(protocol: str, *parts: Union[str, Iterable[str]]) -> str:
return _unescape_glob(path)


def glob_path(path: Union[Path, str], hidden_files: bool = False) -> Iterator[str]:
def glob_path(path: Union[Path, str], hidden_files: bool = False, autoglob_dirs: bool = True) -> Iterator[str]:
"""
Expand a glob path into a list of paths.
"""
path = str(path)
protocol = urlparse(path).scheme
fs = _get_fs(path)

if fs.isdir(path) and autoglob_dirs:
path = join_path(None, path, "*")

for gl in fs.glob(path):
gl = str(gl)

Expand Down Expand Up @@ -191,14 +221,13 @@ def make_relative(paths: List[str]) -> Tuple[str, List[str]]:
if len(paths) == 0:
raise ValueError("Cannot make relative path of empty list")

common_prot, common_parts = (p := _pathify(paths[0]))[0], p[1].parts
common_prot, common_parts, _ = partition_path(paths[0])

for path in paths:
current_prot, current_path = _pathify(path)
current_prot, current_parts, _ = partition_path(path)
if current_prot != common_prot:
raise ValueError(f"Protocols of {path} and {paths[0]} do not match")

current_parts = current_path.parts
for i in range(min(len(common_parts), len(current_parts))):
if common_parts[i] != current_parts[i]:
common_parts = common_parts[:i]
Expand Down
Loading

0 comments on commit a37e7c6

Please sign in to comment.