kaiwu.common._json_serializable_mixin 源代码

# -*- coding: utf-8 -*-
"""
SDK 公用模块
"""
import numpy as np


[文档] class JsonSerializableMixin: """ 序列化器 """
[文档] def to_json_dict(self, exclude_fields=("_optimizer",)): """转化为json字典 Returns: dict: json字典 """ object_dict = self.__dict__ json_dict = {} for attr_name, attr_value in object_dict.items(): if attr_name in exclude_fields: continue attr_type = attr_name + "$type" if isinstance(attr_value, np.ndarray): json_dict[attr_name] = attr_value.tolist() json_dict[attr_type] = "np.ndarray" elif isinstance(attr_value, list): json_dict[attr_type] = "list" json_dict[attr_name] = [ item.to_json_dict() if hasattr(item, "to_json_dict") else item for item in attr_value ] elif hasattr(attr_value, "to_json_dict"): data = attr_value.to_json_dict() json_dict[attr_name] = data json_dict[attr_type] = "JsonSerializableMixin" elif isinstance(attr_value, np.number): json_dict[attr_name] = float(attr_value) elif attr_name == "sub_indices" and attr_value is not None: json_dict[attr_name] = np.array(attr_value).tolist() elif attr_name == "rng": state = self.rng.__getstate__() json_dict[attr_name] = state else: json_dict[attr_name] = attr_value return json_dict
[文档] def load_json_dict(self, json_dict): """从json文件读取的dict恢复对象 Returns: dict: json字典 """ # Nothing stored before, just return if json_dict is None: return param_dict = json_dict.copy() for attr_name, attr_value in param_dict.items(): if "$type" in attr_name: continue attr_type = json_dict.get(attr_name + "$type") if isinstance(attr_value, dict) and attr_type == "JsonSerializableMixin": instance = getattr(self, attr_name) instance.load_json_dict(attr_value) continue if attr_type == "np.ndarray": attr_value = np.array(attr_value) elif attr_name == "rng": rng_obj = np.random.default_rng() rng_obj.__setstate__(attr_value) attr_value = rng_obj setattr(self, attr_name, attr_value)