kaiwu.core._binary_expression 源代码
"""
BinaryExpression
Binary变量构成的表达式
"""
import copy
import math
import numbers
import numpy as np
from kaiwu.core._expression import Expression, expr_add, expr_neg, expr_mul, expr_pow
from kaiwu.core._error import KaiwuError
[文档]
class BinaryExpression(Expression):
"""QUBO表达式的基础数据结构"""
[文档]
def feed(self, feed_dict):
"""为占位符号赋值, 并返回赋值后的新表达式对象
Args:
feed_dict(dict): 需要赋值的占位符的值
Examples:
>>> import kaiwu as kw
>>> p = kw.core.Placeholder('p')
>>> a = kw.core.Binary('a')
>>> y = p * a
>>> str(y) # doctest: +NORMALIZE_WHITESPACE
'(p)*a'
>>> y= y.feed({'p': 2})
>>> str(y) # doctest: +NORMALIZE_WHITESPACE
'2*a'
"""
ret = copy.deepcopy(self)
for expr_vars in ret.coefficient:
if not isinstance(ret.coefficient[expr_vars], numbers.Number):
ret.coefficient[expr_vars] = ret.coefficient[expr_vars].feed(feed_dict)
if not isinstance(ret.offset, numbers.Number):
ret.offset = ret.offset.feed(feed_dict)
return ret
def __repr__(self):
return self.__str__()
def __radd__(self, other):
return self.__add__(other)
def __add__(self, other):
if isinstance(other, np.ndarray):
return other.__add__(self)
result = BinaryExpression()
expr_add(self, other, result)
return result
def __rsub__(self, other):
return (-self).__add__(other)
def __sub__(self, other):
result = self.__add__(-other)
return result
def __neg__(self):
result = BinaryExpression()
expr_neg(self, result)
return result
def __rmul__(self, other):
return self.__mul__(other)
def __mul__(self, other):
result = BinaryExpression()
expr_mul(self, other, result)
return result
def __pow__(self, other):
result = BinaryExpression()
expr_pow(self, other, result)
return result
[文档]
class Binary(BinaryExpression):
"""二进制变量, 只保存变量名,不继承 QuboExpression"""
def __init__(self, name: str = ""):
super().__init__({(name,): 1}, 0)
self.name = name
[文档]
def clear(self):
self.name = ""
self.coefficient = {}
self.offset = 0
[文档]
class Integer(BinaryExpression):
"""整数变量, 只保存变量名和范围,不继承 QuboExpression"""
def __init__(self, name: str = "", min_value=0, max_value=127):
super().__init__()
self.offset = min_value
self.coefficient = {}
# constructing self.coefficient
if max_value <= min_value:
raise KaiwuError("max_value must be larger than min_value")
num_bits = int(math.log2(max_value - min_value))
for j in range(num_bits):
self.coefficient[(f"{name}[{j}]",)] = 2**j
self.coefficient[(f"{name}[{num_bits}]",)] = (
max_value - min_value - 2 ** (num_bits) + 1
)
[文档]
class Placeholder(BinaryExpression):
"""占位符变量, 只保存变量名, 对决策"""
def __init__(self, name: str = ""):
super().__init__()
self.name = name
self.coefficient = {}
self.offset = _Placeholder({tuple({name}): 1}, 0)
[文档]
def get_placeholder_set(self):
"""获取占位符集合"""
placeholder_set = set()
for var_tuple in self.coefficient:
for var in var_tuple:
placeholder_set.add(var)
class _Placeholder(Expression):
"""占位符的底层实现,实际在QuboExpression的dict结构的参数位置"""
def feed(self, feed_dict):
"""为占位符赋值"""
placeholder_value = 0
for placeholder_vars in self.coefficient:
p_value = self.coefficient[placeholder_vars]
for p_var in placeholder_vars:
p_value *= feed_dict[p_var]
placeholder_value += p_value
placeholder_value += self.offset
return placeholder_value
[文档]
def quicksum(qubo_expr_list: list):
"""高性能的QUBO求和器.
Args:
qubo_expr_list (QUBO列表): 用于求和的QUBO表达式的列表.
Returns:
BinaryExpression: 约束QUBO.
Examples:
>>> import kaiwu as kw
>>> qubo_list = [kw.core.Binary(f"b{i}") for i in range(10)] # Variables are also QUBO
>>> output = kw.core.quicksum(qubo_list)
>>> str(output)
'b0+b1+b2+b3+b4+b5+b6+b7+b8+b9'
"""
qsum = BinaryExpression()
for single_q in qubo_expr_list:
if isinstance(single_q, numbers.Number):
qsum.offset += single_q
continue
if not isinstance(single_q, Expression):
raise KaiwuError("qubo_expr_list should be a list of QUBO Expression")
qsum.offset += single_q.offset
for ele in single_q.coefficient.keys():
if ele in qsum.coefficient:
qsum.coefficient[ele] += single_q.coefficient[ele]
else:
qsum.coefficient[ele] = single_q.coefficient[ele]
if qsum.coefficient[ele] == 0:
qsum.coefficient.pop(ele)
return qsum
if __name__ == "__main__":
import doctest
doctest.testmod()