Skip to content

Commit

Permalink
预测前加载模型,预测后删除模型
Browse files Browse the repository at this point in the history
  • Loading branch information
YoungHector committed Dec 1, 2024
1 parent cb630e3 commit 1231889
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
11 changes: 10 additions & 1 deletion src/_data_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ class AutoMaintainer(object):
'''
def __init__(self):
self.pm = PostManager(max_workers=1) # max=1 即为同步
self.model = MODEL_DICT['decision_tree_model']
self.model = None


# 用[数据库里的datasource_id] 查询对应的[config].
Expand Down Expand Up @@ -312,6 +312,11 @@ def daily_model_predict(self):
每日更新模型全量预测结果
"""
# 拆成24个小时的数据运行

# 加载模型
MODEL_DICT.load_model('decision_tree_model_v2', path_prefix='../')
self.model = MODEL_DICT['decision_tree_model']

self._model_predicted_result_pool = []
for i in range(24):
try:
Expand Down Expand Up @@ -377,6 +382,10 @@ def limit_cpu(interval):
messager.send_to_bot_shortcut('出现报错,详细信息为:')
messager.send_to_bot_shortcut(str(e))

# 删除模型。
MODEL_DICT.model_dict.pop('decision_tree_model')
self.model = None

def get_post_data_list(self, pending_datasources_id_list, maintainer:Maintainer):
"""
Expand Down
6 changes: 3 additions & 3 deletions src/auto_sche/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ def __init__(self):

def load_all_model(self):
for name in [
'decision_tree_model_v2',
# 'decision_tree_model_v2',
'weekday_encoder_v2',
'datasource_encoder_v2'
]:
self.load_model(name)

def load_model(self, name, suffix='.joblib'):
def load_model(self, name, suffix='.joblib', path_prefix=''):
# 加载模型、编码器等等
loaded_model = load('./ml_model/{}{}'.format(name, suffix))
loaded_model = load('{}./ml_model/{}{}'.format(path_prefix, name, suffix))
self.model_dict[name.split('_v')[0]] = loaded_model

def __getitem__(self, key):
Expand Down

0 comments on commit 1231889

Please sign in to comment.