# -*- coding: utf-8 -*-
"""
提供SPQC求解器
"""
import hashlib
import io
import json
import os
import time
import pickle
import logging
from urllib.parse import urljoin
import numpy as np
import requests
import pandas as pd
from kaiwu.core import IsingSolver, QuboSolver
from kaiwu.license._license_utils import (
_read_license,
ensure_license,
LICENSE_FILE_PATH,
)
from kaiwu.common import CheckpointManager as ckpt
from kaiwu.license.license_settings import GATEWAY_DOMAIN_NAME
logger = logging.getLogger(__name__)
CREATE_TASK_URL = "/api/system/software_task_manager_pro/create_sdk_task/"
GET_TASK_RESULT_URL = "/api/system/software_task_manager_pro/get_sdk_task_result/"
SIGN_OSS_URL = "/api/system/file/oss_signed_url/"
[文档]
class TaskMode:
"""
任务模式
OPTIMIZATION: 优化模式。原QUOTA模式
SAMPLING: 采样模式。原SAMPLE模式
"""
OPTIMIZATION = "optimization"
SAMPLING = "sampling"
class TaskStatusCode:
"""任务状态码常量类
定义了SPQC任务的各种状态码,用于标识任务的处理阶段。
"""
TASK_DOES_NOT_EXISTS = "10320" # 任务不存在
COMPLETED = "1000" # "计算成功"
VALIDATING = "2001" # "校验中"
CALCULATING = "2002" # "计算中"
VALIDATE_FAILED = "4001" # "校验失败"
class ResponseStatusCode:
"""响应状态码常量类
定义了API响应的各种状态码,用于判断请求是否成功。
"""
SUCCESS = "0"
AUTH_FAILED = 401
[文档]
class CIMOptimizer(IsingSolver, QuboSolver):
"""CIM Optimizer Interface
CIMOptimizer 是一种用于求解 Ising 计算问题的优化器 (Optimizer),它通过提交任务到
专用量子计算机 (SPQC, Special-Purpose Quantum Computer) 真机进行计算,并返回最优解。
主要功能包括:
1. **任务提交**:将 Ising 矩阵任务上传至 SPQC 计算平台,并创建计算任务。
2. **任务查询**:定期检查任务计算状态,获取计算结果。
3. **缓存管理**:本地缓存已计算任务的结果,避免重复提交。
Args:
task_name (str): 任务名称
wait (bool, optional): 是否等待计算完成,默认为 False。
interval (int, optional): 轮询间隔时间(分钟),默认值为 1,最小值 1 分钟。
project_no (str, optional): 项目编号,值为CPQC-X中项目列表的项目ID, 用于创建项目下的任务
task_mode (str): 计算模式,默认为TaskMode.OPTIMIZATION, 可选值为 TaskMode.OPTIMIZATION, TaskMode.SAMPLING
sample_number (int): 采样次数, task_mode=TaskMode.SAMPLING时必填,默认值10,最小值10, 最大值2000
Example:
>>> import numpy as np
>>> import kaiwu as kw
>>> from kaiwu.cim import TaskMode
>>> kw.common.CheckpointManager.save_dir = '/tmp'
>>> matrix = -np.array([[ 0. , 1. , 0. , 1. , 1. ],
... [ 1. , 0. , 0. , 1., 1. ],
... [ 0. , 0. , 0. , 1., 1. ],
... [ 1. , 1., 1. , 0. , 1. ],
... [ 1. , 1., 1. , 1. , 0. ]])
>>> optimizer = kw.cim.CIMOptimizer(
... task_name='cim_optimizer_test',
... task_mode=TaskMode.OPTIMIZATION
... ) # doctest: +SKIP
>>> solution = optimizer.solve(matrix) # doctest: +SKIP
>>> print(solution) # doctest: +SKIP
array([[-1, 1, 1, -1, -1],
[-1, 1, 1, 1, -1],
[-1, 1, 1, -1, 1],
[ 1, -1, -1, -1, 1],
[ 1, -1, 1, 1, -1],
[ 1, -1, 1, -1, 1],
[ 1, -1, 1, -1, -1],
[ 1, -1, -1, 1, 1],
[-1, -1, -1, 1, 1],
[ 1, 1, 1, -1, -1]], dtype=int8)
>>> kw.common.CheckpointManager.save_dir = None
Notes:
1. 需要通过 `CheckpointManager` 设置中间文件保存路径 (`save_dir`)。
2. 任务的唯一标识由 `ising_matrix` 和 `task_name` 共同决定,任何一项的变化都会创建新的任务。
3. 同一个矩阵可以通过修改 `task_name` 创建不同的任务,若仅需查询结果,请确保 `task_name` 不变。
4. 实例化CIMOptimizer时 `task_name` 必传
"""
def __init__(
self,
task_name,
wait=False,
interval=1,
project_no=None,
task_mode=TaskMode.OPTIMIZATION,
sample_number=10,
):
super().__init__()
self.task_name = task_name
self.wait = wait
self.save_dir = ckpt.save_dir
self.project_no = project_no
self.task_mode = task_mode
self.sample_number = sample_number
# 设置轮询间隔,最小1分钟
self.interval = max(1, interval)
assert task_mode in [
TaskMode.OPTIMIZATION,
TaskMode.SAMPLING,
], f"task_mode 必须是 [{TaskMode.OPTIMIZATION}, {TaskMode.SAMPLING}]中的一个"
if self.task_mode == TaskMode.SAMPLING:
if not sample_number:
raise ValueError(
"sample_number is required when task_mode is TaskMode.SAMPLING"
)
if not isinstance(sample_number, int):
raise ValueError("sample_number must be an integer value")
if sample_number < 10 or sample_number > 2000:
raise ValueError("sample_number must be between 10 and 2000 inclusive")
if self.save_dir is None:
raise ValueError("The save directory is required")
ensure_license()
self._authorization = self._get_authorization()
@staticmethod
def _ising_to_csv(ising_matrix: np.ndarray) -> str:
"""把 Ising 矩阵转成 CSV 字符串(不含行列名)。"""
return pd.DataFrame(ising_matrix).to_csv(
index=False, header=False, lineterminator="\n"
)
@staticmethod
def _str_to_md5_hash(data: str) -> str:
"""返回给定字符串的 MD5 哈希值(十六进制)。"""
return hashlib.md5(data.encode("utf-8")).hexdigest()
@staticmethod
def _ising_to_md5_hash(ising_matrix: np.ndarray) -> str:
matrix_str = CIMOptimizer._ising_to_csv(ising_matrix)
return CIMOptimizer._str_to_md5_hash(matrix_str)
@staticmethod
def _process_auth_failed(callback_func, *args, **kwargs):
try:
if os.path.exists(LICENSE_FILE_PATH):
os.remove(LICENSE_FILE_PATH)
except PermissionError as exc:
raise PermissionError(
f"Permission denied when deleting invalid license file: {LICENSE_FILE_PATH}"
) from exc
except OSError as exc:
raise OSError(
f"OS error when deleting invalid license file: {LICENSE_FILE_PATH}: {exc}"
) from exc
ensure_license()
return callback_func(*args, **kwargs)
def _check_existing_result(self, task_name: str, result_filename: str):
"""检查任务是否已有结果文件
Args:
task_name (str): 任务名称
result_filename (str): 文件路径
Returns:
np.ndarray | None: 如果存在结果文件则返回结果,否则返回None
"""
solution_file_path = os.path.join(self.save_dir, result_filename)
if os.path.exists(solution_file_path):
with open(solution_file_path, "rb") as file:
result = pickle.load(file)
logger.info("Task calculation successful!, Task name: %s", task_name)
return result
return None
def _solve(self, ising_matrix=None):
"""入口函数
Args:
ising_matrix (np.ndarray): Ising矩阵
Returns:
np.ndarray | None:
- 解向量集合(任务完成时)
- None(任务仍在计算中)
"""
self.set_matrix(ising_matrix)
csv_data = self._ising_to_csv(ising_matrix)
file_hash = self._str_to_md5_hash(csv_data)
# 目录不存在就创建目录
if not os.path.exists(self.save_dir):
os.mkdir(self.save_dir)
task_name = self.task_name
# 检查当前任务是否已获取结果
result_filename = f"{task_name}_{file_hash}_results.pkl"
result = self._check_existing_result(task_name, result_filename)
if result is not None:
return result
# 构建结果文件路径
solution_file_path = os.path.join(self.save_dir, result_filename)
# 获取任务结果
while True:
result = self._get_or_create_task(
csv_data, task_name, solution_file_path, file_hash
)
if result is not None or not self.wait:
break
time.sleep(self.interval * 60)
return result
[文档]
def get_task_result(self, ising_matrix: np.ndarray) -> dict:
"""
获取任务结果
"""
file_hash = self._ising_to_md5_hash(ising_matrix)
return self._get_task_result_internal(
task_name=self.task_name, file_hash=file_hash
)
def _load_from_file(self, file_path):
"""读取文件
Args:
file_path (str): 获取的文件路径
Returns:
np.ndarray: 读取的文件内容
"""
return pd.read_csv(file_path, header=None, index_col=None).to_numpy()
def _get_authorization(self) -> str:
"""依据license组装基础 header"""
try:
decrypted_message, sdk_code = _read_license()
except ValueError:
return self._process_auth_failed(self._get_authorization)
user_id = decrypted_message.get("user_id")
authorization = f"{user_id}/{sdk_code}"
self._authorization = authorization
return authorization
def _get_task_result_internal(self, task_name: str, file_hash: str):
url = urljoin(GATEWAY_DOMAIN_NAME, GET_TASK_RESULT_URL)
try:
params = {"task_name": task_name, "md5_num": file_hash}
headers = {
"Authorization": self._get_authorization(),
"Content-Type": "application/json",
}
response = requests.get(url, headers=headers, params=params, timeout=10)
response.raise_for_status()
return response.json()
except requests.RequestException as error:
logger.error("Request error: %s", error)
raise
def _upload_file_to_oss(self, file_name: str, file_like_object: object):
"""
通过后端签名 URL, 上传本地文件到 OSS.
:param file_path: 本地文件路径
:return:oss 文件地址
"""
headers = {"Content-Type": "text/csv"}
absolute_url = urljoin(GATEWAY_DOMAIN_NAME, SIGN_OSS_URL)
try:
# 获取 OSS 签名上传地址
response = requests.get(
absolute_url,
params={
"filename": file_name,
"request_headers": json.dumps(headers),
"need_sign": 1,
},
timeout=10,
headers={"Authorization": self._authorization},
)
response.raise_for_status()
data = response.json().get("data", {})
oss_upload_url = data.get("url")
oss_file_url = data.get("oss_file_url")
if not oss_upload_url:
raise ValueError("未获取到上传 URL")
# 上传文件到 OSS
upload_response = requests.put(
oss_upload_url, data=file_like_object, headers=headers, timeout=10
)
upload_response.raise_for_status()
except requests.RequestException as exc:
raise ValueError(f"Network error during file upload: {str(exc)}") from exc
except Exception as e:
raise ValueError(f"Unexpected error during file upload: {str(e)}") from e
return oss_file_url
def _create_task(self, csv_data, task_name, file_hash):
"""创建任务
Args:
csv_data (str): CSV数据
task_name (str): 任务名称
file_hash (str): 文件哈希值
Raises:
ValueError: 任务创建失败时抛出
"""
try:
file_like_object = io.StringIO(csv_data)
file_name = f"{task_name}.csv"
oss_file_url = self._upload_file_to_oss(
file_name=file_name, file_like_object=file_like_object
)
data = {
"task_name": task_name,
"file_hash": file_hash,
"task_source": 6,
"bit_num": self.matrix.shape[0],
"project_no": self.project_no,
"oss_file_url": oss_file_url,
"file_name": file_name,
"task_mode": self.task_mode,
"sample_number": self.sample_number,
"sample_sort_mode": 2, # 每个SHOT选择TOP1
}
absolute_url = urljoin(GATEWAY_DOMAIN_NAME, CREATE_TASK_URL)
response = requests.post(
absolute_url,
timeout=(100, 300),
headers={"Authorization": self._authorization},
data=data,
)
response.raise_for_status()
json_response = response.json()
if json_response.get("code") == ResponseStatusCode.AUTH_FAILED:
self._process_auth_failed(
callback_func=self._create_task,
csv_data=csv_data,
task_name=task_name,
file_hash=file_hash,
)
if json_response.get("code") != ResponseStatusCode.SUCCESS:
raise ValueError(f"Task creation failed: {json_response.get('msg')}")
except requests.RequestException as exc:
raise ValueError(f"Network error during task creation: {str(exc)}") from exc
except Exception as exc:
raise ValueError(f"Task creation failed: {str(exc)}") from exc
logger.info(
"Task submit successfully, waiting for data validation. Task name: %s",
task_name,
)
def _query_task_status(self, task_name: str, file_hash: str):
"""查询任务状态
Args:
task_name (str): 任务名称
file_hash (str): 文件哈希值
Returns:
dict: 任务状态信息,如果任务不存在则返回None
"""
response_data = self._get_task_result_internal(
task_name=task_name, file_hash=file_hash
)
if response_data.get("code") == TaskStatusCode.TASK_DOES_NOT_EXISTS:
return None # 任务不存在
if response_data.get("code") == ResponseStatusCode.AUTH_FAILED:
return self._process_auth_failed(
callback_func=self._query_task_status,
task_name=task_name,
file_hash=file_hash,
)
if (
response_data.get("code") != ResponseStatusCode.SUCCESS
or "data" not in response_data
):
raise ValueError(f"Failed to retrieve task: {response_data}")
return response_data["data"]
def _handle_task_status(self, task_data: dict, task_name: str):
"""处理任务状态
Args:
task_data (dict): 任务数据
task_name (str): 任务名称
Returns:
str | None: 任务状态描述,如果任务完成则返回'task_status_code'
"""
task_status_code = task_data["code"]
if task_status_code == TaskStatusCode.VALIDATE_FAILED:
raise ValueError(
f"Verification failed for task {task_name}: {task_data.get('desc')}"
)
if task_status_code == TaskStatusCode.VALIDATING:
logger.info("The task is being verified: %s", task_name)
return None
if task_status_code == TaskStatusCode.COMPLETED:
logger.info("Task completed: %s", task_name)
return task_status_code
logger.info("Task is still processing: %s", task_name)
return None
def _save_task_result(self, task_data: dict, solution_file_path: str):
"""保存任务结果到文件
Args:
task_data (dict): 任务数据
solution_file_path (str): 结果文件路径
Returns:
np.ndarray: 解决方案数组
"""
solutions = self._load_from_file(task_data.get("result"))
with open(solution_file_path, "wb") as file:
pickle.dump(solutions, file)
return solutions
def _get_or_create_task(self, csv_data, task_name, solution_file_path, file_hash):
"""获取任务结果或创建新任务
Args:
csv_data (str): CSV数据
task_name (str): 任务名称
solution_file_path (str): 结果文件路径
file_hash (str): 文件哈希值
Returns:
np.ndarray | None: 任务结果或None
"""
# 查询任务状态
task_data = self._query_task_status(task_name, file_hash)
if task_data is None:
# 此处兼容1.3.0 版本使用 task_name_prefix + file_hash作为完整任务名的需求,防止用户在更新后无法获取1,3,0版本提交的任务
# 在后续版本中,不再使用这种组合方式作为任务名,
old_task_data_name = f"{task_name}_{file_hash}"
task_data = self._query_task_status(old_task_data_name, file_hash)
# 如果任务不存在,创建新任务
if task_data is None:
self._create_task(csv_data, task_name, file_hash)
return None
# 处理任务状态
status_code = self._handle_task_status(task_data, task_name)
if status_code != TaskStatusCode.COMPLETED:
return None
# 保存并返回结果
return self._save_task_result(task_data, solution_file_path)
if __name__ == "__main__":
import doctest
doctest.testmod()