kaiwu.license._license_utils 源代码

"""
用户授权校验相关内容
"""

import os
import json
import logging
from urllib.parse import urljoin
import requests
import jwt
from kaiwu.license.license_settings import (
    KDEV_URL,
    SECRET_KEY,
    ALGORITHM,
    PLATFORM_DOMAIN,
)

logger = logging.getLogger(__name__)

# kaiwu目录路径
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# license 路径
LICENSE_FILE_PATH = os.path.join(BASE_DIR, "license.lic")
# 获取 license
LICENSE_URL = urljoin(KDEV_URL, "/kdev/sdk-code/license")
# success response code
SUCCESSCODE = "0"

SDK_CODE_SETTINGS = {
    "ALGORITHM": locals().get("SDK_CODE_LICENSE_ALGORITHM", ALGORITHM),
    "VERIFYING_KEY": locals().get("SDK_CODE_LICENSE_VERIFYING_KEY", SECRET_KEY),
}


def _decode_license(encrypted_message):
    """
    解析license

    :param encrypted_message: 加密字符串
    :return: 用户授权信息 e.g
        {
            "user_id": 468753245829857281,
            "user_level": 1,
            "qubits": 30,
            "exp": 1691549171
        }
    """
    verifying_key = SDK_CODE_SETTINGS["VERIFYING_KEY"]
    try:
        payload = jwt.decode(
            encrypted_message,
            verifying_key,
            algorithms=[SDK_CODE_SETTINGS["ALGORITHM"]],
        )
    except:
        # 加密算法或密钥与开发者平台不一致、license过期都会抛出此异常
        raise ValueError(
            f"SDK authorization code expired. Please log in to {PLATFORM_DOMAIN} for support."
        )
    return payload


def _download_license(user_id, sdk_code):
    """
    通过API获取license

    :param user_id: 用户ID
    :param sdk_code: SDK授权码
    :return: license字符串
    """
    data = {"user_id": user_id, "sdk_code": sdk_code}
    try:
        response = requests.post(
            LICENSE_URL,
            data=json.dumps(data),
            timeout=10,
            headers={"Content-Type": "application/json"},
        ).json()
    except Exception as exc:
        logger.debug(str(exc))
        raise ValueError(
            f"License download failed, please log in to {PLATFORM_DOMAIN} for support"
        )
    if response.get("code") != SUCCESSCODE:
        logger.debug(response.get("msg"))
        raise ValueError(
            "SDK authorization code expired or incorrect, "
            f"please log in to {PLATFORM_DOMAIN} for support."
        )
    return response.get("data")


def _save_license_file(encrypted_message):
    """保存license文件"""
    with open(LICENSE_FILE_PATH, "wb") as lic_file:
        lic_file.write(encrypted_message.encode())


def _read_license():
    """读取license信息,返回解码后的license和授权码"""
    # license文件不存在时抛出此异常
    if not os.path.exists(LICENSE_FILE_PATH):
        raise ValueError(
            "Please follow the tutorial to generate the license first. "
            f"If there is no sdk authorization code, please log in to {PLATFORM_DOMAIN} for support."
        )
    # 读取license文件
    with open(LICENSE_FILE_PATH, "rb") as license_file:
        license_info = license_file.read()
        license_info = license_info.decode()
    # 拆分真正的license和sdk code
    try:
        encrypted_message, sdk_code = license_info.split("_sdkCode_")
        if len(sdk_code) == 0:
            raise ValueError
    except:
        # 用户手动修改了license文件可能会导致此异常
        raise ValueError(
            "License error, please follow the tutorial to regenerate license"
            f"If there is no sdk code, please log in to {PLATFORM_DOMAIN} for support."
        )
    decrypted_message = _decode_license(encrypted_message)
    return decrypted_message, sdk_code


def _generate_license_file(user_id, sdk_code):
    """
    通过查询授权信息生成license文件, 并返回license信息

    :param user_id: 用户ID
    :param sdk_code: SDK授权码
    :return: 授权信息 e.g {'user_id': 6, 'user_level': 1, 'qubits': 30, 'exp': 1691668756}
    """
    encrypted_message = _download_license(user_id, sdk_code)
    # 解析license文件
    decrypted_message = _decode_license(encrypted_message)
    # 缓存license文件
    if decrypted_message is not None:
        # 缓存时加入sdk_code
        encrypted_message_code = encrypted_message + "_sdkCode_" + sdk_code
        _save_license_file(encrypted_message_code)
    return decrypted_message


[文档] def init(user_id, sdk_code): """ 初始化生成license文件, 每次调用都会重新生成license文件 :param user_id: 用户ID :param sdk_code: SDK授权码 """ # 获取license并生成license文件 try: _ = _generate_license_file(user_id, sdk_code) except Exception as exc: raise ValueError(str(exc)) from None
[文档] def ensure_license(): """ 检查license文件是否存在,如果不存在,在控制台提示用户输入user_id 和 sdk_code来下载license """ try: if not os.path.exists(LICENSE_FILE_PATH): logger.info( "Please log in to %s for support to get user_id and sdk_code.", PLATFORM_DOMAIN, ) user_id = input("Please enter user ID: ").strip() sdk_code = input("Please enter sdk code: ").strip() init(user_id, sdk_code) except Exception as exc: raise ValueError(str(exc)) from None
def verify_license(qubits): """ license校验 :param qubits: 比特数 """ try: ensure_license() decrypted_message, _ = _read_license() if qubits > decrypted_message.get("qubits"): raise ValueError( f"The maximum number of bits supported by the license: {decrypted_message.get('qubits')}, " f"Please log in to {PLATFORM_DOMAIN} for more bit support." ) except Exception as exc: raise ValueError(f"{exc}") from None