diff --git a/requirements.txt b/requirements.txt index 8555d54cc..e055f4c98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,5 +21,3 @@ peewee # Information collection psutil pynvml -watchdog - diff --git a/swanlab/api/http.py b/swanlab/api/http.py index 909db30b6..3a95469da 100644 --- a/swanlab/api/http.py +++ b/swanlab/api/http.py @@ -12,14 +12,13 @@ from .info import LoginInfo, ProjectInfo, ExperimentInfo from .auth.login import login_by_key from .cos import CosClient +from swanlab.data.modules import MediaBuffer from swanlab.error import NetworkError, ApiError from swanlab.package import get_host_api from swanlab.utils import FONT from swanlab.log import swanlog import requests -from swanlab.data.modules import MediaBuffer - def decode_response(resp: requests.Response) -> Union[Dict, AnyStr]: """ diff --git a/swanlab/api/upload/__init__.py b/swanlab/api/upload/__init__.py index ef0e83b8f..a95b0b7ca 100644 --- a/swanlab/api/upload/__init__.py +++ b/swanlab/api/upload/__init__.py @@ -8,13 +8,10 @@ 上传相关接口 """ from ..http import get_http, sync_error_handler -from .model import ColumnModel, MediaModel, ScalarModel +from .model import ColumnModel, MediaModel, ScalarModel, FileModel from typing import List -from swanlab.error import FileError, ApiError +from swanlab.error import ApiError from swanlab.log import swanlog -import json -import yaml -import os house_url = '/house/metrics' @@ -82,33 +79,20 @@ def upload_scalar_metrics(scalar_metrics: List[ScalarModel]): @sync_error_handler -def upload_files(files: List[str]): +def upload_files(files: List[FileModel]): """ 上传files文件夹中的内容 :param files: 文件列表,内部为文件绝对路径 """ http = get_http() - # 去重list - files = list(set(files)) - files = {os.path.basename(x): x for x in files} - # 读取文件配置,生成更新信息 - data = {} - for filename, filepath in files.items(): - if filename not in _valid_files: - continue - try: - with open(filepath, 'r') as f: - if _valid_files[filename][1] == 'json': - data[_valid_files[filename][0]] = json.load(f) - elif _valid_files[filename][1] == 'yaml': - d = yaml.load(f, Loader=yaml.FullLoader) - if d is None: - raise FileError - data[_valid_files[filename][0]] = d - else: - data[_valid_files[filename][0]] = f.read() - except json.decoder.JSONDecodeError: - raise FileError + # 去重所有的FileModel,留下一个 + if len(files) == 0: + return swanlog.warning("No files to upload.") + file_model = files[0] + if len(files) > 1: + for i in range(1, len(files) - 1): + file_model = FileModel.create(files[i], file_model) + data = file_model.to_dict() http.put(f'/project/{http.groupname}/{http.projname}/runs/{http.exp_id}/profile', data) diff --git a/swanlab/api/upload/model.py b/swanlab/api/upload/model.py index 9c0989b12..a7e4a7403 100644 --- a/swanlab/api/upload/model.py +++ b/swanlab/api/upload/model.py @@ -9,8 +9,8 @@ """ from enum import Enum from typing import List - from swanlab.data.modules import MediaBuffer +from datetime import datetime class ColumnModel: @@ -125,3 +125,37 @@ def to_dict(self): "index": self.step, "epoch": self.epoch } + + +class FileModel: + """ + 运行时文件信息上传模型 + """ + + def __init__(self, requirements: str = None, metadata: dict = None, config: dict = None): + self.requirements = requirements + self.metadata = metadata + self.config = config + self.create_time = datetime.now() + """ + 主要用于去重,保留最新的文件 + """ + + @classmethod + def create(cls, r1: "FileModel", r2: "FileModel") -> "FileModel": + """ + 比较两个FileModel,创建一个新的 + 如果新的newer不存在,则使用older的数据,否则使用newer的数据 + """ + newer, older = (r1, r2) if r1.create_time > r2.create_time else (r2, r1) + rq = newer.requirements if newer.requirements else older.requirements + md = newer.metadata if newer.metadata else older.metadata + cf = newer.config if newer.config else older.config + return cls(rq, md, cf) + + def to_dict(self): + """ + 序列化,会删除为None的字段 + """ + d = {"requirements": self.requirements, "metadata": self.metadata, "config": self.config} + return {k: v for k, v in d.items() if v is not None} diff --git a/swanlab/cloud/dog/README.md b/swanlab/cloud/dog/README.md deleted file mode 100644 index 6d0f0f86b..000000000 --- a/swanlab/cloud/dog/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# 看门小狗🐶 - -看门狗用于嗅探并监视文件资源、系统信息,并且完成向服务端推送的工作。 -本部分不涉及认证工作,单纯进行嗅探。 - diff --git a/swanlab/cloud/dog/__init__.py b/swanlab/cloud/dog/__init__.py deleted file mode 100644 index cabd2c745..000000000 --- a/swanlab/cloud/dog/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024-03-17 16:57:32 -@File: swanlab/data/dog/__init__.py -@IDE: vscode -@Description: - 看门狗模块,导出run函数,开启看门狗,新建其他子线程 -""" diff --git a/swanlab/cloud/dog/log_sniffer.py b/swanlab/cloud/dog/log_sniffer.py deleted file mode 100644 index 08b23473a..000000000 --- a/swanlab/cloud/dog/log_sniffer.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/4/4 14:11 -@File: log_sniffer.py -@IDE: pycharm -@Description: - 日志嗅探器 - 嗅探不做上传操作,只做采集操作,将采集到的日志、差异信息发送给日志聚合器 -""" -from ..utils import ThreadTaskABC, ThreadUtil -from watchdog.observers import Observer -from queue import Queue -from .sniffer_queue import SnifferQueue -from .metadata_handle import MetaHandle -from typing import List -from ..task_types import UploadType -import time - - -class LogSnifferTask(ThreadTaskABC): - """ - 日志嗅探器,负责监听日志信息的改动,当日志信息发生改动时将改动的部分包装发送给日志聚合器 - """ - SNIFFER_TIMEOUT = 2 - __QUEUE: Queue = Queue() - - def __init__(self, meta_path: str): - """ - 初始化日志嗅探器 - :param meta_path: 元数据文件夹路径,由于目前只有元数据文件夹需要嗅探,因此只需要传入元数据文件夹路径 - 后续如果有其他需要嗅探的文件夹,可以将此处改成传入handle类 - """ - self.__sniffer_queue = SnifferQueue(self.__QUEUE, readable=True, writable=False) - """ - 日志嗅探器队列,用于存放从一系列Handler中收集到的日志信息 - """ - self.__observer = Observer(timeout=self.SNIFFER_TIMEOUT) - self.__observer.schedule( - MetaHandle(self.__QUEUE, watched_path=meta_path), - meta_path, - recursive=True - ) - # observer,启动! - self.__observer.start() - - def callback(self, u: ThreadUtil, *args): - # 文件事件可能会有延迟,因此需要等待一段时间 - time.sleep(self.SNIFFER_TIMEOUT) - self.__observer.stop() - self.pass_msg(u) - - def pass_msg(self, u: ThreadUtil): - all_sniffer_msg: List = self.__sniffer_queue.get_all() - if not all_sniffer_msg or len(all_sniffer_msg) == 0: - return - # 去重,由于现在只有files元数据文件,所以只需要针对它去重就行 - # 遍历所有的消息 - files = {UploadType.FILE: []} - for msg in all_sniffer_msg: - for path in msg[0]: - if path not in files[UploadType.FILE]: - files[UploadType.FILE].append(path) - new_msg = (UploadType.FILE, files[UploadType.FILE]) - u.queue.put(new_msg) - - def task(self, u: ThreadUtil, *args): - """ - 任务执行函数,在此处收集处理的所有日志信息,解析、包装、发送给日志聚合器 - :param u: 线程工具类 - """ - # 在此处完成日志信息聚合 - self.pass_msg(u) diff --git a/swanlab/cloud/dog/metadata_handle.py b/swanlab/cloud/dog/metadata_handle.py deleted file mode 100644 index 338b155ab..000000000 --- a/swanlab/cloud/dog/metadata_handle.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/4/5 18:20 -@File: metadata_handle.py -@IDE: pycharm -@Description: - 元数据处理器,看门狗嗅探元数据文件夹,向聚合器发送元数据信息 -""" -from ..task_types import UploadType -from .sniffer_queue import SnifferQueue -from typing import Union, List -from watchdog.events import FileSystemEventHandler, FileSystemEvent -from swanlab.log import swanlog -import os -from queue import Queue - - -class MetaHandle(FileSystemEventHandler): - ValidFiles = ['config.yaml', 'requirements.txt', 'swanlab-metadata.json'] - """ - 有效的元数据文件列表,只有这些文件会被传输,如果出现其他文件出现waning - """ - - ModifiableFiles = [ValidFiles[0], ValidFiles[2]] - """ - 可修改的元数据文件列表(其他只会传输一次) - """ - - def __init__(self, queue: Queue, watched_path: str): - """ - 初始化日志嗅探处理器 - :param watched_path: 监听的路径,用作初始对照 - """ - self.watched_path = watched_path - self.queue = SnifferQueue(queue, readable=False) - self.on_init_upload() - - def list_all_meta_files(self) -> List[str]: - """ - 列出所有的元数据文件 - """ - files = [x for x in os.listdir(self.watched_path) if os.path.isfile(self.fmt_file_path(x)[0])] - return [x for x in files if x in self.ValidFiles] - - def fmt_file_path(self, file_name: Union[List[str], str]) -> List[str]: - """ - 格式化文件路径 - """ - if isinstance(file_name, str): - file_name = [file_name] - return [os.path.join(self.watched_path, x) for x in file_name] - - def on_init_upload(self): - """ - 实例化的时候进行一次文件扫描,watched_path下所有ValidFiles生成一个一个msg发给队列 - """ - meta_files = self.list_all_meta_files() - if len(meta_files) == 0: - return swanlog.warning("empty meta files, it might be a bug?") - self.queue.put((self.fmt_file_path(meta_files), UploadType.FILE)) - - def on_modified(self, event: FileSystemEvent) -> None: - """ - 文件被修改时触发 - """ - if event.is_directory: - return - file_name = os.path.basename(event.src_path) - if file_name not in self.ModifiableFiles: - # 被忽略 - return swanlog.warning(f"file {file_name} is not allowed to be modified") - self.queue.put((self.fmt_file_path(file_name), UploadType.FILE)) diff --git a/swanlab/cloud/dog/sniffer_queue.py b/swanlab/cloud/dog/sniffer_queue.py deleted file mode 100644 index 1d637053e..000000000 --- a/swanlab/cloud/dog/sniffer_queue.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/4/5 21:23 -@File: sniffer_queue.py -@IDE: pycharm -@Description: - 嗅探器队列,负责收集所有嗅探线程注册的日志信息 -""" -from typing import List -from queue import Queue -from ..utils import LogQueue - - -class SnifferQueue(LogQueue): - - def __init__(self, queue: Queue, readable: bool = True, writable: bool = True): - super().__init__(queue, readable, writable) - - def put(self, msg: LogQueue.MsgType): - """ - 向管道中写入日志信息,日志信息必须是函数,聚合器会依次执行他们 - :param msg: 日志信息 - """ - super().put(msg) - - def get(self) -> LogQueue.MsgType: - """ - 从管道中读取日志信息 - :return: 日志信息 - """ - return super().get() - - def get_all(self) -> List[LogQueue.MsgType]: - """ - 从管道中读取所有的日志信息 - :return: 日志信息 - """ - return super().get_all() - - def put_all(self, msgs: List[LogQueue.MsgType]): - """ - 向管道中写入所有日志信息 - :param msgs: 日志信息 - """ - super().put_all(msgs) diff --git a/swanlab/data/callback_cloud.py b/swanlab/data/callback_cloud.py index 035c78770..83116bb29 100644 --- a/swanlab/data/callback_cloud.py +++ b/swanlab/data/callback_cloud.py @@ -7,10 +7,9 @@ @Description: 云端回调 """ -from .run.callback import MetricInfo, ColumnInfo -from swanlab.cloud import UploadType -from swanlab.error import ApiError -from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel +from .run.callback import MetricInfo, ColumnInfo, RuntimeInfo +from swanlab.data.cloud import UploadType +from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel from swanlab.api import LoginInfo, create_http, terminal_login from swanlab.api.upload import upload_logs from swanlab.log import swanlog @@ -22,8 +21,7 @@ from swanlab.error import KeyFileError from swanlab.env import get_swanlab_folder from .callback_local import LocalRunCallback, get_run, SwanLabRunState -from swanlab.cloud import LogSnifferTask, ThreadPool -from swanlab.db import Experiment +from swanlab.data.cloud import ThreadPool from swanlab.utils import create_time from swanlab.package import get_package_version, get_package_latest_version import json @@ -93,7 +91,7 @@ def _clean_handler(self): return swanlog.debug("SwanLab is exiting, please wait.") self._train_finish_print() # 如果正在运行 - run.finish() if run.is_running else swanlog.debug("Duplicate finish, ignore it.") + run.finish() if run.running else swanlog.debug("Duplicate finish, ignore it.") def _except_handler(self, tp, val, tb): if self.exiting: @@ -124,22 +122,11 @@ def on_init(self, project: str, workspace: str, logdir: str = None) -> int: def on_run(self): swanlog.install(self.settings.console_dir) # 注册实验信息 - try: - get_http().mount_exp( - exp_name=self.settings.exp_name, - colors=self.settings.exp_colors, - description=self.settings.description, - ) - except ApiError as e: - if e.resp.status_code == 409: - FONT.brush("", 50) - swanlog.error("The experiment name already exists, please change the experiment name") - Experiment.purely_delete(run_id=self.settings.run_id) - sys.exit(409) - - # 资源嗅探器 - sniffer = LogSnifferTask(self.settings.files_dir) - self.pool.create_thread(sniffer.task, name="sniffer", callback=sniffer.callback) + get_http().mount_exp( + exp_name=self.settings.exp_name, + colors=self.settings.exp_colors, + description=self.settings.description, + ) # 向swanlog注册输出流回调 def _write_call_call(message): @@ -159,6 +146,17 @@ def _write_call_call(message): if in_jupyter(): show_button_html(experiment_url) + def on_runtime_info_update(self, r: RuntimeInfo): + # 执行local逻辑,保存文件到本地 + super(CloudRunCallback, self).on_runtime_info_update(r) + # 添加上传任务到线程池 + rc = r.config.to_dict() if r.config is not None else None + rr = r.requirements.info if r.requirements is not None else None + rm = r.metadata.to_dict() if r.metadata is not None else None + # 不需要json序列化,上传时会自动序列化 + f = FileModel(requirements=rr, config=rc, metadata=rm) + self.pool.queue.put((UploadType.FILE, [f])) + def on_column_create(self, column_info: ColumnInfo): error = None if column_info.error is not None: diff --git a/swanlab/data/callback_local.py b/swanlab/data/callback_local.py index 3a309fd78..66beb4803 100644 --- a/swanlab/data/callback_local.py +++ b/swanlab/data/callback_local.py @@ -7,12 +7,10 @@ @Description: 基本回调函数注册表,此时不考虑云端情况 """ -from typing import Callable from swanlab.log import swanlog from swanlab.utils.font import FONT from swanlab.data.run.main import get_run, SwanLabRunState -from swanlab.data.run.callback import SwanLabRunCallback, MetricInfo -from swanlab.data.system import get_system_info, get_requirements +from swanlab.data.run.callback import SwanLabRunCallback, MetricInfo, RuntimeInfo from swanlab.env import ROOT from datetime import datetime import traceback @@ -107,29 +105,11 @@ def _clean_handler(self): return swanlog.debug("SwanLab Runtime has been cleaned manually.") self._train_finish_print() # 如果正在运行 - run.finish() if run.is_running else swanlog.debug("Duplicate finish, ignore it.") + run.finish() if run.running else swanlog.debug("Duplicate finish, ignore it.") def on_init(self, proj_name: str, workspace: str, logdir: str = None): self._init_logdir(logdir) - def before_init_experiment( - self, - run_id: str, - exp_name: str, - description: str, - num: int, - suffix: str, - setter: Callable[[str, str, str, str], None], - ): - requirements_path = self.settings.requirements_path - metadata_path = self.settings.metadata_path - # 将实验依赖存入 requirements.txt - with open(requirements_path, "w") as f: - f.write(get_requirements()) - # 将实验环境(硬件信息、git信息等等)存入 swanlab-metadata.json - with open(metadata_path, "w") as f: - json.dump(get_system_info(self.settings), f) - def on_run(self): swanlog.install(self.settings.console_dir) # 注入系统回调 @@ -138,8 +118,15 @@ def on_run(self): self._train_begin_print() swanlog.info("Experiment_name: " + FONT.yellow(self.settings.exp_name)) self._watch_tip_print() - if not os.path.exists(self.settings.log_dir): - os.mkdir(self.settings.log_dir) + + def on_runtime_info_update(self, r: RuntimeInfo): + # 更新运行时信息 + if r.requirements is not None: + r.requirements.write(self.settings.files_dir) + if r.metadata is not None: + r.metadata.write(self.settings.files_dir) + if r.config is not None: + r.config.write(self.settings.files_dir) def on_metric_create(self, metric_info: MetricInfo): # 出现任何错误直接返回 diff --git a/swanlab/cloud/__init__.py b/swanlab/data/cloud/__init__.py similarity index 69% rename from swanlab/cloud/__init__.py rename to swanlab/data/cloud/__init__.py index c3508e67d..d2c072ad0 100644 --- a/swanlab/cloud/__init__.py +++ b/swanlab/data/cloud/__init__.py @@ -9,6 +9,5 @@ """ from .start_thread import ThreadPool from .task_types import UploadType -from .dog.log_sniffer import LogSnifferTask -__all__ = ["UploadType", "LogSnifferTask", "ThreadPool"] +__all__ = ["UploadType", "ThreadPool"] diff --git a/swanlab/cloud/log_collector.py b/swanlab/data/cloud/log_collector.py similarity index 100% rename from swanlab/cloud/log_collector.py rename to swanlab/data/cloud/log_collector.py diff --git a/swanlab/cloud/start_thread.py b/swanlab/data/cloud/start_thread.py similarity index 100% rename from swanlab/cloud/start_thread.py rename to swanlab/data/cloud/start_thread.py diff --git a/swanlab/cloud/task_types.py b/swanlab/data/cloud/task_types.py similarity index 100% rename from swanlab/cloud/task_types.py rename to swanlab/data/cloud/task_types.py diff --git a/swanlab/cloud/utils.py b/swanlab/data/cloud/utils.py similarity index 100% rename from swanlab/cloud/utils.py rename to swanlab/data/cloud/utils.py diff --git a/swanlab/data/modules/__init__.py b/swanlab/data/modules/__init__.py index 6fbc91b36..adfe28eb5 100644 --- a/swanlab/data/modules/__init__.py +++ b/swanlab/data/modules/__init__.py @@ -4,7 +4,7 @@ from .text import Text from .line import Line, FloatConvertible from typing import Union, List -from .wrapper import DataWrapper, ErrorInfo +from .wrapper import DataWrapper, WrapperErrorInfo DataType = Union[int, float, FloatConvertible, BaseType, List[BaseType]] ChartType = BaseType.Chart @@ -19,6 +19,6 @@ "Line", "DataType", "ChartType", - "ErrorInfo", + "WrapperErrorInfo", "MediaBuffer" ] diff --git a/swanlab/data/modules/base.py b/swanlab/data/modules/base.py index b26f25ddf..ae21be20a 100644 --- a/swanlab/data/modules/base.py +++ b/swanlab/data/modules/base.py @@ -13,7 +13,6 @@ 可以看到有两次转换,如果每次都是耗时操作,那么会导致性能问题,所以对于单个类而言,上次转换的结果会被保存 """ from abc import ABC, abstractmethod -from ..settings import SwanDataSettings from typing import List, Dict, Optional, ByteString, Union, Tuple from swanlab.log import swanlog from enum import Enum @@ -109,19 +108,15 @@ def __init__(self): self.key: Optional[str] = None # 保存的step self.step: Optional[int] = None - # 当前运行时配置 - self.settings: Optional[SwanDataSettings] = None - def inject(self, key: str, step: int, settings: SwanDataSettings): + def inject(self, key: str, step: int): """ 注入属性 :param key: key名称 :param step: 当前步骤 - :param settings: 当前运行时配置 """ self.key = key self.step = step - self.settings = settings class BaseType(ABC, DynamicProperty): @@ -219,9 +214,9 @@ def file_name(self): def file_name(self, value): if not isinstance(value, str) or not value: raise TypeError(f"Expected str, but got {type(value)}") - if self.__file_name is not None: - # 此时意味着使用类似 [] * 2 的操作复制了多个相同的实例,这允许,但不推荐 - swanlog.warning("You are log duplicate and same instance, this is not recommended") + # if self.__file_name is not None: + # # 此时意味着使用类似 [] * 2 的操作复制了多个相同的实例,这允许,但不推荐 + # swanlog.warning("You are logging a duplicate and same instance, this is not recommended") self.__file_name = value diff --git a/swanlab/data/modules/wrapper.py b/swanlab/data/modules/wrapper.py index 39180e9c1..b5765fd59 100644 --- a/swanlab/data/modules/wrapper.py +++ b/swanlab/data/modules/wrapper.py @@ -13,7 +13,7 @@ from .line import Line -class ErrorInfo: +class WrapperErrorInfo: """ DataWrapper转换时的错误信息 """ @@ -91,7 +91,7 @@ def parsed(self) -> bool: return self.__result is not None or self.__error is not None @property - def error(self) -> Optional[ErrorInfo]: + def error(self) -> Optional[WrapperErrorInfo]: """ 解析时候的错误信息 """ @@ -114,13 +114,13 @@ def parse(self, **kwargs) -> Optional[ParseResult]: # [Line] if self.type == Line: if len(self.__data) > 1: - self.__error = ErrorInfo("float", "list(Line)", result.chart) + self.__error = WrapperErrorInfo("float", "list(Line)", result.chart) else: d.inject(**kwargs) try: result.float, _ = d.parse() except DataTypeError as e: - self.__error = ErrorInfo(e.expected, e.got, result.chart) + self.__error = WrapperErrorInfo(e.expected, e.got, result.chart) self.__result = result return self.__result @@ -132,7 +132,7 @@ def parse(self, **kwargs) -> Optional[ParseResult]: i.inject(**kwargs) d, r = i.parse() except DataTypeError as e: - self.__error = ErrorInfo(e.expected, e.got, result.chart) + self.__error = WrapperErrorInfo(e.expected, e.got, result.chart) return None data.append(d) buffers.append(r) @@ -156,8 +156,8 @@ def __filter_list(li: List): return None @classmethod - def create_duplicate_error(cls) -> ErrorInfo: + def create_duplicate_error(cls) -> WrapperErrorInfo: """ 快捷创建一个重复错误 """ - return ErrorInfo(None, None, None, duplicated=True) + return WrapperErrorInfo(None, None, None, duplicated=True) diff --git a/swanlab/data/run/callback.py b/swanlab/data/run/callback.py deleted file mode 100644 index c8e3ff4c0..000000000 --- a/swanlab/data/run/callback.py +++ /dev/null @@ -1,366 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/5/5 20:23 -@File: callback.py -@IDE: pycharm -@Description: - 回调函数注册抽象类 -""" -from typing import Union, Optional, Callable, Dict, List -from abc import ABC, abstractmethod -from swanlab.data.settings import SwanDataSettings -from swanlab.data.modules import ChartType, ErrorInfo, MediaBuffer -from swanlab.log import swanlog -from swanlab.utils.font import FONT -from swanlab.env import is_windows -from swanlab.package import get_package_version -from urllib.parse import quote -import atexit -import sys -import os - - -class ColumnInfo: - """ - 列信息,当创建列时,会生成这个对象 - """ - - def __init__( - self, - key: str, - namespace: str, - chart: ChartType, - sort: Optional[int] = None, - error: Optional[ErrorInfo] = None, - reference: Optional[str] = None, - config: Optional[Dict] = None, - ): - self.key = key - """ - 列的key名称 - """ - self.namespace = namespace - """ - 列的命名空间 - """ - self.chart = chart - """ - 列的图表类型 - """ - self.error = error - """ - 列的类型错误信息 - """ - self.reference = reference - """ - 列的参考对象 - """ - self.sort = sort - """ - 列在namespace中的排序 - """ - self.config = config if config is not None else {} - """ - 列的额外配置信息 - """ - - @property - def got(self): - """ - 传入的错误类型,如果列出错,返回错误类型,如果没出错,`暂时`返回None - """ - if self.error is None: - return None - return self.error.got - - @property - def expected(self): - """ - 期望的类型,如果列出错,返回期望的类型,如果没出错,`暂时`返回None - """ - if self.error is None: - return None - return self.error.expected - - -class MetricInfo: - """ - 指标信息,当新的指标被log时,会生成这个对象 - """ - __SUMMARY_NAME = "_summary.json" - - def __init__( - self, - key: str, - column_info: ColumnInfo, - error: Optional[ErrorInfo], - metric: Union[Dict, None] = None, - summary: Union[Dict, None] = None, - step: int = None, - epoch: int = None, - logdir: str = None, - metric_file_name: str = None, - media_dir: str = None, - buffers: List[MediaBuffer] = None, - ): - self.__error = error - - self.key = quote(key, safe="") - """ - 指标的key名称,被quote编码 - """ - self.column_info = column_info - """ - 指标对应的列信息 - """ - self.metric = metric - """ - 指标信息,error时为None - """ - self.summary = summary - """ - 摘要信息,error时为None - """ - self.step = step - """ - 当前指标的步数,error时为None - """ - self.epoch = epoch - """ - 当前指标对应本地的行数,error时为None - """ - self.metric_path = None if self.error else os.path.join(logdir, self.key, metric_file_name) - """ - 指标文件的路径,error时为None - """ - self.summary_path = None if self.error else os.path.join(logdir, self.key, self.__SUMMARY_NAME) - """ - 摘要文件的路径,error时为None - """ - self.media_dir = media_dir - """ - 静态文件的根文件夹 - """ - self.buffers = buffers - """ - 需要上传的媒体数据,比特流,error时为None,如果上传为非媒体类型(或Text类型),也为None - """ - # 写入文件名称,对应上传时的文件名称 - if self.buffers is not None: - for i, buffer in enumerate(self.buffers): - buffer.file_name = "{}/{}".format(self.key, metric["data"][i]) - - @property - def error(self) -> bool: - """ - 这条指标信息是否有错误,错误分几种: - 1. 列错误,列一开始就出现问题 - 2. 重复错误 - 3. 指标错误 - """ - return self.error_info is not None or self.column_error_info is not None - - @property - def column_error_info(self) -> Optional[ErrorInfo]: - """ - 列错误信息 - """ - return self.column_info.error - - @property - def error_info(self) -> Optional[ErrorInfo]: - """ - 指标错误信息 - """ - return self.__error - - @property - def duplicated_error(self) -> bool: - """ - 是否是重复的指标 - """ - return self.__error and self.__error.duplicated - - @property - def data(self) -> Union[Dict, None]: - """ - 指标数据的data字段 - """ - if self.error: - return None - return self.metric["data"] - - -class U: - """ - 工具函数类,隔离SwanLabRunCallback回调与其他工具函数 - """ - - def __init__(self): - self.settings: Optional[SwanDataSettings] = None - - def inject(self, settings: SwanDataSettings): - """ - 为SwanLabRunCallback注入settings等一些依赖,因为实例化可能在SwanLabRun之前发生 - :param settings: SwanDataSettings, 数据配置 - :return: - """ - self.settings = settings - - @staticmethod - def formate_abs_path(path: str) -> str: - """这主要针对windows环境,输入的绝对路径可能不包含盘符,这里进行补充 - 主要是用于打印效果 - 如果不是windows环境,直接返回path,相当于没有调用这个函数 - - Parameters - ---------- - path : str - 待转换的路径 - - Returns - ------- - str - 增加了盘符的路径 - """ - if not is_windows(): - return path - if not os.path.isabs(path): - return path - need_add = len(path) < 3 or path[1] != ":" - # 处理反斜杠, 保证路径的正确性 - path = path.replace("/", "\\") - if need_add: - return os.path.join(os.getcwd()[:2], path) - return path - - def _train_begin_print(self): - """ - 训练开始时的打印信息 - """ - swanlog.debug("SwanLab Runtime has initialized") - swanlog.debug("SwanLab will take over all the print information of the terminal from now on") - swanlog.info("Tracking run with swanlab version " + get_package_version()) - local_path = FONT.magenta(FONT.bold(self.formate_abs_path(self.settings.run_dir))) - swanlog.info("Run data will be saved locally in " + local_path) - - def _watch_tip_print(self): - """ - watch命令提示打印 - """ - swanlog.info( - "🌟 Run `" - + FONT.bold("swanlab watch -l {}".format(self.formate_abs_path(self.settings.swanlog_dir))) - + "` to view SwanLab Experiment Dashboard locally" - ) - - def _train_finish_print(self): - """ - 打印结束信息 - """ - swanlog.info("Experiment {} has completed".format(FONT.yellow(self.settings.exp_name))) - - -class SwanLabRunCallback(ABC, U): - """ - SwanLabRunCallback,回调函数注册类,所有以`on_`和`before_`开头的函数都会在对应的时机被调用 - 为了方便管理: - 1. `_`开头的函数为内部函数,不会被调用,且写在最开头 - 2. 所有回调按照逻辑上的触发顺序排列 - 3. 所有回调不要求全部实现,只需实现需要的回调即可 - """ - - def _register_sys_callback(self): - """ - 注册系统回调,内部使用 - """ - sys.excepthook = self._except_handler - atexit.register(self._clean_handler) - - def _unregister_sys_callback(self): - """ - 注销系统回调,内部使用 - """ - sys.excepthook = sys.__excepthook__ - atexit.unregister(self._clean_handler) - - def _clean_handler(self): - """ - 正常退出清理函数,此函数调用`run.finish` - """ - pass - - def _except_handler(self, tp, val, tb): - """ - 异常退出清理函数 - """ - pass - - def on_init(self, proj_name: str, workspace: str, logdir: str = None): - """ - 执行`swanlab.init`时调用 - 此时运行时环境变量没有被设置,此时修改环境变量还是有效的 - :param logdir: str, 用户设置的日志目录 - :param proj_name: str, 项目名称 - :param workspace: str, 工作空间 - """ - pass - - def before_init_experiment( - self, - run_id: str, - exp_name: str, - description: str, - num: int, - suffix: str, - setter: Callable[[str, str, str, str], None], - ): - """ - 在初始化实验之前调用,此时SwanLabRun已经初始化完毕 - :param run_id: str, SwanLabRun的运行id - :param exp_name: str, 实验名称 - :param description: str, 实验描述 - :param num: int, 历史实验数量 - :param suffix: str, 实验后缀 - :param setter: Callable[[str, str, str, str], None], 设置实验信息的函数,在这里设置实验信息 - """ - pass - - def on_run(self): - """ - SwanLabRun初始化完毕时调用 - """ - pass - - def on_log(self): - """ - 每次执行swanlab.log时调用 - """ - pass - - def on_column_create(self, column_info: ColumnInfo): - """ - 列创建回调函数,新增列信息时调用 - """ - pass - - def on_metric_create(self, metric_info: MetricInfo): - """ - 指标创建回调函数,新增指标信息时调用 - """ - pass - - def on_stop(self, error: str = None): - """ - 训练结束时的回调函数 - """ - pass - - @abstractmethod - def __str__(self) -> str: - """ - 返回当前回调函数的名称,这条应该是一个全局唯一的标识 - 在operator中会用到这个名称,必须唯一 - """ - pass diff --git a/swanlab/data/run/callback/__init__.py b/swanlab/data/run/callback/__init__.py new file mode 100644 index 000000000..7d1e443cc --- /dev/null +++ b/swanlab/data/run/callback/__init__.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/6/5 16:27 +@File: __init__.py +@IDE: pycharm +@Description: + 回调函数注册抽象模块 +""" +from typing import Callable +from abc import ABC, abstractmethod +from .utils import U +from .models import * +import atexit +import sys + + +class SwanLabRunCallback(ABC, U): + """ + SwanLabRunCallback,回调函数注册类,所有以`on_`和`before_`开头的函数都会在对应的时机被调用 + 为了方便管理: + 1. `_`开头的函数为内部函数,不会被调用,且写在最开头 + 2. 所有回调按照逻辑上的触发顺序排列 + 3. 带有from_*后缀的回调函数代表调用者来自其他地方,比如config、operator等,这将通过settings对象传递 + 4. 所有回调不要求全部实现,只需实现需要的回调即可 + """ + + def _register_sys_callback(self): + """ + 注册系统回调,内部使用 + """ + sys.excepthook = self._except_handler + atexit.register(self._clean_handler) + + def _unregister_sys_callback(self): + """ + 注销系统回调,内部使用 + """ + sys.excepthook = sys.__excepthook__ + atexit.unregister(self._clean_handler) + + def _clean_handler(self): + """ + 正常退出清理函数,此函数调用`run.finish` + """ + pass + + def _except_handler(self, tp, val, tb): + """ + 异常退出清理函数 + """ + pass + + def on_init(self, proj_name: str, workspace: str, logdir: str = None): + """ + 执行`swanlab.init`时调用 + 此时运行时环境变量没有被设置,此时修改环境变量还是有效的 + :param logdir: str, 用户设置的日志目录 + :param proj_name: str, 项目名称 + :param workspace: str, 工作空间 + """ + pass + + def before_init_experiment( + self, + run_id: str, + exp_name: str, + description: str, + num: int, + suffix: str, + setter: Callable[[str, str, str, str], None], + ): + """ + 在初始化实验之前调用,此时SwanLabRun已经初始化完毕 + :param run_id: str, SwanLabRun的运行id + :param exp_name: str, 实验名称 + :param description: str, 实验描述 + :param num: int, 历史实验数量 + :param suffix: str, 实验后缀 + :param setter: Callable[[str, str, str, str], None], 设置实验信息的函数,在这里设置实验信息 + """ + pass + + def on_run(self): + """ + SwanLabRun初始化完毕时调用 + """ + pass + + def on_run_error_from_operator(self, e: OperateErrorInfo): + """ + SwanLabRun初始化错误时被操作员调用 + """ + pass + + def on_runtime_info_update(self, r: RuntimeInfo): + """ + 运行时信息更新时调用 + :param r: RuntimeInfo, 运行时信息 + """ + pass + + def on_log(self): + """ + 每次执行swanlab.log时调用 + """ + pass + + def on_column_create(self, column_info: ColumnInfo): + """ + 列创建回调函数,新增列信息时调用 + """ + pass + + def on_metric_create(self, metric_info: MetricInfo): + """ + 指标创建回调函数,新增指标信息时调用 + """ + pass + + def on_stop(self, error: str = None): + """ + 训练结束时的回调函数 + """ + pass + + @abstractmethod + def __str__(self) -> str: + """ + 返回当前回调函数的名称,这条应该是一个全局唯一的标识 + 在operator中会用到这个名称,必须唯一 + """ + pass diff --git a/swanlab/data/run/callback/models/__init__.py b/swanlab/data/run/callback/models/__init__.py new file mode 100644 index 000000000..7ff7f991f --- /dev/null +++ b/swanlab/data/run/callback/models/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/6/5 16:29 +@File: __init__.py +@IDE: pycharm +@Description: + 与回调函数通信时的模型 +""" +from .key import MediaBuffer, MetricInfo, ColumnInfo +from .error import OperateErrorInfo +from .runtime import RuntimeInfo + +__all__ = ["MediaBuffer", "MetricInfo", "ColumnInfo", "OperateErrorInfo", "RuntimeInfo"] diff --git a/swanlab/data/run/callback/models/error.py b/swanlab/data/run/callback/models/error.py new file mode 100644 index 000000000..332624ba4 --- /dev/null +++ b/swanlab/data/run/callback/models/error.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/6/5 16:30 +@File: error.py +@IDE: pycharm +@Description: + 错误模型 +""" + + +class OperateErrorInfo: + """ + 操作错误信息,当操作员操作回调发生错误时,会生成这个对象,传给相应的回调函数 + """ + + def __init__(self, desc: str): + self.desc = desc + """ + 错误描述 + """ + + def __str__(self): + return f"SwanLabError: {self.desc}" diff --git a/swanlab/data/run/callback/models/key.py b/swanlab/data/run/callback/models/key.py new file mode 100644 index 000000000..06d338618 --- /dev/null +++ b/swanlab/data/run/callback/models/key.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/6/5 16:32 +@File: key.py +@IDE: pycharm +@Description: + 与Key相关的回调函数触发时的模型 +""" +from typing import Union, Optional, Dict, List +from swanlab.data.modules import ChartType, WrapperErrorInfo, MediaBuffer +from urllib.parse import quote +import os + + +class ColumnInfo: + """ + 列信息,当创建列时,会生成这个对象 + """ + + def __init__( + self, + key: str, + namespace: str, + chart: ChartType, + sort: Optional[int] = None, + error: Optional[WrapperErrorInfo] = None, + reference: Optional[str] = None, + config: Optional[Dict] = None, + ): + self.key = key + """ + 列的key名称 + """ + self.namespace = namespace + """ + 列的命名空间 + """ + self.chart = chart + """ + 列的图表类型 + """ + self.error = error + """ + 列的类型错误信息 + """ + self.reference = reference + """ + 列的参考对象 + """ + self.sort = sort + """ + 列在namespace中的排序 + """ + self.config = config if config is not None else {} + """ + 列的额外配置信息 + """ + + @property + def got(self): + """ + 传入的错误类型,如果列出错,返回错误类型,如果没出错,`暂时`返回None + """ + if self.error is None: + return None + return self.error.got + + @property + def expected(self): + """ + 期望的类型,如果列出错,返回期望的类型,如果没出错,`暂时`返回None + """ + if self.error is None: + return None + return self.error.expected + + +class MetricInfo: + """ + 指标信息,当新的指标被log时,会生成这个对象 + """ + __SUMMARY_NAME = "_summary.json" + + def __init__( + self, + key: str, + column_info: ColumnInfo, + error: Optional[WrapperErrorInfo], + metric: Union[Dict, None] = None, + summary: Union[Dict, None] = None, + step: int = None, + epoch: int = None, + logdir: str = None, + metric_file_name: str = None, + media_dir: str = None, + buffers: List[MediaBuffer] = None, + ): + self.__error = error + + self.key = quote(key, safe="") + """ + 指标的key名称,被quote编码 + """ + self.column_info = column_info + """ + 指标对应的列信息 + """ + self.metric = metric + """ + 指标信息,error时为None + """ + self.summary = summary + """ + 摘要信息,error时为None + """ + self.step = step + """ + 当前指标的步数,error时为None + """ + self.epoch = epoch + """ + 当前指标对应本地的行数,error时为None + """ + self.metric_path = None if self.error else os.path.join(logdir, self.key, metric_file_name) + """ + 指标文件的路径,error时为None + """ + self.summary_path = None if self.error else os.path.join(logdir, self.key, self.__SUMMARY_NAME) + """ + 摘要文件的路径,error时为None + """ + self.media_dir = media_dir + """ + 静态文件的根文件夹 + """ + self.buffers = buffers + """ + 需要上传的媒体数据,比特流,error时为None,如果上传为非媒体类型(或Text类型),也为None + """ + # 写入文件名称,对应上传时的文件名称 + if self.buffers is not None: + for i, buffer in enumerate(self.buffers): + buffer.file_name = "{}/{}".format(self.key, metric["data"][i]) + + @property + def error(self) -> bool: + """ + 这条指标信息是否有错误,错误分几种: + 1. 列错误,列一开始就出现问题 + 2. 重复错误 + 3. 指标错误 + """ + return self.error_info is not None or self.column_error_info is not None + + @property + def column_error_info(self) -> Optional[WrapperErrorInfo]: + """ + 列错误信息 + """ + return self.column_info.error + + @property + def error_info(self) -> Optional[WrapperErrorInfo]: + """ + 指标错误信息 + """ + return self.__error + + @property + def duplicated_error(self) -> bool: + """ + 是否是重复的指标 + """ + return self.__error and self.__error.duplicated + + @property + def data(self) -> Union[Dict, None]: + """ + 指标数据的data字段 + """ + if self.error: + return None + return self.metric["data"] diff --git a/swanlab/data/run/callback/models/runtime.py b/swanlab/data/run/callback/models/runtime.py new file mode 100644 index 000000000..dbfb55de2 --- /dev/null +++ b/swanlab/data/run/callback/models/runtime.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/6/5 16:55 +@File: runtime.py +@IDE: pycharm +@Description: + 运行时信息模型 +""" +from abc import ABC, abstractmethod +from typing import Optional, Any +import json +import yaml +import os + + +class InfoWriter(ABC): + + def __init__(self, info): + self.info = info + + @abstractmethod + def write(self, path: str): + """ + 写入到本地文件的方法 + :param path: 写入的路径,不包含文件名 + """ + pass + + @abstractmethod + def dumps(self): + """ + 将信息转换为字符串 + """ + pass + + @abstractmethod + def to_dict(self): + """ + 将信息转换为字典 + """ + pass + + +class RequirementInfo(InfoWriter): + """ + 依赖信息 + """ + + def __init__(self, info: str): + super().__init__(info) + self.name = "requirements.txt" + + def write(self, path: str): + with open(os.path.join(path, self.name), "w", encoding="utf-8") as f: + f.write(self.info) + + def dumps(self): + return self.info + + def to_dict(self): + raise NotImplementedError("RequirementInfo has no to_dict method") + + +class MetadataInfo(InfoWriter): + """ + 系统信息 + """ + + def __init__(self, info: dict): + super().__init__(info) + self.name = "swanlab-metadata.json" + + def write(self, path: str): + with open(os.path.join(path, self.name), "w", encoding="utf-8") as f: + f.write(self.dumps()) + + def dumps(self): + return json.dumps(self.to_dict(), ensure_ascii=False) + + def to_dict(self): + return self.info + + +class ConfigInfo(InfoWriter): + """ + 配置信息 + """ + + def __init__(self, info: dict): + super().__init__(info) + self.name = "config.yaml" + self.__data = None + + def write(self, path: str): + with open(os.path.join(path, self.name), "w", encoding="utf-8") as f: + f.write(self.dumps()) + + def dumps(self): + return yaml.dump(self.to_dict(), allow_unicode=True) + + def to_dict(self): + """ + 返回配置信息的字典形式 + """ + # 遍历每一个配置项,值改为value,增加desc和sort字段 + # 原因是序列化时可能会丢失key的排序信息,所以这里增加一个sort字段 + # 没有在__init__中直接修改是因为可能会有其他地方需要原始数据,并且会丢失一些性能 + if self.__data is not None: + return self.__data + self.__data = { + k: { + "value": v, + "sort": i, + "desc": "" + } for i, (k, v) in enumerate(self.info.items()) + } + return self.__data + + +class RuntimeInfo: + """ + 运行时信息,包括系统信息,依赖信息等 + 如果某些信息为None,代表没有传入配置 + """ + + def __init__(self, requirements: str = None, metadata: dict = None, config: dict = None): + """ + :param requirements: python依赖信息 + :param metadata: 系统信息 + :param config: 上传的配置信息 + """ + self.requirements: Optional[RequirementInfo] = RequirementInfo( + requirements + ) if requirements is not None else None + + self.metadata: Optional[MetadataInfo] = MetadataInfo( + metadata + ) if metadata is not None else None + + self.config: Optional[ConfigInfo] = ConfigInfo( + config + ) if config is not None else None diff --git a/swanlab/data/run/callback/utils.py b/swanlab/data/run/callback/utils.py new file mode 100644 index 000000000..b11d07652 --- /dev/null +++ b/swanlab/data/run/callback/utils.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +r""" +@DATE: 2024/6/5 16:34 +@File: utils.py +@IDE: pycharm +@Description: + 工具类 +""" +from typing import Optional, Any +from swanlab.data.run.settings import SwanDataSettings +from swanlab.log import swanlog +from swanlab.utils.font import FONT +from swanlab.env import is_windows +from swanlab.package import get_package_version +import os + + +class U: + """ + 工具函数类,隔离SwanLabRunCallback回调与其他工具函数 + """ + + def __init__(self): + self.settings: Optional[SwanDataSettings] = None + + def inject(self, settings: SwanDataSettings): + """ + 为SwanLabRunCallback注入settings等一些依赖,因为实例化可能在SwanLabRun之前发生 + :param settings: SwanDataSettings, 数据配置 + :return: + """ + self.settings = settings + + @staticmethod + def fmt_windows_path(path: str) -> str: + """这主要针对windows环境,输入的绝对路径可能不包含盘符,这里进行补充 + 主要是用于打印效果 + 如果不是windows环境,直接返回path,相当于没有调用这个函数 + + Parameters + ---------- + path : str + 待转换的路径 + + Returns + ------- + str + 增加了盘符的路径 + """ + if not is_windows(): + return path + if not os.path.isabs(path): + return path + need_add = len(path) < 3 or path[1] != ":" + # 处理反斜杠, 保证路径的正确性 + path = path.replace("/", "\\") + if need_add: + return os.path.join(os.getcwd()[:2], path) + return path + + def _train_begin_print(self): + """ + 训练开始时的打印信息 + """ + swanlog.debug("SwanLab Runtime has initialized") + swanlog.debug("SwanLab will take over all the print information of the terminal from now on") + swanlog.info("Tracking run with swanlab version " + get_package_version()) + local_path = FONT.magenta(FONT.bold(self.fmt_windows_path(self.settings.run_dir))) + swanlog.info("Run data will be saved locally in " + local_path) + + def _watch_tip_print(self): + """ + watch命令提示打印 + """ + swanlog.info( + "🌟 Run `" + + FONT.bold("swanlab watch -l {}".format(self.fmt_windows_path(self.settings.swanlog_dir))) + + "` to view SwanLab Experiment Dashboard locally" + ) + + def _train_finish_print(self): + """ + 打印结束信息 + """ + swanlog.info("Experiment {} has completed".format(FONT.yellow(self.settings.exp_name))) diff --git a/swanlab/data/run/config.py b/swanlab/data/run/config.py index 7949e3bac..a601fe526 100644 --- a/swanlab/data/run/config.py +++ b/swanlab/data/run/config.py @@ -7,26 +7,31 @@ @Description: SwanLabConfig 配置类 """ -from typing import Any -from collections.abc import Mapping +from typing import Any, Mapping, Union +from collections.abc import MutableMapping import yaml import argparse -from ..settings import SwanDataSettings from swanlab.log import swanlog import datetime import math +from typing import Callable, Optional +from .callback import RuntimeInfo +from swanlab.data.modules import Line +import re +import json -def json_serializable(obj: dict): +def json_serializable(obj): """ 将传入的字典转换为JSON可序列化格式。 + :raises TypeError: 对象不是JSON可序列化的 """ # 如果对象是基本类型,则直接返回 if isinstance(obj, (int, float, str, bool, type(None))): if isinstance(obj, float) and math.isnan(obj): - return "nan" + return Line.nan if isinstance(obj, float) and math.isinf(obj): - return "inf" + return Line.inf return obj # 将日期和时间转换为字符串 @@ -41,315 +46,192 @@ def json_serializable(obj: dict): elif isinstance(obj, dict): return {str(key): json_serializable(value) for key, value in obj.items()} - else: - # 对于其他不可序列化的类型,转换为字符串表示 - return str(obj) + # 对于可变映射,递归调用此函数处理值,并将key转换为字典 + elif isinstance(obj, MutableMapping): + return {str(key): json_serializable(value) for key, value in obj.items()} + raise TypeError(f"Object {obj} is not JSON serializable") -def thirdparty_config_process(data) -> dict: +def third_party_config_process(data) -> dict: """ 对于一些特殊的第三方库的处理,例如omegaconf + :raises TypeError: 适配的写入的第三方库都没有命中,抛出TypeError """ # 如果是omegaconf的DictConfig,则转换为字典 try: - import omegaconf - + import omegaconf # noqa if isinstance(data, omegaconf.DictConfig): return omegaconf.OmegaConf.to_container(data, resolve=True, throw_on_missing=True) - except Exception as e: + except ImportError: pass # 如果是argparse的Namespace,则转换为字典 if isinstance(data, argparse.Namespace): return vars(data) - return data + raise TypeError -class SwanLabConfig(Mapping): +def parse(config) -> dict: """ - The SwanConfig class is used for realize the invocation method of `run.config.lr`. + Check the configuration item and convert it to a JSON serializable format. """ + if config is None: + return {} + # 1. 第三方配置类型判断与转换 + try: + return third_party_config_process(config) + except TypeError: + pass + # 2. 将config转换为可被json序列化的字典 + try: + return json_serializable(config) + except TypeError: # noqa + pass + # 3. 尝试序列化,序列化成功直接返回 + try: + return json.loads(json.dumps(config)) + except Exception as e: # noqa + # 还失败就没办法了,👋 + raise TypeError(f"config: {config} is not a json serialized dict, error: {e}") - # 配置字典 - __config = dict() - - # 运行时设置 - __settings = dict() - - @property - def _inited(self): - return self.__settings.get("save_path") is not None - def __init__(self, config: dict = None, settings: SwanDataSettings = None): - """ - 实例化配置类,如果settings不为None,说明是通过swanlab.init调用的,否则是通过swanlab.config调用的 +__config_attr__ = ['_SwanLabConfig__config', '_SwanLabConfig__on_setter'] - Parameters - ---------- - settings : SwanDataSettings, optional - 运行时设置 - """ - if config is None: - config = {} - self.__config.update(config) - self.__settings["save_path"] = settings.config_path if settings is not None else None - self.__settings["should_save"] = settings.should_save if settings is not None else False - if self._inited: - self.save() +class SwanLabConfig(MutableMapping): + """ + The SwanConfig class is used for realize the invocation method of `run.config.lr`. - @property - def should_shave(self): - return self.__settings.get("should_save") + Attention: + The configuration item must be JSON serializable; Cannot set private attributes by `.__xxx`. + """ - @staticmethod - def __check_config(config: dict) -> dict: + def __init__( + self, + config: Union[MutableMapping, argparse.Namespace] = None, + on_setter: Optional[Callable[[RuntimeInfo], Any]] = None + ): """ - 检查配置是否合法,确保它可以被 JSON/YAML 序列化,并返回转换后的配置字典。 + 实例化配置类,如果settings不为None,说明是通过swanlab.init调用的,否则是通过swanlab.config调用的 """ - if config is None: - return {} - # config必须可以被json序列化 - try: - # 第三方配置类型判断与转换 - config = thirdparty_config_process(config) - - # 将config转换为可被json序列化的字典 - config = json_serializable(dict(config)) - - # 尝试序列化,如果还是失败就退出 - yaml.dump(config) - - except: - raise TypeError(f"config: {config} is not a valid dict, which can be json serialized") - - return config + # 每一个实例有自己的config + self.__config = {} + if config is not None: + self.__config.update(parse(config)) + self.__on_setter = on_setter @staticmethod - def __check_private(name: str): - """ - 检查属性名是否是私有属性,如果是私有属性,抛出异常 - - Parameters - ---------- - name : str - 属性名 - - Raises - ---------- - AttributeError - 如果属性名是私有属性,抛出异常 + def __fmt_config(config: dict): """ - methods = ["set", "get", "pop"] - swanlog.debug(f"Check private attribute: {name}") - if name.startswith("__") or name.startswith("_SwanLabConfig__") or name in methods: - raise AttributeError("You can not get private attribute") - - def __setattr__(self, name: str, value: Any) -> None: - """ - 自定义属性设置方法。如果属性名不是私有属性,则同时更新配置字典并保存。 - 允许通过点号方式设置属性,但不允许设置私有属性: - ```python - run.config.lr = 0.01 # 允许 - run.config._lr = 0.01 # 允许 - run.config.__lr = 0.01 # 不允许 - ``` - - 值得注意的是类属性的设置不会触发此方法 + 格式化config,值改为value字段,增加desc和sort字段 """ + # 遍历每一个配置项,值改为value + sort = 0 + for key, value in config.items(): + config[key] = {"value": value, "desc": "", "sort": sort} + sort += 1 - # 判断是否是私有属性 - name = str(name) - self.__check_private(name) - # 设置属性,并判断是否已经初始化,如果是,则调用保存方法 - self.__dict__[name] = value - # 同步到配置字典 - self.__config[name] = value - self.save() - - def __setitem__(self, name: str, value: Any) -> None: + def __save(self): """ - 以字典方式设置配置项的值,并保存,但不允许设置私有属性: - ```python - run.config["lr"] = 0.01 # 允许 - run.config["_lr"] = 0.01 # 允许 - run.config["__lr"] = 0.01 # 不允许 - ``` + 保存config为dict """ - # 判断是否是私有属性 - name = str(name) - self.__check_private(name) - self.__config[name] = value - self.save() - - def set(self, name: str, value: Any) -> None: - """ - Explicitly set the value of a configuration item and save it. For example: - - ```python - run.config.set("lr", 0.01) # Allowed - run.config.set("_lr", 0.01) # Allowed - run.config.set("__lr", 0.01) # Not allowed - ``` - - Parameters - ---------- - name: str - Name of the configuration item - value: Any - Value of the configuration item + if not self.__on_setter: + return swanlog.debug("The configuration is not saved because the setter is not set.") + try: + # 深度拷贝一次,防止引用传递 + data = yaml.load(yaml.dump(self.__config), Loader=yaml.FullLoader) + except Exception as e: + swanlog.error(f"Error occurred when saving config: {e}") + return + # 遍历每一个配置项,值改为value,如果是字典,则递归调用 + self.__fmt_config(data) + r = RuntimeInfo(config=self.__config) + self.__on_setter(r) + swanlog.debug(f"Save configuration.") - Raises - ---------- - AttributeError - If the attribute name is private, an exception is raised - """ - name = str(name) - self.__check_private(name) - self.__config[name] = value - self.save() + # ---------------------------------- 实现对象风格 ---------------------------------- - def pop(self, name: str) -> bool: + def __delattr__(self, name: str): """ - Delete a configuration item; if the item does not exist, skip. - - Parameters - ---------- - name : str - Name of the configuration item - - Returns - ---------- - bool - True if deletion is successful, False otherwise + 删除配置项,如果配置项不存在 """ + # _*__正则开头的属性不允许删除 + if re.match(r"_.*__", name): + raise AttributeError(f"Attribute '{name}' is private and cannot be deleted") try: del self.__config[name] - self.save() - return True + self.__save() except KeyError: - return False + raise AttributeError(f"You have not deleted '{name}' in the config of the current experiment") - def get(self, name: str): + def __getattr__(self, name: str): """ - Get the value of a configuration item. If the item does not exist, raise AttributeError. - - Parameters - ---------- - name : str - Name of the configuration item - - Returns - ---------- - value : Any - Value of the configuration item - - Raises - ---------- - AttributeError - If the configuration item does not exist, an AttributeError is raised + 如果以点号方式访问属性且属性不存在于类中,尝试从配置字典中获取。 """ try: return self.__config[name] except KeyError: - raise AttributeError(f"You have not retrieved '{name}' in the config of the current experiment") - - def clean(self): - """ - 清空配置字典 - """ - self.__config.clear() + raise AttributeError(f"You have not get '{name}' in the config of the current experiment") - def update(self, data: dict): + def __setattr__(self, name: str, value: Any) -> None: """ - Update the configuration item with the dict provided and save it. - - :param data: dict of configuration items + Custom setter attribute, user can not set private attributes. """ + name = str(name) + if name in __config_attr__: + return super().__setattr__(name, value) + # _*__正则开头的属性不允许设置 + if re.match(r"_.*__", name): + raise AttributeError(f"Attribute '{name}' is private and cannot be set") + # 否则应该设置到配置字典中 + self.__config[name] = parse(value) + self.__save() - self.__config.update(data) - self.save() - - def __getattr__(self, name: str): - """ - 如果以点号方式访问属性且属性不存在于类中,尝试从配置字典中获取。 - """ - try: - return self.__config[name] - except KeyError: - raise AttributeError(f"You have not get '{name}' in the config of the current experiment") + # ---------------------------------- 实现字典风格 ---------------------------------- - def __getitem__(self, name: str): + def get(self, name: str, default=None): """ - 以字典方式获取配置项的值。 + Get the value of a configuration item. If the item does not exist, raise AttributeError. """ try: return self.__config[name] except KeyError: - raise AttributeError(f"You have not get '{name}' in the config of the current experiment") + return default - def __delattr__(self, name: str) -> bool: + def __delitem__(self, name: str): """ 删除配置项,如果配置项不存在,跳过 - - Parameters - ---------- - name : str - 配置项名称 - - Returns - ---------- - bool - 是否删除成功 """ try: del self.__config[name] - return True + self.__save() except KeyError: - return False + raise KeyError(f"You have not set '{name}' in the config of the current experiment when deleting") - def __delitem__(self, name: str) -> bool: + def __getitem__(self, name: str): """ - 删除配置项,如果配置项不存在,跳过 - - Parameters - ---------- - name : str - 配置项名称 - - Returns - ---------- - bool - 是否删除成功 + 以字典方式获取配置项的值 """ + # 如果self.__dict__中有name属性,则返回 + if not isinstance(name, str): + raise TypeError(f"Key must be a string, but got {type(name)}") + # 以_SwanLabConfig__开头,删除 + if name.startswith("_SwanLabConfig__"): + name = name[15:] try: - del self.__config[name] - return True + return self.__config[name] except KeyError: - return False + raise KeyError(f"You have not get '{name}' in the config of the current experiment") - def save(self): + def __setitem__(self, name: str, value: Any) -> None: """ - 保存config为json,不必校验config的YAML格式,将在写入时完成校验 + Set the value of a configuration item. If the item does not exist, create it. + User are not allowed to set private attributes. """ - if not self.should_shave: - return - swanlog.debug("Save config to {}".format(self.__settings.get("save_path"))) - - serialization_config = self.__check_config(self.__config) - with open(self.__settings.get("save_path"), "w") as f: - # 将config的每个key的value转换为desc和value两部分,value就是原来的value,desc是None - # 这样做的目的是为了在web界面中显示config的内容,desc是用于描述value的 - config = { - key: { - "desc": None, - "sort": index, - "value": value, - } - for index, (key, value) in enumerate(serialization_config.items()) - } - yaml.dump(config, f) + name = str(name) + self.__config[name] = parse(value) + self.__save() def __iter__(self): """ @@ -365,3 +247,45 @@ def __len__(self): def __str__(self): return str(self.__config) + + # ---------------------------------- 其他函数 ---------------------------------- + + def set(self, name: str, value: Any): + """ + Explicitly set the value of a configuration item and save it. + Private attributes are not allowed to be set. + """ + name = str(name) + self.__config[name] = parse(value) + self.__save() + + def pop(self, name: str): + """ + Delete a configuration item; if the item does not exist, skip. + """ + try: + t = self.__config[name] + del self.__config[name] + self.__save() + return t + except KeyError: + return None + + def update(self, __m: Union[MutableMapping, argparse.Namespace] = None, **kwargs): + """ + Update the configuration with the key/value pairs from __m, overwriting existing keys. + """ + if __m is not None: + for k, v in parse(__m).items(): + self.__config[k] = v + for k, v in kwargs.items(): + self.__config[k] = parse(v) + self.__save() + + def clean(self): + """ + Clean the configuration. + Attention: This method will reset the instance and instance will not automatically save the configuration. + """ + self.__config.clear() + self.__on_setter = None diff --git a/swanlab/data/run/exp.py b/swanlab/data/run/exp.py index efc4f38ec..961bb7eee 100644 --- a/swanlab/data/run/exp.py +++ b/swanlab/data/run/exp.py @@ -1,9 +1,9 @@ -from swanlab.data.settings import SwanDataSettings -from swanlab.data.modules import DataWrapper, Line, ErrorInfo +from swanlab.data.run.settings import SwanDataSettings +from swanlab.data.modules import DataWrapper, Line from swanlab.log import swanlog from typing import Dict, Optional from swanlab.utils import create_time -from .callback import MetricInfo, ColumnInfo +from .callback import MetricInfo, ColumnInfo, RuntimeInfo from .operator import SwanLabRunOperator import json import math @@ -28,6 +28,7 @@ def __init__(self, settings: SwanDataSettings, operator: SwanLabRunOperator) -> self.settings = settings # 当前实验的所有tag数据字段 self.keys: Dict[str, SwanLabKey] = {} + # TODO 操作员不传递给实验 self.__operator = operator def add(self, key: str, data: DataWrapper, step: int = None) -> MetricInfo: @@ -59,7 +60,7 @@ def add(self, key: str, data: DataWrapper, step: int = None) -> MetricInfo: if step in key_obj.steps: swanlog.warning(f"Step {step} on key {key} already exists, ignored.") return MetricInfo(key, key_obj.column_info, DataWrapper.create_duplicate_error()) - data.parse(step=step, settings=self.settings, key=key) + data.parse(step=step, key=key) # ---------------------------------- 图表创建 ---------------------------------- diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 27e25d0ad..d28caa68a 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -7,18 +7,20 @@ @Description: 在此处定义SwanLabRun类并导出 """ -from ..settings import SwanDataSettings +from .settings import SwanDataSettings from swanlab.log import swanlog from swanlab.data.modules import MediaType, DataWrapper, FloatConvertible, Line +from .system import get_system_info, get_requirements +from swanlab.package import get_package_version +from swanlab.utils.file import check_key_format from .config import SwanLabConfig -import random from enum import Enum from .exp import SwanLabExp from datetime import datetime -from typing import Callable, Optional, Dict, MutableMapping -from .operator import SwanLabRunOperator +from typing import Callable, Optional, Dict +from .operator import SwanLabRunOperator, RuntimeInfo from swanlab.env import get_mode -from ...utils.file import check_key_format +import random class SwanLabRunState(Enum): @@ -43,7 +45,7 @@ def __init__( project_name: str = None, experiment_name: str = None, description: str = None, - run_config: MutableMapping = None, + run_config=None, log_level: str = None, suffix: str = None, exp_num: int = None, @@ -62,7 +64,7 @@ def __init__( description : str, optional 实验描述,用于对当前实验进行更详细的介绍或标注 如果不提供此参数(为None),可以在web界面中进行修改,这意味着必须在此改为空字符串"" - run_config : MutableMapping, optional + run_config : Any, optional 实验参数配置,可以在web界面中显示,如学习率、batch size等 不需要做任何限制,但必须是字典类型,可被json序列化,否则会报错 log_level : str, optional @@ -77,9 +79,7 @@ def __init__( operator : SwanLabRunOperator, optional 实验操作员,用于批量处理回调函数的调用,如果不提供此参数(为None),则会自动生成一个实例 """ - - global run - if run is not None: + if self.is_started(): raise RuntimeError("SwanLabRun has been initialized") # ---------------------------------- 初始化类内参数 ---------------------------------- @@ -93,11 +93,12 @@ def __init__( self.__settings = SwanDataSettings(run_id=self.__run_id, should_save=not self.__operator.disabled) self.__operator.inject(self.__settings) # ---------------------------------- 初始化日志记录器 ---------------------------------- - # output、console_dir等内容不依赖于实验名称的设置 swanlog.set_level(self.__check_log_level(log_level)) # ---------------------------------- 初始化配置 ---------------------------------- - # 给外部1个config - self.__config = SwanLabConfig(run_config, self.__settings) + global config + config.update(run_config) + setattr(config, "_SwanLabConfig__on_setter", self.__operator.on_runtime_info_update) + self.__config = config # ---------------------------------- 注册实验 ---------------------------------- self.__exp: SwanLabExp = self.__register_exp(experiment_name, description, suffix, num=exp_num) # 实验状态标记,如果status不为0,则无法再次调用log方法 @@ -107,12 +108,19 @@ def __init__( def _(state: SwanLabRunState): self.__state = state - global _change_run_state + global _change_run_state, run _change_run_state = _ run = self # ---------------------------------- 初始化完成 ---------------------------------- self.__operator.on_run() + # 执行__save,必须在on_run之后,因为on_run之前部分的信息还没完全初始化 + getattr(config, "_SwanLabConfig__save")() + # 系统信息采集 + self.__operator.on_runtime_info_update(RuntimeInfo( + requirements=get_requirements(), + metadata=get_system_info(get_package_version(), self.settings.log_dir) + )) @property def operator(self) -> SwanLabRunOperator: @@ -138,16 +146,32 @@ def get_state(cls) -> SwanLabRunState: global run return run.state if run is not None else SwanLabRunState.NOT_STARTED + @staticmethod + def is_started() -> bool: + """ + If the experiment has been initialized, return True, otherwise return False. + """ + return get_run() is not None + @property - def is_crashed(self) -> bool: + def crashed(self) -> bool: + """ + If the experiment is marked as 'CRASHED', return True, otherwise return False. + """ return self.__state == SwanLabRunState.CRASHED @property - def is_success(self) -> bool: + def success(self) -> bool: + """ + If the experiment is marked as 'SUCCESS', return True, otherwise return False. + """ return self.__state == SwanLabRunState.SUCCESS @property - def is_running(self) -> bool: + def running(self) -> bool: + """ + If the experiment is marked as 'RUNNING', return True, otherwise return False. + """ return self.__state == SwanLabRunState.RUNNING @staticmethod @@ -162,11 +186,11 @@ def finish(state: SwanLabRunState = SwanLabRunState.SUCCESS, error=None): :param state: The state of the experiment, it can be 'SUCCESS', 'CRASHED' or 'RUNNING'. :param error: The error message when the experiment is marked as 'CRASHED'. If not 'CRASHED', it should be None. """ - global run + global run, config # 分为几步 # 1. 设置数据库实验状态为对应状态 # 2. 判断是否为云端同步,如果是则开始关闭线程池和同步状态 - # 3. 清空run和config对象,run改为局部变量_run,config被清空 + # 3. 清空run和config对象,run改为局部变量_run,新建一个config对象,原本的config对象内容转移到新的config对象,全局config被清空 # 4. 返回_run if run is None: raise RuntimeError("The run object is None, please call `swanlab.init` first.") @@ -182,7 +206,10 @@ def finish(state: SwanLabRunState = SwanLabRunState.SUCCESS, error=None): # disabled 模式下没有install,所以会报错 pass - run.config.clean() + # ---------------------------------- 清空config和run ---------------------------------- + _config = SwanLabConfig(config) + setattr(run, "_SwanLabRun__config", _config) + config.clean() _run, run = run, None return _run @@ -196,7 +223,7 @@ def settings(self) -> SwanDataSettings: return self.__settings @property - def config(self): + def config(self) -> SwanLabConfig: """ This property allows you to access the 'config' content passed through `init`, and allows you to modify it. The latest configuration after each modification @@ -293,11 +320,9 @@ def __register_exp( 注册实验,将实验配置写入数据库中,完成实验配置的初始化 """ - # ---------------------------------- 初始化实验 ---------------------------------- - def setter(exp_name: str, light_color: str, dark_color: str, desc: str): """ - 设置实验相关信息 + 设置实验相关信息的函数 :param exp_name: 实验名称 :param light_color: 亮色 :param dark_color: 暗色 @@ -305,12 +330,13 @@ def setter(exp_name: str, light_color: str, dark_color: str, desc: str): :return: """ # 实验创建成功,设置实验相关信息 - self.__settings.exp_name = exp_name + self.settings.exp_name = exp_name self.settings.exp_colors = (light_color, dark_color) self.settings.description = desc self.__operator.before_init_experiment(self.__run_id, experiment_name, description, num, suffix, setter) - return SwanLabExp(self.__settings, operator=self.__operator) + + return SwanLabExp(self.settings, operator=self.__operator) @staticmethod def __check_log_level(log_level: str) -> str: @@ -331,8 +357,12 @@ def __check_log_level(log_level: str) -> str: """ 修改实验状态的函数,用于在实验状态改变时调用 """ + +# 全局唯一的config对象,不应该重新赋值 config: Optional["SwanLabConfig"] = SwanLabConfig() -"""Global config instance. After the user calls finish(), config will be set to None.""" +""" +Global config instance. +""" def _set_run_state(state: SwanLabRunState): diff --git a/swanlab/data/run/operator.py b/swanlab/data/run/operator.py index dd46280df..04d80284c 100644 --- a/swanlab/data/run/operator.py +++ b/swanlab/data/run/operator.py @@ -8,8 +8,10 @@ 回调函数操作员,批量处理回调函数的调用 """ from typing import List, Union, Dict, Any, Callable -from .callback import SwanLabRunCallback, MetricInfo, ColumnInfo -from ..settings import SwanDataSettings +from .callback import SwanLabRunCallback, MetricInfo, ColumnInfo, OperateErrorInfo, RuntimeInfo +from swanlab.data.run.settings import SwanDataSettings +import swanlab.error as E +from swanlab.utils import FONT OperatorReturnType = Dict[str, Any] @@ -91,7 +93,18 @@ def before_init_experiment( ) def on_run(self): - return self.__run_all("on_run") + try: + return self.__run_all("on_run") + except E.ApiError as e: + FONT.brush("", 50) + if e.resp.status_code == 409: + error = OperateErrorInfo("The experiment name already exists, please change the experiment name") + return self.__run_all("on_run_error_from_operator", error) + else: + raise e + + def on_runtime_info_update(self, r: RuntimeInfo): + return self.__run_all("on_runtime_info_update", r) def on_log(self): return self.__run_all("on_log") @@ -103,4 +116,7 @@ def on_column_create(self, column_info: ColumnInfo): return self.__run_all("on_column_create", column_info) def on_stop(self, error: str = None): - return self.__run_all("on_stop", error) + r = self.__run_all("on_stop", error) + # 清空所有注册的回调函数 + self.callbacks.clear() + return r diff --git a/swanlab/data/settings.py b/swanlab/data/run/settings.py similarity index 56% rename from swanlab/data/settings.py rename to swanlab/data/run/settings.py index 9b678651a..915119650 100644 --- a/swanlab/data/settings.py +++ b/swanlab/data/run/settings.py @@ -8,50 +8,20 @@ 数据收集部分配置,此为运行时生成的配置, """ import os -from ..env import get_swanlog_dir +from swanlab.env import get_swanlog_dir from typing import Tuple from swanlab.package import get_package_version -class SwanDataSettings: - def __init__(self, run_id: str, should_save: bool) -> None: - """实验名称 - - Parameters - ---------- - exp_name : str - 实验名称,实验名称应该唯一,由0-9,a-z,A-Z," ","_","-","/"组成 - 但此处不做限制 - run_id : str - 实验运行id,由时间戳生成,用于区分不同实验存储目录 - """ +class LazySettings: + """ + 需要外界设置的信息,他们并不在一开始就被赋予意义,如果在设置前访问,返回None + """ + + def __init__(self): self.__exp_name = None self.__exp_colors = None self.__description = None - # 日志存放目录 - self.__swanlog_dir: str = get_swanlog_dir() - # 日志存放目录的上一级目录,默认情况下这应该是项目根目录 - self.__root_dir: str = os.path.dirname(self.__swanlog_dir) - # 实验运行id - self.__run_id: str = run_id - self.__version = get_package_version() - self.__should_save = should_save - - @property - def should_save(self): - """ - 是否应该保存实验信息 - """ - return self.__should_save - - def mkdir(self, path: str) -> None: - """创建目录""" - if not os.path.exists(path) and self.should_save: - os.makedirs(path, exist_ok=True) - - @property - def version(self) -> str: - return self.__version @property def exp_name(self) -> str: @@ -91,65 +61,109 @@ def description(self, description: str) -> None: raise ValueError("description can only be set once") self.__description = description + +class SwanDataSettings(LazySettings): + """ + SwanLabRun的配置信息,包括当前实验路径信息等 + + 涉及路径的属性都已经自动转换为绝对路径,并且文件夹路径在`should_save=True`时会自动创建 + + "几乎"所有属性都为只读属性,只有在初始化时可以设置 + """ + + def __init__(self, run_id: str, should_save: bool) -> None: + """ + 初始化 + :param run_id: 实验运行id + :param should_save: 是否应该保存实验相关信息,如果保存,相关文件夹将自动创建(文件不会自动创建) + """ + LazySettings.__init__(self) + # ---------------------------------- 静态信息 ---------------------------------- + self.__should_save = should_save + self.__run_id = run_id + self.__version = get_package_version() + # ---------------------------------- 文件夹信息 ---------------------------------- + logdir = get_swanlog_dir() + self.__root_dir = os.path.dirname(logdir) + self.__swanlog_dir = logdir + self.__run_dir = os.path.join(logdir, run_id) + self.__console_dir = os.path.join(self.run_dir, "console") + self.__log_dir = os.path.join(self.run_dir, "logs") + self.__files_dir = os.path.join(self.run_dir, "files") + self.__media_dir = os.path.join(self.run_dir, "media") + # ---------------------------------- 文件信息 ---------------------------------- + self.__error_path = os.path.join(self.console_dir, "error.log") + + def mkdir(self, path: str) -> None: + """创建目录""" + if not os.path.exists(path) and self.should_save: + os.makedirs(path, exist_ok=True) + + # ---------------------------------- 静态属性 ---------------------------------- + + @property + def should_save(self): + """ + 是否应该保存实验信息 + """ + return self.__should_save + + @property + def version(self) -> str: + return self.__version + @property def run_id(self) -> str: """实验运行id""" return self.__run_id + # ---------------------------------- 文件夹属性 ---------------------------------- + @property def root_dir(self) -> str: """根目录""" + # 必然存在,不需要创建 return self.__root_dir @property def swanlog_dir(self) -> str: """swanlog目录,存储所有实验的日志目录,也是runs.swanlab数据库的存储目录""" + self.mkdir(self.__swanlog_dir) return self.__swanlog_dir @property def run_dir(self) -> str: """实验日志、信息文件夹路径""" - return os.path.join(get_swanlog_dir(), self.run_id) - - @property - def error_path(self) -> str: - """错误日志文件路径""" - return os.path.join(self.console_dir, "error.log") + self.mkdir(self.__run_dir) + return self.__run_dir @property def log_dir(self) -> str: """记录用户日志文件夹路径""" - return os.path.join(self.run_dir, "logs") + self.mkdir(self.__log_dir) + return self.__log_dir @property def console_dir(self) -> str: """记录终端日志路径""" - return os.path.join(self.run_dir, "console") + self.mkdir(self.__console_dir) + return self.__console_dir @property def media_dir(self) -> str: - """媒体文件路径""" - path = os.path.join(self.run_dir, "media") - self.mkdir(path) - return path + """静态资源路径""" + self.mkdir(self.__media_dir) + return self.__media_dir @property def files_dir(self) -> str: """实验配置信息路径""" - path = os.path.join(self.run_dir, "files") - self.mkdir(path) - return path + self.mkdir(self.__files_dir) + return self.__files_dir - @property - def requirements_path(self) -> str: - """实验依赖的存储文件""" - return os.path.join(self.files_dir, "requirements.txt") + # ---------------------------------- 文件属性 ---------------------------------- @property - def metadata_path(self) -> str: - """实验环境存储文件""" - return os.path.join(self.files_dir, "swanlab-metadata.json") - - @property - def config_path(self) -> str: - return os.path.join(self.files_dir, "config.yaml") + def error_path(self) -> str: + """错误日志文件路径""" + return self.__error_path diff --git a/swanlab/data/system/__init__.py b/swanlab/data/run/system/__init__.py similarity index 100% rename from swanlab/data/system/__init__.py rename to swanlab/data/run/system/__init__.py diff --git a/swanlab/data/system/bin/apple_gpu_stats b/swanlab/data/run/system/bin/apple_gpu_stats similarity index 100% rename from swanlab/data/system/bin/apple_gpu_stats rename to swanlab/data/run/system/bin/apple_gpu_stats diff --git a/swanlab/data/system/info.py b/swanlab/data/run/system/info.py similarity index 94% rename from swanlab/data/system/info.py rename to swanlab/data/run/system/info.py index c1c609d88..d5be93daf 100644 --- a/swanlab/data/system/info.py +++ b/swanlab/data/run/system/info.py @@ -13,8 +13,7 @@ import subprocess import multiprocessing import pynvml -from ...log import swanlog -from swanlab.data.settings import SwanDataSettings +from swanlab.log import swanlog def __replace_second_colon(input_string, replacement): @@ -23,7 +22,7 @@ def __replace_second_colon(input_string, replacement): if first_colon_index != -1: second_colon_index = input_string.find(":", first_colon_index + 1) if second_colon_index != -1: - return input_string[:second_colon_index] + replacement + input_string[second_colon_index + 1:] + return input_string[:second_colon_index] + replacement + input_string[second_colon_index + 1 :] return input_string @@ -108,7 +107,7 @@ def __get_nvidia_gpu_info(): gpu_name = gpu_name.decode("utf-8") info["type"].append(gpu_name) # 获取 GPU 的总显存, 单位为GB - info["memory"].append(round(pynvml.nvmlDeviceGetMemoryInfo(handle).total / (1024 ** 3))) + info["memory"].append(round(pynvml.nvmlDeviceGetMemoryInfo(handle).total / (1024**3))) except pynvml.NVMLError as e: swanlog.debug(f"An error occurred when getting GPU info: {e}") @@ -190,7 +189,7 @@ def __get_memory_size(): try: # 获取系统总内存大小 mem = psutil.virtual_memory() - total_memory = round(mem.total / (1024 ** 3)) # 单位为GB + total_memory = round(mem.total / (1024**3)) # 单位为GB return total_memory except Exception as e: swanlog.debug(f"An error occurred when getting memory size: {e}") @@ -226,13 +225,13 @@ def get_requirements() -> str: return None -def get_system_info(settings: SwanDataSettings): - """获取系统信息""" +def get_system_info(version: str, logdir: str): + """获取系统信息 + :param version: swanlab版本号 + :param logdir: swanlab日志目录 + """ return { - "swanlab": { - "version": settings.version, - "logdir": settings.swanlog_dir - }, + "swanlab": {"version": version, "logdir": logdir}, "hostname": socket.gethostname(), "os": platform.platform(), "python": platform.python_version(), diff --git a/swanlab/data/system/monitor.py b/swanlab/data/run/system/monitor.py similarity index 100% rename from swanlab/data/system/monitor.py rename to swanlab/data/run/system/monitor.py diff --git a/swanlab/data/sdk.py b/swanlab/data/sdk.py index 837b7c2a0..0b268d030 100644 --- a/swanlab/data/sdk.py +++ b/swanlab/data/sdk.py @@ -52,11 +52,6 @@ def _check_proj_name(name: str) -> str: return _name -def _is_inited(): - """检查是否已经初始化""" - return get_run() is not None - - def login(api_key: str = None): """ Login to SwanLab Cloud. If you already have logged in, you can use this function to relogin. @@ -68,7 +63,7 @@ def login(api_key: str = None): api_key : str authentication key, if not provided, the key will be read from the key file. """ - if _is_inited(): + if SwanLabRun.is_started(): raise RuntimeError("You must call swanlab.login() before using init()") CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.get_login_info() @@ -145,11 +140,11 @@ def init( SwanLab will attempt to replace them from the configuration file you provided; otherwise, it will use the parameters you passed as the definitive ones. """ - run = get_run() - if run is not None: + if SwanLabRun.is_started(): swanlog.warning("You have already initialized a run, the init function will be ignored") - return run + return get_run() # ---------------------------------- 一些变量、格式检查 ---------------------------------- + # TODO 下个版本删除 if "cloud" in kwargs: swanlog.warning( "The `cloud` parameter in swanlab.init is deprecated and will be removed in the future" @@ -191,6 +186,23 @@ def init( return run +def should_call_after_init(text): + """ + 装饰器,限制必须在实验初始化后调用 + """ + + def decorator(func): + def wrapper(*args, **kwargs): + if not SwanLabRun.is_started(): + raise RuntimeError(text) + return func(*args, **kwargs) + + return wrapper + + return decorator + + +@should_call_after_init("You must call swanlab.init() before using log()") def log(data: Dict[str, DataType], step: int = None): """ Log a row of data to the current run. @@ -206,13 +218,12 @@ def log(data: Dict[str, DataType], step: int = None): The step number of the current data, if not provided, it will be automatically incremented. If step is duplicated, the data will be ignored. """ - if not _is_inited(): - raise RuntimeError("You must call swanlab.init() before using log()") run = get_run() ll = run.log(data, step) return ll +@should_call_after_init("You must call swanlab.init() before using finish()") def finish(state: SwanLabRunState = SwanLabRunState.SUCCESS, error=None): """ Finish the current run and close the current experiment @@ -222,9 +233,7 @@ def finish(state: SwanLabRunState = SwanLabRunState.SUCCESS, error=None): If you mark the experiment as 'CRASHED' manually, `error` must be provided. """ run = get_run() - if not get_run(): - raise RuntimeError("You must call swanlab.data.init() before using finish()") - if not run.is_running: + if not run.running: return swanlog.error("After experiment is finished, you can't call finish() again.") run.finish(state, error) diff --git a/swanlab/db/callback.py b/swanlab/db/callback.py index adf176f03..1767db685 100644 --- a/swanlab/db/callback.py +++ b/swanlab/db/callback.py @@ -17,6 +17,7 @@ from swanlab.utils.file import check_exp_name_format, check_desc_format from datetime import datetime import time +import sys class GlomCallback(SwanLabRunCallback): @@ -94,6 +95,12 @@ def on_init(self, proj_name: str, *args, **kwargs): # 初始化项目数据库 Project.init(proj_name) + def on_run_error_from_operator(self, e): + # 更新数据库中的实验状态 + swanlog.error(e) + Experiment.purely_delete(run_id=self.settings.run_id) + sys.exit(409) + def before_init_experiment( self, run_id: str, diff --git a/swanlab/env.py b/swanlab/env.py index 787150247..265f0bd51 100644 --- a/swanlab/env.py +++ b/swanlab/env.py @@ -76,7 +76,7 @@ def get_mode(env: Optional[Env] = None) -> Optional[str]: def get_swanlog_dir(env: Optional[Env] = None) -> Optional[str]: - """获取swanlog路径 + """获取swanlog路径,返回值为绝对路径 Returns ------- diff --git a/swanlab/package.json b/swanlab/package.json index d13ccf075..e7ad8392c 100644 --- a/swanlab/package.json +++ b/swanlab/package.json @@ -1,6 +1,6 @@ { "name": "swanlab", - "version": "0.3.8", + "version": "0.3.9", "description": "", "python": "true", "host": { diff --git a/test/create_experiment.py b/test/create_experiment.py index fd8fb7c7c..1921ff3f2 100644 --- a/test/create_experiment.py +++ b/test/create_experiment.py @@ -42,28 +42,13 @@ if epoch % 10 == 0: # 测试audio sample_rate = 44100 - test_audio_arr = np.random.randn(2, 100000) - swanlab.log( - { - "test/audio": [swanlab.Audio(test_audio_arr, sample_rate, caption="test")] * (epoch // 10), - }, - step=epoch, - ) + audios = [swanlab.Audio(np.random.randn(2, 100000), sample_rate, caption="test") for _ in range(epoch // 10)] + swanlab.log({"test/audio": audios}, step=epoch) # 测试image - test_image = np.random.randint(0, 255, (100, 100, 3)) - swanlab.log( - { - "test/image": [swanlab.Image(test_image, caption="test")] * (epoch // 10), - }, - step=epoch, - ) + images = [swanlab.Image(np.random.randint(0, 255, (100, 100, 3)), caption="test") for _ in range(epoch // 10)] + swanlab.log({"test/image": images}, step=epoch) # 测试text - swanlab.log( - { - "text": swanlab.Text("这是一段测试文本", caption="test"), - }, - step=epoch, - ) + swanlab.log({"text": swanlab.Text("这是一段测试文本", caption="test")}, step=epoch) # 测试折线图 swanlab.log({"t/accuracy": acc, "loss": loss, "loss2": loss2}) else: diff --git a/test/unit/data/run/pytest_config.py b/test/unit/data/run/pytest_config.py index 005166a61..17a012dba 100644 --- a/test/unit/data/run/pytest_config.py +++ b/test/unit/data/run/pytest_config.py @@ -1,8 +1,11 @@ -from swanlab.data.run.main import SwanLabRun, get_run, SwanLabConfig, swanlog +import math +import yaml +from swanlab.data.run.main import SwanLabRun, get_run, swanlog, get_config +from swanlab.data.run.config import SwanLabConfig, parse, Line, RuntimeInfo, MutableMapping from tutils import clear import pytest -import swanlab import omegaconf +import argparse @pytest.fixture(scope="function", autouse=True) @@ -18,247 +21,337 @@ def setup_function(): swanlog.enable_log() -class TestSwanLabRunConfig: - def test_config_normal_haverun(self): +def test_parse(): + """ + 测试config.parse函数 + """ + # ---------------------------------- omegaConf ---------------------------------- + config_data = { + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + } + cfg = omegaconf.OmegaConf.create(config_data) + config = parse(cfg) + assert yaml.dump(config) == yaml.dump(config_data) + + # ---------------------------------- 自定义继承自MutableMapping的类 ---------------------------------- + + class Test(MutableMapping): + def __init__(self, a, b): + self.data = {"a": a, "b": b} + + def __setitem__(self, __key, __value): + self.data[__key] = __value + + def __delitem__(self, __key): + del self.data[__key] + + def __getitem__(self, __key): + return self.data.get(__key, None) + + def __len__(self): + return len(self.data) + + def __iter__(self): + return iter(self.data) + + config_data = {"a": 1, "b": "mnist", "c/d": [1, 2, 3], "e/f/h": {"a": 1, "b": {"c": 2}}, "test": Test(1, 2)} + config = parse(config_data) + assert config["test"]["a"] == 1 + assert config["test"]["b"] == 2 + # ---------------------------------- 包含NaN或者INF的dict对象 ---------------------------------- + config_data = { + "inf": math.inf, + "nan": math.nan, + } + config = parse(config_data) + assert config["inf"] == Line.inf + assert config["nan"] == Line.nan + # ---------------------------------- argparse.Namespace ---------------------------------- + config_data = argparse.Namespace(a=1, b="mnist", c=[1, 2, 3], d={"a": 1, "b": {"c": 2}}) + config = parse(config_data) + assert yaml.dump(config) == yaml.dump(vars(config_data)) + + +class TestSwanLabConfigOperation: + """ + 单独测试TestSwanLabRunConfig这个类 + """ + + def test_basic_operation_object(self): """ - 初始化时有config参数,测试三种获取数据的方式,且使用的run对象 + 测试类的基本操作,增删改 """ - config_data = { - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - } - run = SwanLabRun(run_config=config_data) - assert isinstance(run.config, SwanLabConfig) - assert len(run.config) == 4 - - assert run.config == config_data - - assert run.config["a"] == 1 - assert run.config["b"] == "mnist" - assert run.config["c/d"] == [1, 2, 3] - assert run.config["c/d"][0] == 1 - assert run.config["e/f/h"] == {"a": 1, "b": {"c": 2}} - assert run.config["e/f/h"]["a"] == 1 - assert run.config["e/f/h"]["b"]["c"] == 2 - - assert run.config.a == 1 - assert run.config.b == "mnist" - - assert run.config.get("a") == 1 - assert run.config.get("b") == "mnist" - assert run.config.get("c/d") == [1, 2, 3] - assert run.config.get("c/d")[0] == 1 - assert run.config.get("e/f/h") == {"a": 1, "b": {"c": 2}} - assert run.config.get("e/f/h")["a"] == 1 - assert run.config["e/f/h"]["b"]["c"] == 2 - - run.config.save() - - def test_config_finish_haverun(self): + config = SwanLabConfig() + # ---------------------------------- 对象风格设置 ---------------------------------- + config.a = 1 + assert config.a == 1 + config.a = 2 + assert config.a == 2 + with pytest.raises(AttributeError): + config.__a = 1 + # 不存在的属性 + with pytest.raises(AttributeError): + config.b # noqa + # 删除属性 + del config.a + with pytest.raises(AttributeError): + del config.a # 重复删除报错 + with pytest.raises(AttributeError): + config.a # noqa + + def test_basic_operation_dict(self): """ - 测试在run.finish()之后config是否置空 + 测试字典的基本操作,增删改 """ - config_data = { - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - } - - run = SwanLabRun(run_config=config_data) - run.finish() - - assert isinstance(run.config, SwanLabConfig) - assert len(run.config) == 0 - - run.config.save() - - def test_config_normal(self): + config = SwanLabConfig() + # ---------------------------------- 字典风格设置 ---------------------------------- + config["a"] = 1 + assert config["a"] == 1 + config["a"] = 2 + assert config["a"] == 2 + # 删除属性 + del config["a"] + with pytest.raises(KeyError): + del config["a"] + with pytest.raises(KeyError): + config["a"] # noqa + # 字典风格可以设置,读取,删除私有属性 + config["__a"] = 1 + assert config["__a"] == 1 + del config["__a"] + with pytest.raises(KeyError): + del config["__a"] # 重复删除失败 + with pytest.raises(KeyError): + config["__a"] # noqa + # int访问,设置 + config[1] = 1 # noqa + with pytest.raises(TypeError): + assert config[1] == 1 # noqa + + def test_dict_iter(self): """ - 初始化时有config参数,测试三种获取数据的方式,直接用全局的config对象 + 测试字典风格的迭代 """ - config_data = { - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - } - config = SwanLabConfig(config=config_data) - - assert isinstance(config, SwanLabConfig) - assert len(config) == 4 - - assert config == config_data - - assert config["a"] == 1 - assert config["b"] == "mnist" - assert config["c/d"] == [1, 2, 3] - assert config["c/d"][0] == 1 - assert config["e/f/h"] == {"a": 1, "b": {"c": 2}} - assert config["e/f/h"]["a"] == 1 - assert config["e/f/h"]["b"]["c"] == 2 - - assert config.a == 1 - assert config.b == "mnist" - - assert config.get("a") == 1 - assert config.get("b") == "mnist" - assert config.get("c/d") == [1, 2, 3] - assert config.get("c/d")[0] == 1 - assert config.get("e/f/h") == {"a": 1, "b": {"c": 2}} - assert config.get("e/f/h")["a"] == 1 - assert config["e/f/h"]["b"]["c"] == 2 - - config.save() - config.clean() - - def test_config_update(self): + config = SwanLabConfig() + ll = ["a", "b", "c", "d"] + for i in ll: + config[i] = i + assert set(config) == {"a", "b", "c", "d"} + index = 0 + # 返回顺序相同 + for key in config: + assert key == ll[index] + index += 1 + + def test_dict_len(self): """ - 测试config初始为空,之后通过update的方式添加config参数 + 测试字典风格的长度 """ - - config_data = { - "a": 1, - "b": "mnist", - "c/d": [1, 2, 3], - "e/f/h": {"a": 1, "b": {"c": 2}}, - } - - update_data = { - "a": 2, - "e/f/h": [4, 5, 6], - "j": 3, - } - config = SwanLabConfig() assert len(config) == 0 + config["a"] = 1 + assert len(config) == 1 + config["b"] = 2 + assert len(config) == 2 + del config["a"] + assert len(config) == 1 + del config["b"] + assert len(config) == 0 - # 第一次更新 - config.update(config_data) - assert config == config_data - assert len(config) == 4 - - assert config["a"] == 1 - assert config["b"] == "mnist" - assert config["c/d"] == [1, 2, 3] - assert config["c/d"][0] == 1 - assert config["e/f/h"] == {"a": 1, "b": {"c": 2}} - assert config["e/f/h"]["a"] == 1 - assert config["e/f/h"]["b"]["c"] == 2 - - # 第二次更新 - config.update(update_data) - assert len(config) == 5 - - assert config["a"] == 2 - assert config["e/f/h"] == [4, 5, 6] - assert config["e/f/h"][0] == 4 - assert config["j"] == 3 - - config.save() + def test_func_operation(self): + """ + 测试内置函数操作 + """ + config = SwanLabConfig() + # ---------------------------------- get ---------------------------------- + a = config.get("a") + assert a is None + a = config.get("a", 1) + assert a == 1 + config["a"] = 5 + a = config.get("a") + assert a == 5 + # ---------------------------------- set ---------------------------------- + config.set("b", 1) + assert config["b"] == 1 + config.set("__b", 1) + assert config["__b"] == 1 + # ---------------------------------- pop ---------------------------------- + config["c"] = 9 + c = config.pop("c") + assert c == 9 + assert config.pop("d") is None + # ---------------------------------- clean ---------------------------------- + config["e"] = 1 + config["g"] = 0 config.clean() + assert len(config) == 0 + with pytest.raises(KeyError): + config["e"] # noqa + # ---------------------------------- update ---------------------------------- + config["x"] = 1 + config["y"] = 2 + config["z"] = {"a": 1, "b": 2} + config.update({"x": 2, "y": 3, "z": {"a": 2, "b": 3}}) + assert config["x"] == 2 + assert config["y"] == 3 + assert config["z"] == {"a": 2, "b": 3} + # update自己 + _config = SwanLabConfig() + _config.update(config) + assert _config == config + # update,argparse.Namespace + _config = SwanLabConfig() + _config.update(argparse.Namespace(a=1, b=2)) + assert _config["a"] == 1 + assert _config["b"] == 2 + # update, use kwargs + _config = SwanLabConfig() + _config.update(a=2, b=1) + assert _config["a"] == 2 + assert _config["b"] == 1 + + +def test_on_setter(): + """ + 测试on_setter函数,在设置属性时触发 + """ + num = 1 + + def on_setter(_: RuntimeInfo): + nonlocal num + num += 1 + + config = SwanLabConfig(on_setter=on_setter) + + # ---------------------------------- 对象、字典风格 ---------------------------------- + + # 设置触发 + config.a = 1 + assert num == 2 + del config.a + assert num == 3 + config["b"] = 1 + assert num == 4 + del config["b"] + assert num == 5 + + # 读取不触发 + config.x = 1 + assert num == 6 + _ = config.x + assert num == 6 + config["y"] = 1 + assert num == 7 + _ = config["y"] + assert num == 7 + + # ---------------------------------- api ---------------------------------- + + # 设置触发 + config.set("c", 1) + assert num == 8 + config.pop("c") + assert num == 9 + config.update({"d": {}}) + assert num == 10 + # 深层设置无法触发 + config.d["e"] = 1 + assert num == 10 + + # 读取不触发 + config.get("f", 1) + config["g"] = 1 + assert num == 11 + config.get("g") + assert num == 11 + + # ---------------------------------- clean以后再设置无法触发 ---------------------------------- + + config.clean() + config.h = 1 + assert num == 11 + + +class TestSwanLabConfigWithRun: + """ + 测试SwanLabConfig与SwanLabRun的交互 + """ - def test_config_get_config(self): + def test_use_dict(self): """ - 测试get_config + 正常流程,输入字典 """ - config_data = { + run = SwanLabRun(run_config={ "a": 1, "b": "mnist", "c/d": [1, 2, 3], "e/f/h": {"a": 1, "b": {"c": 2}}, - } - - assert isinstance(swanlab.get_config(), SwanLabConfig) - assert len(swanlab.get_config()) == 0 - - config = SwanLabConfig(config=config_data) - - assert isinstance(swanlab.get_config(), SwanLabConfig) - assert len(swanlab.get_config()) == 4 - - config.save() - config.clean() - - def test_config_from_omegaconf(self): + }) + config = run.config + _config = get_config() + assert config["a"] == _config["a"] == 1 + assert config["b"] == _config["b"] == "mnist" + assert config["c/d"] == _config["c/d"] == [1, 2, 3] + + def test_use_omegaconf(self): """ - 测试config导入omegaconf的情况 + 正常流程,输入OmegaConf """ - config_data = { + run = SwanLabRun(run_config=omegaconf.OmegaConf.create({ "a": 1, "b": "mnist", "c/d": [1, 2, 3], "e/f/h": {"a": 1, "b": {"c": 2}}, - } - cfg = omegaconf.OmegaConf.create(config_data) - config = SwanLabConfig(config=config_data) - - assert isinstance(config, SwanLabConfig) - assert len(config) == 4 - - assert config["a"] == 1 - assert config["b"] == "mnist" - assert config["c/d"] == [1, 2, 3] - assert config["e/f/h"] == {"a": 1, "b": {"c": 2}} - - config.save() - config.clean() - - def test_not_json_serializable(self): + })) + config = run.config + _config = get_config() + assert config["a"] == _config["a"] == 1 + assert config["b"] == _config["b"] == "mnist" + assert config["c/d"] == _config["c/d"] == [1, 2, 3] + + def test_use_argparse(self): """ - 测试不可json化的数据 + 正常流程,输入argparse.Namespace """ - import math, json - - config_data = { + run = SwanLabRun(run_config=argparse.Namespace(a=1, b="mnist", c=[1, 2, 3], d={"a": 1, "b": {"c": 2}})) + config = run.config + _config = get_config() + assert config["a"] == _config["a"] == 1 + assert config["b"] == _config["b"] == "mnist" + assert config["c"] == _config["c"] == [1, 2, 3] + + def test_use_config(self): + """ + 正常流程,输入SwanLabConfig + """ + run = SwanLabRun(run_config=SwanLabConfig({ "a": 1, "b": "mnist", - "c/d": [1, 2, 3], + "c": [1, 2, 3], "e/f/h": {"a": 1, "b": {"c": 2}}, - "test_nan": math.nan, - "test_inf": math.inf, - } - - config = SwanLabConfig(config=config_data) - - json_data = json.dumps(dict(config)) - - config.save() - config.clean() - - def test_insert_class(self): + })) + config = run.config + _config = get_config() + assert config["a"] == _config["a"] == 1 + assert config["b"] == _config["b"] == "mnist" + assert config["c"] == _config["c"] == [1, 2, 3] + + def test_after_finish(self): """ - 测试插入类 + 测试在finish之后config的变化 """ - from collections.abc import MutableMapping - - class Test(MutableMapping): - def __init__(self, a, b): - self.data = {"a": a, "b": b} - - def __setitem__(self, __key, __value): - self.data[__key] = __value - - def __delitem__(self, __key): - del self.data[__key] - - def __getitem__(self, __key): - return self.data.get(__key, None) - - def __len__(self): - return len(self.data) - - def __iter__(self): - return iter(self.data) - - config_data = {"a": 1, "b": "mnist", "c/d": [1, 2, 3], "e/f/h": {"a": 1, "b": {"c": 2}}, "test": Test(1, 2)} - - config = SwanLabConfig(config=config_data) - - assert config.test.data["a"] == 1 - assert config.test.data["b"] == 2 - - config.save() - config.clean() + run = SwanLabRun(run_config={ + "a": 1, + "b": "mnist", + "c/d": [1, 2, 3], + "e/f/h": {"a": 1, "b": {"c": 2}}, + }) + run.finish() + config = run.config + _config = get_config() + assert len(config) == 4 + assert len(_config) == 0 diff --git a/test/unit/data/run/pytest_main.py b/test/unit/data/run/pytest_main.py index a1dfc1726..01ccfb52a 100644 --- a/test/unit/data/run/pytest_main.py +++ b/test/unit/data/run/pytest_main.py @@ -73,19 +73,19 @@ def test_not_started(self): def test_running(self): run = SwanLabRun() assert run.state == SwanLabRunState.RUNNING - assert run.is_running is True + assert run.running is True def test_crashed(self): run = SwanLabRun() run.finish(SwanLabRunState.CRASHED, error="error") assert run.state == SwanLabRunState.CRASHED - assert run.is_crashed is True + assert run.crashed is True def test_success(self): run = SwanLabRun() run.finish(SwanLabRunState.SUCCESS) assert run.state == SwanLabRunState.SUCCESS - assert run.is_success is True + assert run.success is True class TestSwanLabRunLog: