diff --git a/pyproject.toml b/pyproject.toml index 558a8d5..f4c3963 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "hatchling.build" [project] name = "swankit" -version = "0.1.2b1" +version = "0.1.2b2" dynamic = ["readme", "dependencies"] description = "Base toolkit for SwanLab" license = "Apache-2.0" diff --git a/swankit/callback/__init__.py b/swankit/callback/__init__.py index 01fef7b..339472d 100644 --- a/swankit/callback/__init__.py +++ b/swankit/callback/__init__.py @@ -107,4 +107,4 @@ def __str__(self) -> str: pass -__all__ = ["SwanKitCallback", "MediaBuffer", "MetricInfo", "ColumnInfo", "OperateErrorInfo", "RuntimeInfo"] +__all__ = ["SwanKitCallback", "models"] diff --git a/swankit/callback/models/__init__.py b/swankit/callback/models/__init__.py index 6a26ada..0b60cbf 100644 --- a/swankit/callback/models/__init__.py +++ b/swankit/callback/models/__init__.py @@ -7,7 +7,7 @@ @Description: 与回调函数通信时的模型 """ -from .key import MediaBuffer, MetricInfo, ColumnInfo, MetricErrorInfo +from .key import MediaBuffer, MetricInfo, ColumnInfo, MetricErrorInfo, ColumnClass, SectionType, ColumnConfig, YRange from .error import OperateErrorInfo from .runtime import RuntimeInfo @@ -18,4 +18,8 @@ "MetricErrorInfo", "OperateErrorInfo", "RuntimeInfo", + "ColumnClass", + "SectionType", + "ColumnConfig", + "YRange", ] diff --git a/swankit/callback/models/key.py b/swankit/callback/models/key.py index 1cd33cb..6c33cb6 100644 --- a/swankit/callback/models/key.py +++ b/swankit/callback/models/key.py @@ -7,11 +7,24 @@ @Description: 与Key相关的回调函数触发时的模型 """ -from typing import Union, Optional, Dict, List, Literal -from swankit.core import ChartType, ParseErrorInfo, MediaBuffer +from typing import Union, Optional, Dict, List, Literal, Tuple, TypedDict + +from swankit.core import ChartType, ParseErrorInfo, MediaBuffer, ChartReference from urllib.parse import quote import os +ColumnClass = Literal["CUSTOM", "SYSTEM"] +SectionType = Literal["PINNED", "HIDDEN", "PUBLIC", "SYSTEM"] +YRange = Optional[Tuple[Optional[float], Optional[float]]] + + +class ColumnConfig(TypedDict): + """ + 列信息配置 + """ + + y_range: YRange + class ColumnInfo: """ @@ -21,22 +34,23 @@ class ColumnInfo: def __init__( self, key: str, - key_id: str, - key_name: str, - key_class: Literal["CUSTOM", "SYSTEM"], + kid: str, + name: str, + cls: ColumnClass, chart_type: ChartType, - chart_reference: Literal["step", "time"], + chart_reference: ChartReference, section_name: Optional[str], + section_type: SectionType, section_sort: Optional[int] = None, error: Optional[ParseErrorInfo] = None, - config: Optional[Dict] = None, + config: Optional[ColumnConfig] = None, ): """ 生成的列信息对象 - :param key: 生成的列名称 - :param key_id: 当前实验下,列的唯一id,与保存路径等信息有关 - :param key_name: key的别名 - :param key_class: 列的类型,CUSTOM为自定义列,SYSTEM为系统生成列 + :param key: 生成的列名称,作为索引键值 + :param kid: 当前实验下,列的唯一id,与保存路径等信息有关,与云端请求无关 + :param name: 列的别名 + :param cls: 列的类型,CUSTOM为自定义列,SYSTEM为系统生成列 :param chart_type: 列对应的图表类型 :param chart_reference: 这个列对应图表的参考系,step为步数,time为时间 :param section_name: 列的组名 @@ -45,18 +59,19 @@ def __init__( :param config: 列的额外配置信息 """ self.key = key - self.key_id = key_id - self.key_name = key_name - self.key_class = key_class + self.kid = kid + self.name = name + self.cls = cls self.chart_type = chart_type self.chart_reference = chart_reference self.section_name = section_name self.section_sort = section_sort + self.section_type = section_type self.error = error - self.config = config if config is not None else {} + self.config = config @property def got(self): @@ -124,7 +139,7 @@ def __init__( self.metric_summary = metric_summary self.metric_step = metric_step self.metric_epoch = metric_epoch - _id = self.column_info.key_id + _id = self.column_info.kid self.metric_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, metric_file_name) self.summary_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, self.__SUMMARY_NAME) self.swanlab_media_dir = swanlab_media_dir diff --git a/swankit/callback/models/runtime.py b/swankit/callback/models/runtime.py index dbfb55d..4b2de84 100644 --- a/swankit/callback/models/runtime.py +++ b/swankit/callback/models/runtime.py @@ -8,7 +8,7 @@ 运行时信息模型 """ from abc import ABC, abstractmethod -from typing import Optional, Any +from typing import Optional import json import yaml import os @@ -108,13 +108,7 @@ def to_dict(self): # 没有在__init__中直接修改是因为可能会有其他地方需要原始数据,并且会丢失一些性能 if self.__data is not None: return self.__data - self.__data = { - k: { - "value": v, - "sort": i, - "desc": "" - } for i, (k, v) in enumerate(self.info.items()) - } + self.__data = {k: {"value": v, "sort": i, "desc": ""} for i, (k, v) in enumerate(self.info.items())} return self.__data @@ -130,14 +124,10 @@ def __init__(self, requirements: str = None, metadata: dict = None, config: dict :param metadata: 系统信息 :param config: 上传的配置信息 """ - self.requirements: Optional[RequirementInfo] = RequirementInfo( - requirements - ) if requirements is not None else None + self.requirements: Optional[RequirementInfo] = ( + RequirementInfo(requirements) if requirements is not None else None + ) - self.metadata: Optional[MetadataInfo] = MetadataInfo( - metadata - ) if metadata is not None else None + self.metadata: Optional[MetadataInfo] = MetadataInfo(metadata) if metadata is not None else None - self.config: Optional[ConfigInfo] = ConfigInfo( - config - ) if config is not None else None + self.config: Optional[ConfigInfo] = ConfigInfo(config) if config is not None else None diff --git a/swankit/core/__init__.py b/swankit/core/__init__.py index b7e0d72..152429d 100644 --- a/swankit/core/__init__.py +++ b/swankit/core/__init__.py @@ -7,7 +7,7 @@ @Description: 核心解析模块工具 """ -from .data import BaseType, MediaType, DataSuite, MediaBuffer, ParseResult, ParseErrorInfo +from .data import BaseType, MediaType, DataSuite, MediaBuffer, ParseResult, ParseErrorInfo, ChartReference from .settings import SwanLabSharedSettings ChartType = BaseType.Chart @@ -20,5 +20,6 @@ "MediaBuffer", "ParseResult", "ParseErrorInfo", - "SwanLabSharedSettings" + "SwanLabSharedSettings", + "ChartReference", ] diff --git a/swankit/core/data.py b/swankit/core/data.py index 7fc8820..b03f049 100644 --- a/swankit/core/data.py +++ b/swankit/core/data.py @@ -7,7 +7,7 @@ @Description: 数据处理模型 """ -from typing import List, Dict, Optional, ByteString, Union, Tuple +from typing import List, Dict, Optional, ByteString, Union, Tuple, Literal from abc import ABC, abstractmethod from enum import Enum from io import BytesIO @@ -15,6 +15,8 @@ import math import io +ChartReference = Literal["STEP", "TIME"] + class DataSuite: """ @@ -172,14 +174,6 @@ def get_section(self) -> str: """ return "default" - # noinspection PyMethodMayBeStatic - def get_config(self) -> Optional[Dict]: - """ - 获取图表的config配置信息,应该返回一个字典,或者为None - 为None时代表不需要配置 - """ - return None - # noinspection PyMethodMayBeStatic def get_more(self) -> Optional[Dict]: """ @@ -193,6 +187,7 @@ class MediaType(BaseType): # noqa """ 媒体类型,用于区分标量和媒体,不用做实例化,应该由子类继承 """ + pass @@ -220,30 +215,28 @@ class ParseResult: """ def __init__( - self, - section: str = None, - chart: BaseType.Chart = None, - data: Union[List[str], float] = None, - config: Optional[List[Dict]] = None, - more: Optional[List[Dict]] = None, - buffers: Optional[List[MediaBuffer]] = None, + self, + section: str = None, + chart: BaseType.Chart = None, + data: Union[List[str], float] = None, + more: Optional[List[Dict]] = None, + buffers: Optional[List[MediaBuffer]] = None, + reference: ChartReference = "STEP", ): """ :param section: 转换后数据对应的section :param chart: 转换后数据对应的图表类型,枚举类型 :param data: 存储在.log中的数据 - :param config: 存储在.log中的配置 :param more: 存储在.log中的更多信息 :param buffers: 存储于media文件夹中的原始数据,比特流,特别的,对于某些字符串即原始数据的情况,此处为None + :param reference: 图表数据的参考类型 """ self.__data = data - self.config = config self.more = more self.buffers = buffers self.section = section self.chart = chart - # 默认的reference - self.reference = "step" + self.reference = reference self.step = None @property @@ -291,11 +284,11 @@ class ParseErrorInfo: """ def __init__( - self, - expected: Optional[str], - got: Optional[str], - chart: Optional[BaseType.Chart], - duplicated: bool = False + self, + expected: Optional[str], + got: Optional[str], + chart: Optional[BaseType.Chart], + duplicated: bool = False, ): """ :param expected: 期望的数据类型 diff --git a/test/unit/callback/models/test_key.py b/test/unit/callback/models/test_key.py index c491ccf..4ef406a 100644 --- a/test/unit/callback/models/test_key.py +++ b/test/unit/callback/models/test_key.py @@ -5,40 +5,43 @@ def test_column_info(): c = K.ColumnInfo( key="a/1", - key_id="b", - key_name="c", - key_class="SYSTEM", + kid="b", + name="c", + cls="SYSTEM", section_name="e", section_sort=1, + section_type="PUBLIC", chart_type=ChartType.TEXT, - chart_reference="step", + chart_reference="STEP", error=None, config=None, ) assert c.got is None assert c.key == "a/1" - assert c.key_id == "b" - assert c.key_name == "c" - assert c.key_class == "SYSTEM" + assert c.kid == "b" + assert c.name == "c" + assert c.cls == "SYSTEM" assert c.section_name == "e" assert c.section_sort == 1 assert c.chart_type == ChartType.TEXT - assert c.chart_reference == "step" + assert c.chart_reference == "STEP" + assert c.section_type == "PUBLIC" assert c.error is None - assert c.config == {} + assert c.config is None assert c.key_encode == "a%2F1" def test_metric_info(): c = K.ColumnInfo( key="a/1", - key_id="b", - key_name="c", - key_class="SYSTEM", + kid="b", + name="c", + cls="SYSTEM", section_name="e", section_sort=1, chart_type=ChartType.TEXT, - chart_reference="step", + section_type="PUBLIC", + chart_reference="STEP", error=None, config=None, ) @@ -56,25 +59,26 @@ def test_metric_info(): ) assert m.column_info.got is None assert m.column_info.key == "a/1" - assert m.column_info.key_id == "b" - assert m.column_info.key_name == "c" - assert m.column_info.key_class == "SYSTEM" + assert m.column_info.kid == "b" + assert m.column_info.name == "c" + assert m.column_info.cls == "SYSTEM" assert m.column_info.section_name == "e" assert m.column_info.section_sort == 1 assert m.column_info.chart_type == ChartType.TEXT - assert m.column_info.chart_reference == "step" + assert m.column_info.chart_reference == "STEP" + assert m.column_info.section_type == "PUBLIC" assert m.column_info.error is None - assert m.column_info.config == {} + assert m.column_info.config is None assert m.column_info.key_encode == "a%2F1" assert m.column_info.got is None assert m.column_info.expected is None assert m.column_info.key_encode == "a%2F1" assert m.column_info.key == "a/1" - assert m.column_info.key_id == "b" + assert m.column_info.kid == "b" assert m.metric == {"data": 1} assert m.metric_buffers is None assert m.metric_summary == {"data": 1} assert m.metric_step == 1 assert m.metric_epoch == 1 assert m.swanlab_media_dir == "." - assert m.metric_file_path == f"./{c.key_id}/1.log" + assert m.metric_file_path == f"./{c.kid}/1.log"