Skip to content

Commit

Permalink
Add os.path YAML constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshitomo-matsubara committed Mar 29, 2024
1 parent a5a0594 commit c8664cb
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions torchdistill/common/yaml_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def yaml_join(loader, node):
:return: joined string.
:rtype: str
"""
seq = loader.construct_sequence(node)
seq = loader.construct_sequence(node, deep=True)
return ''.join([str(i) for i in seq])


Expand All @@ -34,10 +34,40 @@ def yaml_pathjoin(loader, node):
:return: joined (file) path.
:rtype: str
"""
seq = loader.construct_sequence(node)
seq = loader.construct_sequence(node, deep=True)
return os.path.expanduser(os.path.join(*[str(i) for i in seq]))


def yaml_expanduser(loader, node):
"""
Applies os.path.expanduser to a (file) path.
:param loader: yaml loader.
:type loader: yaml.loader.FullLoader
:param node: node.
:type node: yaml.nodes.Node
:return: (file) path.
:rtype: str
"""
path = loader.construct_python_str(node)
return os.path.expanduser(path)


def yaml_abspath(loader, node):
"""
Applies os.path.abspath to a (file) path.
:param loader: yaml loader.
:type loader: yaml.loader.FullLoader
:param node: node.
:type node: yaml.nodes.Node
:return: (file) path.
:rtype: str
"""
path = loader.construct_python_str(node)
return os.path.abspath(path)


def yaml_import_get(loader, node):
"""
Imports module and get its attribute.
Expand Down Expand Up @@ -145,6 +175,8 @@ def load_yaml_file(yaml_file_path, custom_mode=True):
if custom_mode:
yaml.add_constructor('!join', yaml_join, Loader=yaml.FullLoader)
yaml.add_constructor('!pathjoin', yaml_pathjoin, Loader=yaml.FullLoader)
yaml.add_constructor('!expanduser', yaml_expanduser, Loader=yaml.FullLoader)
yaml.add_constructor('!abspath', yaml_abspath, Loader=yaml.FullLoader)
yaml.add_constructor('!import_get', yaml_import_get, Loader=yaml.FullLoader)
yaml.add_constructor('!import_call', yaml_import_call, Loader=yaml.FullLoader)
yaml.add_constructor('!import_call_method', yaml_import_call_method, Loader=yaml.FullLoader)
Expand Down

0 comments on commit c8664cb

Please sign in to comment.