Skip to content

Commit

Permalink
chore: sys-model (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
SAKURA-CAT authored Dec 5, 2024
1 parent d2f99b7 commit 0c1d1c0
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 83 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build-backend = "hatchling.build"

[project]
name = "swankit"
version = "0.1.2b1"
version = "0.1.2b2"
dynamic = ["readme", "dependencies"]
description = "Base toolkit for SwanLab"
license = "Apache-2.0"
Expand Down
2 changes: 1 addition & 1 deletion swankit/callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,4 @@ def __str__(self) -> str:
pass


__all__ = ["SwanKitCallback", "MediaBuffer", "MetricInfo", "ColumnInfo", "OperateErrorInfo", "RuntimeInfo"]
__all__ = ["SwanKitCallback", "models"]
6 changes: 5 additions & 1 deletion swankit/callback/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@Description:
与回调函数通信时的模型
"""
from .key import MediaBuffer, MetricInfo, ColumnInfo, MetricErrorInfo
from .key import MediaBuffer, MetricInfo, ColumnInfo, MetricErrorInfo, ColumnClass, SectionType, ColumnConfig, YRange
from .error import OperateErrorInfo
from .runtime import RuntimeInfo

Expand All @@ -18,4 +18,8 @@
"MetricErrorInfo",
"OperateErrorInfo",
"RuntimeInfo",
"ColumnClass",
"SectionType",
"ColumnConfig",
"YRange",
]
47 changes: 31 additions & 16 deletions swankit/callback/models/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,24 @@
@Description:
与Key相关的回调函数触发时的模型
"""
from typing import Union, Optional, Dict, List, Literal
from swankit.core import ChartType, ParseErrorInfo, MediaBuffer
from typing import Union, Optional, Dict, List, Literal, Tuple, TypedDict

from swankit.core import ChartType, ParseErrorInfo, MediaBuffer, ChartReference
from urllib.parse import quote
import os

ColumnClass = Literal["CUSTOM", "SYSTEM"]
SectionType = Literal["PINNED", "HIDDEN", "PUBLIC", "SYSTEM"]
YRange = Optional[Tuple[Optional[float], Optional[float]]]


class ColumnConfig(TypedDict):
"""
列信息配置
"""

y_range: YRange


class ColumnInfo:
"""
Expand All @@ -21,22 +34,23 @@ class ColumnInfo:
def __init__(
self,
key: str,
key_id: str,
key_name: str,
key_class: Literal["CUSTOM", "SYSTEM"],
kid: str,
name: str,
cls: ColumnClass,
chart_type: ChartType,
chart_reference: Literal["step", "time"],
chart_reference: ChartReference,
section_name: Optional[str],
section_type: SectionType,
section_sort: Optional[int] = None,
error: Optional[ParseErrorInfo] = None,
config: Optional[Dict] = None,
config: Optional[ColumnConfig] = None,
):
"""
生成的列信息对象
:param key: 生成的列名称
:param key_id: 当前实验下,列的唯一id,与保存路径等信息有关
:param key_name: key的别名
:param key_class: 列的类型,CUSTOM为自定义列,SYSTEM为系统生成列
:param key: 生成的列名称,作为索引键值
:param kid: 当前实验下,列的唯一id,与保存路径等信息有关,与云端请求无关
:param name: 列的别名
:param cls: 列的类型,CUSTOM为自定义列,SYSTEM为系统生成列
:param chart_type: 列对应的图表类型
:param chart_reference: 这个列对应图表的参考系,step为步数,time为时间
:param section_name: 列的组名
Expand All @@ -45,18 +59,19 @@ def __init__(
:param config: 列的额外配置信息
"""
self.key = key
self.key_id = key_id
self.key_name = key_name
self.key_class = key_class
self.kid = kid
self.name = name
self.cls = cls

self.chart_type = chart_type
self.chart_reference = chart_reference

self.section_name = section_name
self.section_sort = section_sort
self.section_type = section_type

self.error = error
self.config = config if config is not None else {}
self.config = config

@property
def got(self):
Expand Down Expand Up @@ -124,7 +139,7 @@ def __init__(
self.metric_summary = metric_summary
self.metric_step = metric_step
self.metric_epoch = metric_epoch
_id = self.column_info.key_id
_id = self.column_info.kid
self.metric_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, metric_file_name)
self.summary_file_path = None if self.is_error else os.path.join(swanlab_logdir, _id, self.__SUMMARY_NAME)
self.swanlab_media_dir = swanlab_media_dir
Expand Down
24 changes: 7 additions & 17 deletions swankit/callback/models/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
运行时信息模型
"""
from abc import ABC, abstractmethod
from typing import Optional, Any
from typing import Optional
import json
import yaml
import os
Expand Down Expand Up @@ -108,13 +108,7 @@ def to_dict(self):
# 没有在__init__中直接修改是因为可能会有其他地方需要原始数据,并且会丢失一些性能
if self.__data is not None:
return self.__data
self.__data = {
k: {
"value": v,
"sort": i,
"desc": ""
} for i, (k, v) in enumerate(self.info.items())
}
self.__data = {k: {"value": v, "sort": i, "desc": ""} for i, (k, v) in enumerate(self.info.items())}
return self.__data


Expand All @@ -130,14 +124,10 @@ def __init__(self, requirements: str = None, metadata: dict = None, config: dict
:param metadata: 系统信息
:param config: 上传的配置信息
"""
self.requirements: Optional[RequirementInfo] = RequirementInfo(
requirements
) if requirements is not None else None
self.requirements: Optional[RequirementInfo] = (
RequirementInfo(requirements) if requirements is not None else None
)

self.metadata: Optional[MetadataInfo] = MetadataInfo(
metadata
) if metadata is not None else None
self.metadata: Optional[MetadataInfo] = MetadataInfo(metadata) if metadata is not None else None

self.config: Optional[ConfigInfo] = ConfigInfo(
config
) if config is not None else None
self.config: Optional[ConfigInfo] = ConfigInfo(config) if config is not None else None
5 changes: 3 additions & 2 deletions swankit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@Description:
核心解析模块工具
"""
from .data import BaseType, MediaType, DataSuite, MediaBuffer, ParseResult, ParseErrorInfo
from .data import BaseType, MediaType, DataSuite, MediaBuffer, ParseResult, ParseErrorInfo, ChartReference
from .settings import SwanLabSharedSettings

ChartType = BaseType.Chart
Expand All @@ -20,5 +20,6 @@
"MediaBuffer",
"ParseResult",
"ParseErrorInfo",
"SwanLabSharedSettings"
"SwanLabSharedSettings",
"ChartReference",
]
43 changes: 18 additions & 25 deletions swankit/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
@Description:
数据处理模型
"""
from typing import List, Dict, Optional, ByteString, Union, Tuple
from typing import List, Dict, Optional, ByteString, Union, Tuple, Literal
from abc import ABC, abstractmethod
from enum import Enum
from io import BytesIO
import hashlib
import math
import io

ChartReference = Literal["STEP", "TIME"]


class DataSuite:
"""
Expand Down Expand Up @@ -172,14 +174,6 @@ def get_section(self) -> str:
"""
return "default"

# noinspection PyMethodMayBeStatic
def get_config(self) -> Optional[Dict]:
"""
获取图表的config配置信息,应该返回一个字典,或者为None
为None时代表不需要配置
"""
return None

# noinspection PyMethodMayBeStatic
def get_more(self) -> Optional[Dict]:
"""
Expand All @@ -193,6 +187,7 @@ class MediaType(BaseType): # noqa
"""
媒体类型,用于区分标量和媒体,不用做实例化,应该由子类继承
"""

pass


Expand Down Expand Up @@ -220,30 +215,28 @@ class ParseResult:
"""

def __init__(
self,
section: str = None,
chart: BaseType.Chart = None,
data: Union[List[str], float] = None,
config: Optional[List[Dict]] = None,
more: Optional[List[Dict]] = None,
buffers: Optional[List[MediaBuffer]] = None,
self,
section: str = None,
chart: BaseType.Chart = None,
data: Union[List[str], float] = None,
more: Optional[List[Dict]] = None,
buffers: Optional[List[MediaBuffer]] = None,
reference: ChartReference = "STEP",
):
"""
:param section: 转换后数据对应的section
:param chart: 转换后数据对应的图表类型,枚举类型
:param data: 存储在.log中的数据
:param config: 存储在.log中的配置
:param more: 存储在.log中的更多信息
:param buffers: 存储于media文件夹中的原始数据,比特流,特别的,对于某些字符串即原始数据的情况,此处为None
:param reference: 图表数据的参考类型
"""
self.__data = data
self.config = config
self.more = more
self.buffers = buffers
self.section = section
self.chart = chart
# 默认的reference
self.reference = "step"
self.reference = reference
self.step = None

@property
Expand Down Expand Up @@ -291,11 +284,11 @@ class ParseErrorInfo:
"""

def __init__(
self,
expected: Optional[str],
got: Optional[str],
chart: Optional[BaseType.Chart],
duplicated: bool = False
self,
expected: Optional[str],
got: Optional[str],
chart: Optional[BaseType.Chart],
duplicated: bool = False,
):
"""
:param expected: 期望的数据类型
Expand Down
44 changes: 24 additions & 20 deletions test/unit/callback/models/test_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,43 @@
def test_column_info():
c = K.ColumnInfo(
key="a/1",
key_id="b",
key_name="c",
key_class="SYSTEM",
kid="b",
name="c",
cls="SYSTEM",
section_name="e",
section_sort=1,
section_type="PUBLIC",
chart_type=ChartType.TEXT,
chart_reference="step",
chart_reference="STEP",
error=None,
config=None,
)
assert c.got is None
assert c.key == "a/1"
assert c.key_id == "b"
assert c.key_name == "c"
assert c.key_class == "SYSTEM"
assert c.kid == "b"
assert c.name == "c"
assert c.cls == "SYSTEM"
assert c.section_name == "e"
assert c.section_sort == 1
assert c.chart_type == ChartType.TEXT
assert c.chart_reference == "step"
assert c.chart_reference == "STEP"
assert c.section_type == "PUBLIC"
assert c.error is None
assert c.config == {}
assert c.config is None
assert c.key_encode == "a%2F1"


def test_metric_info():
c = K.ColumnInfo(
key="a/1",
key_id="b",
key_name="c",
key_class="SYSTEM",
kid="b",
name="c",
cls="SYSTEM",
section_name="e",
section_sort=1,
chart_type=ChartType.TEXT,
chart_reference="step",
section_type="PUBLIC",
chart_reference="STEP",
error=None,
config=None,
)
Expand All @@ -56,25 +59,26 @@ def test_metric_info():
)
assert m.column_info.got is None
assert m.column_info.key == "a/1"
assert m.column_info.key_id == "b"
assert m.column_info.key_name == "c"
assert m.column_info.key_class == "SYSTEM"
assert m.column_info.kid == "b"
assert m.column_info.name == "c"
assert m.column_info.cls == "SYSTEM"
assert m.column_info.section_name == "e"
assert m.column_info.section_sort == 1
assert m.column_info.chart_type == ChartType.TEXT
assert m.column_info.chart_reference == "step"
assert m.column_info.chart_reference == "STEP"
assert m.column_info.section_type == "PUBLIC"
assert m.column_info.error is None
assert m.column_info.config == {}
assert m.column_info.config is None
assert m.column_info.key_encode == "a%2F1"
assert m.column_info.got is None
assert m.column_info.expected is None
assert m.column_info.key_encode == "a%2F1"
assert m.column_info.key == "a/1"
assert m.column_info.key_id == "b"
assert m.column_info.kid == "b"
assert m.metric == {"data": 1}
assert m.metric_buffers is None
assert m.metric_summary == {"data": 1}
assert m.metric_step == 1
assert m.metric_epoch == 1
assert m.swanlab_media_dir == "."
assert m.metric_file_path == f"./{c.key_id}/1.log"
assert m.metric_file_path == f"./{c.kid}/1.log"

0 comments on commit 0c1d1c0

Please sign in to comment.