kaiwu.common._checkpoint 源代码
# -*- coding: utf-8 -*-
"""
管理 checkpoint 保存的路径和各个对象 checkpoint 的命名。
该模块提供了一个 `CheckpointManager` 类,用于管理 checkpoint 的保存路径和命名规则。
通过设置 `CheckpointManager.save_dir`,可以指定 checkpoint 的保存目录。
"""
import os
import json
import logging
logger = logging.getLogger(__name__)
[文档]
class CheckpointManager:
"""
管理 checkpoint, 保存对象的运行状态,用于后续断点处恢复运行
通过设置 `CheckpointManager.save_dir`,可以指定 checkpoint 的保存目录。
Args:
save_dir (str): checkpoint 的保存目录。
"""
save_dir = None
_class_name_counter = {} # 用于更新生成name的id
_dict_obj_identity = {} # used to record the mapping of obj and savename
@classmethod
def _clear(cls):
"""刷新状态"""
cls._class_name_counter = {}
_dict_obj_file_name = {}
@classmethod
def _get_identity(cls, obj):
"""获取用于保存checkpoint的名字,名字基于类名生成
Args:
obj (Object): 要保存的对象
Returns:
str: 用于保存checkpoint的名字
"""
if cls._dict_obj_identity.get(obj) is not None:
return cls._dict_obj_identity.get(obj)
class_name = obj.__class__.__name__
if class_name in cls._class_name_counter:
cls._class_name_counter[class_name] += 1
else:
cls._class_name_counter[class_name] = 1
cls._dict_obj_identity[obj] = (
f"{class_name}_{str(cls._class_name_counter[class_name])}"
)
return cls._dict_obj_identity.get(obj)
[文档]
@classmethod
def get_path(cls, obj):
"""获取对象checkpoint的路径
Args:
obj (Object): 保存的对象
Returns:
str: checkpoint路径
"""
identity = CheckpointManager._get_identity(obj)
return os.path.join(CheckpointManager.save_dir, identity + "_checkpoint.json")
[文档]
@classmethod
def load(cls, obj):
"""加载串行化的对象
Args:
obj (Object): 保存的对象
Returns:
str: json dict形式的对象
"""
if CheckpointManager.save_dir is None:
return None
json_dict = None
if os.path.exists(CheckpointManager.get_path(obj)):
with open(
CheckpointManager.get_path(obj), "r", encoding="utf8"
) as load_file:
json_dict = json.load(load_file)
logger.info(
"The previous state loaded in %s. clear the folder %s if it is not your will.",
CheckpointManager.save_dir,
CheckpointManager.save_dir,
)
return json_dict
[文档]
@classmethod
def dump(cls, obj):
"""对象串行化后存储在磁盘上
Args:
obj (Object): 保存的对象
Returns:
None
"""
if CheckpointManager.save_dir is None:
return
with open(CheckpointManager.get_path(obj), "w", encoding="utf8") as save_file:
json.dump(obj.to_json_dict(), save_file)