Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.1.0b1 #3

Merged
merged 5 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions swankit/callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,99 @@
@Description:
回调类,规定回调函数的接口规范。
"""

from typing import Callable
from abc import ABC, abstractmethod
from .models import *


class SwanKitCallback(ABC):
"""
SwanKitCallback,回调函数注册类,所有以`on_`和`before_`开头的函数都会在对应的时机被调用
此处只定义会被调用的函数,用于接口规范
"""

def on_init(self, proj_name: str, workspace: str, logdir: str = None, **kwargs):
"""
执行`swanlab.init`时调用,此时运行时环境变量没有被设置,此时修改环境变量还是有效的
:param logdir: str, 用户设置的日志目录
:param proj_name: str, 项目名称
:param workspace: str, 工作空间
:param kwargs: dict, 其他参数,为了增加灵活性,可以在on_init的时候设置一些其他类内参数
"""
pass

def before_init_experiment(
self,
run_id: str,
exp_name: str,
description: str,
num: int,
suffix: str,
setter: Callable[[str, str, str, str], None],
):
"""
在初始化实验之前调用,此时SwanLabRun已经初始化完毕
FIXME setter函数实际上并不应该被传递,现在是因为存在实验名称不能重复的历史遗留问题
:param run_id: str, SwanLabRun的运行id
:param exp_name: str, 实验名称
:param description: str, 实验描述
:param num: int, 历史实验数量
:param suffix: str, 实验后缀
:param setter: Callable[[str, str, str, str], None], 设置实验信息的函数,在这里设置实验信息
"""
pass

def on_run(self):
"""
SwanLabRun初始化完毕时调用
"""
pass

def on_run_error_from_operator(self, e: OperateErrorInfo):
"""
执行`on_run`错误时被操作员调用
"""
pass

def on_runtime_info_update(self, r: RuntimeInfo):
"""
运行时信息更新时调用
:param r: RuntimeInfo, 运行时信息
"""
pass

def on_log(self):
"""
每次执行swanlab.log时调用
"""
pass

def on_column_create(self, column_info: ColumnInfo):
"""
列创建回调函数,新增列信息时调用
"""
pass

def on_metric_create(self, metric_info: MetricInfo):
"""
指标创建回调函数,新增指标信息时调用
"""
pass

def on_stop(self, error: str = None):
"""
训练结束时的回调函数
"""
pass

@abstractmethod
def __str__(self) -> str:
"""
返回当前回调函数的名称,这条应该是一个全局唯一的标识
在operator中会用到这个名称,必须唯一
"""
pass


__all__ = ["SwanKitCallback", "MediaBuffer", "MetricInfo", "ColumnInfo", "OperateErrorInfo", "RuntimeInfo"]
14 changes: 14 additions & 0 deletions swankit/callback/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/6/5 16:29
@File: __init__.py
@IDE: pycharm
@Description:
与回调函数通信时的模型
"""
from .key import MediaBuffer, MetricInfo, ColumnInfo
from .error import OperateErrorInfo
from .runtime import RuntimeInfo

__all__ = ["MediaBuffer", "MetricInfo", "ColumnInfo", "OperateErrorInfo", "RuntimeInfo"]
24 changes: 24 additions & 0 deletions swankit/callback/models/error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/6/5 16:30
@File: error.py
@IDE: pycharm
@Description:
错误模型
"""


class OperateErrorInfo:
"""
操作错误信息,当操作员操作回调发生错误时,会生成这个对象,传给相应的回调函数
"""

def __init__(self, desc: str):
self.desc = desc
"""
错误描述
"""

def __str__(self):
return f"SwanLabError: {self.desc}"
184 changes: 184 additions & 0 deletions swankit/callback/models/key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/6/5 16:32
@File: key.py
@IDE: pycharm
@Description:
与Key相关的回调函数触发时的模型
"""
from typing import Union, Optional, Dict, List
from swankit.core import ChartType, ParseErrorInfo, MediaBuffer
from urllib.parse import quote
import os


class ColumnInfo:
"""
列信息,当创建列时,会生成这个对象
"""

def __init__(
self,
key: str,
namespace: str,
chart: ChartType,
sort: Optional[int] = None,
error: Optional[ParseErrorInfo] = None,
reference: Optional[str] = None,
config: Optional[Dict] = None,
):
self.key = key
"""
列的key名称
"""
self.namespace = namespace
"""
列的命名空间
"""
self.chart = chart
"""
列的图表类型
"""
self.error = error
"""
列的类型错误信息
"""
self.reference = reference
"""
列的参考对象
"""
self.sort = sort
"""
列在namespace中的排序
"""
self.config = config if config is not None else {}
"""
列的额外配置信息
"""

@property
def got(self):
"""
传入的错误类型,如果列出错,返回错误类型,如果没出错,`暂时`返回None
"""
if self.error is None:
return None
return self.error.got

@property
def expected(self):
"""
期望的类型,如果列出错,返回期望的类型,如果没出错,`暂时`返回None
"""
if self.error is None:
return None
return self.error.expected


class MetricInfo:
"""
指标信息,当新的指标被log时,会生成这个对象
"""
__SUMMARY_NAME = "_summary.json"

def __init__(
self,
key: str,
column_info: ColumnInfo,
error: Optional[ParseErrorInfo],
metric: Union[Dict, None] = None,
summary: Union[Dict, None] = None,
step: int = None,
epoch: int = None,
logdir: str = None,
metric_file_name: str = None,
media_dir: str = None,
buffers: List[MediaBuffer] = None,
):
self.__error = error

self.key = quote(key, safe="")
"""
指标的key名称,被quote编码
"""
self.column_info = column_info
"""
指标对应的列信息
"""
self.metric = metric
"""
指标信息,error时为None
"""
self.summary = summary
"""
摘要信息,error时为None
"""
self.step = step
"""
当前指标的步数,error时为None
"""
self.epoch = epoch
"""
当前指标对应本地的行数,error时为None
"""
self.metric_path = None if self.error else os.path.join(logdir, self.key, metric_file_name)
"""
指标文件的路径,error时为None
"""
self.summary_path = None if self.error else os.path.join(logdir, self.key, self.__SUMMARY_NAME)
"""
摘要文件的路径,error时为None
"""
self.media_dir = media_dir
"""
静态文件的根文件夹
"""
self.buffers = buffers
"""
需要上传的媒体数据,比特流,error时为None,如果上传为非媒体类型(或Text类型),也为None
"""
# 写入文件名称,对应上传时的文件名称
if self.buffers is not None:
for i, buffer in enumerate(self.buffers):
buffer.file_name = "{}/{}".format(self.key, metric["data"][i])

@property
def error(self) -> bool:
"""
这条指标信息是否有错误,错误分几种:
1. 列错误,列一开始就出现问题
2. 重复错误
3. 指标错误
"""
return self.error_info is not None or self.column_error_info is not None

@property
def column_error_info(self) -> Optional[ParseErrorInfo]:
"""
列错误信息
"""
return self.column_info.error

@property
def error_info(self) -> Optional[ParseErrorInfo]:
"""
指标错误信息
"""
return self.__error

@property
def duplicated_error(self) -> bool:
"""
是否是重复的指标
"""
return self.__error and self.__error.duplicated

@property
def data(self) -> Union[Dict, None]:
"""
指标数据的data字段
"""
if self.error:
return None
return self.metric["data"]
Loading