当前位置:   article > 正文

国密算法 SM9 公钥加密 数字签名 密钥交换 基于身份的密码算法(IBC)高效python代码_sm9算法

sm9算法

接上篇:国密算法 SM9 公钥加密 数字签名 密钥交换 基于身份的密码算法(IBC)完整高效的开源python代码-CSDN博客

接触过国密算法Python库——hggm的朋友可能知道,库中SM2、SM3、SM4、ZUC都区分了慢速版和快速版:慢速版纯Python实现,代码逻辑清晰适合教学;快速版用了外挂(如numba),牺牲代码可读性或部分Python特性以追求更高的效率。上篇SM9代码虽然集成了若干关于效率优化的文献研究成果,但毕竟是纯Python实现的,我还是放在了慢速版(/slow目录下),这显然是给自己挖了坑。没发现合适的外挂,故而继续在数学上做文章,在上篇代码的基础上,主要进行了3项优化:

①固定点乘预计算。SM9的两个椭圆曲线群G1、G2的生成元分别是P1、P2,还有用户签名私钥ds,算法中多次出现与这些固定点的点乘运算(k·P),类似SM2,可提前计算好k的每一个字节位置与P相乘的结果并保存,后续点乘运算则转变为31次点加法。实测有6.~倍的效率提升。

②采用Comb固定基的高次幂运算。算法中多次出现以gs和ge为底的幂运算(gs=e(P1, Ppub_s), ge=e(Ppub_e, P2)),提前计算好gs和ge的预计算表,后续幂运算转变为31次乘法和31次平方计算(见文献:王江涛,樊荣,黄哲. SM9中高次幂运算的快速实现方法[J]. 计算机工程,2023,49(9):118-124,136. DOI:10.19678/j.issn.1000-3428.0065618)。实测有4倍左右的效率提升。

③固定G2点的双线性对预计算。用户加密私钥de(椭圆曲线群G2上的点)经常参与双线性对计算,可预先计算好双线性对Miller循环中由该点衍生出的所有系数,后续计算双线性对时可简化线函数,并省略点加。不过实测只有10%左右的效率提升。

当然上述优化都是针对特定的计算步骤,签名验签、加解密、密钥交换等算法中也包含了许多不符合上述情况的耗时步骤,因此实际提升幅度没有那么大。虽然依旧是纯Python实现,但上述优化也以牺牲可读性为代价,所以放在了快速版这边。

完整代码如下:

  1. import os
  2. from random import randrange
  3. from math import ceil
  4. from .SM3 import digest as sm3
  5. # SM9总则(GB_T 38635.1-2020) A.1 系统参数
  6. q = 0XB640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D # 基域特征
  7. N = 0XB640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25 # 群的阶
  8. # 群G1的生成元 P1=(x_p1, y_p1)
  9. x_p1 = 0X93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD
  10. y_p1 = 0X21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616
  11. # 群G2的生成元 P2=(x_p2, y_p2)
  12. x_p2 = (0X85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141,
  13. 0X3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B)
  14. y_p2 = (0X17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96,
  15. 0XA7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7)
  16. HASH_SIZE = 32 # sm3输出256位(32字节)
  17. N_SIZE = 32 # 阶的字节数
  18. KEY_LEN = 128 # 默认密钥位数
  19. K2_len = 256 # MAC函数中密钥K2的位数
  20. def to_byte(x, size=None):
  21. if type(x) is int:
  22. return x.to_bytes(size if size else ceil(x.bit_length() / 8), byteorder='big')
  23. elif type(x) in (str, bytes):
  24. x = x.encode() if type(x) is str else x
  25. return x[:size] if size and len(x) > size else x # 超过指定长度,则截取左侧字符
  26. elif type(x) in (tuple, list):
  27. return b''.join(to_byte(c, size) for c in x)
  28. return bytes(x)[:size] if size else bytes(x)
  29. # 将字节转换为int
  30. def to_int(byte):
  31. return int.from_bytes(byte, byteorder='big')
  32. # 广义的欧几里得除法求模逆(耗时约为slow/SM2代码内get_inverse函数的43%)
  33. def mod_inv(a, mod=q):
  34. if a == 0:
  35. return 0
  36. lm, low, hm, high = 1, a % mod, 0, mod
  37. while low > 1:
  38. r = high // low
  39. lm, low, hm, high = hm - lm * r, high - low * r, lm, low
  40. return lm % mod
  41. class FQ:
  42. def __init__(self, n):
  43. self.n = n
  44. def __add__(self, other):
  45. return FQ(self.n + other.n)
  46. def __sub__(self, other):
  47. return FQ(self.n - other.n)
  48. def __mul__(self, other): # 右操作数可为int
  49. return FQ(self.n * (other.n if type(other) is FQ else other) % q)
  50. def __truediv__(self, other): # 右操作数可为int
  51. return FQ(self.n * mod_inv(other.n if type(other) is FQ else other) % q)
  52. def __pow__(self, other): # 操作数应为int
  53. return FQ(pow(self.n, other, q) if other else 1)
  54. def __eq__(self, other): # 右操作数可为int
  55. return self.n % q == (other.n if type(other) is FQ else other) % q
  56. def __neg__(self):
  57. return FQ(-self.n)
  58. def __repr__(self):
  59. return 'FQ(%064X)' % (self.n % q)
  60. def __bytes__(self):
  61. return to_byte(self.n % q, N_SIZE)
  62. def is_zero(self):
  63. return self.n % q == 0
  64. def inv(self):
  65. return FQ(mod_inv(self.n))
  66. def sqr(self):
  67. return FQ(self.n * self.n % q)
  68. @classmethod
  69. def one(cls):
  70. return cls(1)
  71. @classmethod
  72. def zero(cls):
  73. return cls(0)
  74. class FQ2:
  75. def __init__(self, *coeffs): # 国标中的表示是高位在前,而此处coeffs是低位在前
  76. self.coeffs = coeffs
  77. def __add__(self, other):
  78. (a0, a1), (b0, b1) = self.coeffs, other.coeffs
  79. return FQ2(a0 + b0, a1 + b1)
  80. def __sub__(self, other):
  81. (a0, a1), (b0, b1) = self.coeffs, other.coeffs
  82. return FQ2(a0 - b0, a1 - b1)
  83. def sqr(self):
  84. a0, a1 = self.coeffs
  85. return FQ2((a0 * a0 - (a1 * a1 << 1)) % q, (a0 * a1 << 1) % q) # (a0^2 - 2 * a1^2, 2 * a0 * a1)
  86. def sqr_u(self):
  87. a0, a1 = self.coeffs
  88. return FQ2(-(a0 * a1 << 2) % q, (a0 * a0 - (a1 * a1 << 1)) % q) # (-4 * a0 * a1, a0^2 - 2 * a1^2)
  89. def mul_b_u(self, b): # 带参数乘法
  90. (a0, a1), (b0, b1) = self.coeffs, b.coeffs
  91. return FQ2(-(a0 * b1 + a1 * b0 << 1) % q, (a0 * b0 - (a1 * b1 << 1)) % q) # (-2*(a0*b1+a1*b0), a0*b0-2*a1*b1)
  92. def __mul__(self, other):
  93. if type(other) is int:
  94. a0, a1 = self.coeffs
  95. return FQ2(a0 << 1, a1 << 1) if other == 2 else FQ2(a0 * other % q, a1 * other % q)
  96. (a0, a1), (b0, b1) = self.coeffs, other.coeffs
  97. a0b0, a1b1 = a0 * b0, a1 * b1 # Karatsuba 思想方法(节约一次乘法),实测此处约有5%提升,用在其他地方未见性能提升
  98. return FQ2((a0b0 - (a1b1 << 1)) % q, ((a0 + a1) * (b0 + b1) - (a0b0 + a1b1)) % q) # (a0*b0-2*a1*b1,a0*b1+a1*b0)
  99. def __rmul__(self, other):
  100. return self.__mul__(other)
  101. def __truediv__(self, other):
  102. if type(other) is int:
  103. other_inv = mod_inv(other)
  104. return FQ2([c * other_inv % q for c in self.coeffs])
  105. return self * other.inv()
  106. def inv(self):
  107. a0, a1 = self.coeffs
  108. if a0 == 0:
  109. return FQ2(0, -mod_inv(a1 << 1)) # (0, -(2 * a1)^-1)
  110. if a1 == 0:
  111. return FQ2(mod_inv(a0), 0) # (a0^-1, 0)
  112. k = mod_inv(a0 * a0 + (a1 * a1 << 1)) # k = (a0^2 + 2 * a1^2)^-1
  113. return FQ2(a0 * k % q, -a1 * k % q) # (a0 * k, -a1 * k)
  114. def conjugate(self): # 共轭
  115. a0, a1 = self.coeffs
  116. return self.__class__(a0, -a1)
  117. def get_fp_list(self): # 返回所有基域元素(高位在前)
  118. if type(self) is FQ2:
  119. return [i % q for i in self[::-1]]
  120. return [y for x in self[::-1] for y in x.get_fp_list()] if self.coeffs else [0] * 4 # 注意FQ4对象零值的处理
  121. def __repr__(self):
  122. return '%s(%s)' % (self.__class__.__name__, ', '.join('%064X' % i for i in self.get_fp_list()))
  123. def __bytes__(self): # 字节串高位在前
  124. return to_byte(self.get_fp_list(), N_SIZE)
  125. def __eq__(self, other):
  126. return self.get_fp_list() == other.get_fp_list()
  127. def __neg__(self):
  128. return self.__class__(*[-c for c in self.coeffs])
  129. def __getitem__(self, item):
  130. return self.coeffs[item]
  131. def is_zero(self):
  132. return all(c % q == 0 for c in self.coeffs) if type(self) is FQ2 else all(c.is_zero() for c in self.coeffs)
  133. @classmethod
  134. def one(cls):
  135. return FQ2_one if cls is FQ2 else (FQ12_one if cls is FQ12 else FQ4_one)
  136. @classmethod
  137. def zero(cls):
  138. return FQ2_zero if cls is FQ2 else ()
  139. class FQ4(FQ2): # 零元的coeffs为空,可优化FQ12稀疏乘法运算
  140. def __add__(self, other):
  141. if not self.coeffs:
  142. return other
  143. if not other.coeffs:
  144. return self
  145. (a0, a1), (b0, b1) = self.coeffs, other.coeffs
  146. return FQ4(a0 + b0, a1 + b1)
  147. def __sub__(self, other):
  148. if not self.coeffs:
  149. return -other
  150. if not other.coeffs:
  151. return self
  152. (a0, a1), (b0, b1) = self.coeffs, other.coeffs
  153. return FQ4(a0 - b0, a1 - b1)
  154. def sqr(self):
  155. if not self.coeffs:
  156. return FQ4_zero
  157. a0, a1 = self.coeffs
  158. return FQ4(a0.sqr() + a1.sqr_u(), a0 * a1 * 2) # (a0^2 + a1^2 * u, 2 * a0 * a1)
  159. def sqr_v(self):
  160. if not self.coeffs:
  161. return FQ4_zero
  162. a0, a1 = self.coeffs
  163. return FQ4(a0.mul_b_u(a1) * 2, a0.sqr() + a1.sqr_u()) # (2 * a0 * a1 * u, a0^2 + a1^2 * u)
  164. def mul_b_v(self, b): # 带参数乘法
  165. if not self.coeffs or not b.coeffs:
  166. return FQ4_zero
  167. (a0, a1), (b0, b1) = self.coeffs, b.coeffs
  168. return FQ4(a0.mul_b_u(b1) + a1.mul_b_u(b0), a0 * b0 + a1.mul_b_u(b1)) # (a0*b1*u+a1*b0*u, a0*b0+a1*b1*u)
  169. def __mul__(self, other):
  170. if not self.coeffs:
  171. return FQ4_zero
  172. if type(other) is int:
  173. a0, a1 = self.coeffs
  174. return FQ4(a0 * other, a1 * other)
  175. if not other.coeffs:
  176. return FQ4_zero
  177. (a0, a1), (b0, b1) = self.coeffs, other.coeffs
  178. return FQ4(a0 * b0 + a1.mul_b_u(b1), a0 * b1 + a1 * b0) # (a0*b0+a1*b1*u, a0*b1+a1*b0)
  179. def inv(self):
  180. if not self.coeffs:
  181. return FQ4_zero
  182. a0, a1 = self.coeffs
  183. k = (a1.sqr_u() - a0.sqr()).inv()
  184. return FQ4((-a0 * k), a1 * k)
  185. class FQ12(FQ2):
  186. def __add__(self, other):
  187. (a0, a1, a2), (b0, b1, b2) = self.coeffs, other.coeffs
  188. return FQ12(a0 + b0, a1 + b1, a2 + b2)
  189. def __sub__(self, other):
  190. (a0, a1, a2), (b0, b1, b2) = self.coeffs, other.coeffs
  191. return FQ12(a0 - b0, a1 - b1, a2 - b2)
  192. def sqr(self):
  193. a0, a1, a2 = self.coeffs
  194. return FQ12(a0.sqr() + a1.mul_b_v(a2) * 2, a0 * a1 * 2 + a2.sqr_v(), a0 * a2 * 2 + a1.sqr())
  195. def __mul__(self, other):
  196. (a0, a1, a2), (b0, b1, b2) = self.coeffs, other.coeffs
  197. return FQ12(a0 * b0 + a1.mul_b_v(b2) + a2.mul_b_v(b1), a0 * b1 + a1 * b0 + a2.mul_b_v(b2),
  198. a0 * b2 + a1 * b1 + a2 * b0)
  199. def sqr2(self): # 分圆循环子群Gϕ6(Fp2)中的元素平方
  200. a, b, c = self.coeffs
  201. a2, b2, c2v = a.sqr(), b.sqr(), c.sqr_v()
  202. return FQ12(a2 + (a2 - a.conjugate()) * 2, c2v + (c2v + b.conjugate()) * 2, b2 + (b2 - c.conjugate()) * 2)
  203. def __pow__(self, other): # 实际运行此函数的对象都是分圆循环子群Gϕ6(Fp2)中的元素
  204. if other > 10: # 加减法
  205. h, k = bin(3 * other)[2:], bin(other)[2:]
  206. k, t, nf = '0' * (len(h) - len(k)) + k, self, self.frobenius6()
  207. for i in range(1, len(h) - 1):
  208. t = t.sqr2()
  209. if h[i] == '1' and k[i] == '0':
  210. t = t * self
  211. elif h[i] == '0' and k[i] == '1':
  212. t = t * nf
  213. else:
  214. t = self
  215. for ri in bin(other)[3:]:
  216. t = t.sqr2() * self if ri == '1' else t.sqr2()
  217. return t
  218. def inv(self):
  219. a0, a1, a2 = self.coeffs
  220. a0_2, a1_2 = a0.sqr(), a1.sqr()
  221. if a2.is_zero():
  222. k = (a0 * a0_2 + a1.mul_b_v(a1_2)).inv()
  223. return FQ12(a0_2 * k, (-a0 * a1 * k), a1_2 * k)
  224. t0, t1, t2 = a1_2 - a0 * a2, a0 * a1 - a2.sqr_v(), a0_2 - a1.mul_b_v(a2)
  225. t3 = a2 * (t1.sqr() - t0 * t2).inv()
  226. return FQ12(t2 * t3, (-t1 * t3), t0 * t3)
  227. def frobenius(self):
  228. (a0, a1), (b0, b1), (c0, c1) = self.coeffs
  229. a = FQ4(a0.conjugate(), a1.conjugate() * alpha3)
  230. b = FQ4(b0.conjugate() * alpha1, b1.conjugate() * alpha4)
  231. c = FQ4(c0.conjugate() * alpha2, c1.conjugate() * alpha5)
  232. return FQ12(a, b, c)
  233. def frobenius2(self):
  234. a, b, c = self.coeffs
  235. return FQ12(a.conjugate(), b.conjugate() * alpha2, c.conjugate() * alpha4)
  236. def frobenius3(self):
  237. (a0, a1), (b0, b1), (c0, c1) = self.coeffs
  238. a = FQ4(a0.conjugate(), -a1.conjugate() * alpha3)
  239. b = FQ4(b0.conjugate() * alpha3, b1.conjugate())
  240. c = FQ4(-c0.conjugate(), c1.conjugate() * alpha3)
  241. return FQ12(a, b, c)
  242. def frobenius6(self):
  243. a, b, c = self.coeffs
  244. return FQ12(a.conjugate(), -b.conjugate(), c.conjugate())
  245. class ECC_Point:
  246. def __init__(self, *pt): # 采用Jacobian射影坐标计算,输入仿射坐标后会转换为Jacobian射影坐标
  247. self.pt = pt if len(pt) == 3 else (*pt, pt[0].one())
  248. @classmethod
  249. def from_byte(cls, byte): # 输入bytes类型仿射坐标,构建点对象
  250. fp_num = len(byte) // (N_SIZE << 1) # 单个坐标包含的域元素个数
  251. if fp_num in (1, 2) and len(byte) % N_SIZE == 0:
  252. fp_list = [to_int(byte[i:i + N_SIZE]) for i in range(0, len(byte), N_SIZE)] # 将bytes转换为域元素列表
  253. if fp_num == 1:
  254. return cls(FQ(fp_list[0]), FQ(fp_list[1]))
  255. x_list, y_list = fp_list[fp_num - 1::-1], fp_list[:fp_num - 1:-1] # 从bytes到FQ2对象保存的域元素,需翻转高低位顺序
  256. return cls(FQ2(*x_list), FQ2(*y_list))
  257. return False
  258. def is_inf(self):
  259. return self[2].is_zero()
  260. def is_on_curve(self): # 检查点是否满足曲线方程 y^2 == x^3 + b
  261. x, y, z = self.pt
  262. return y ** 2 == x ** 3 + (_b1 if type(x) is FQ else _b2) * z ** 6
  263. def double(self):
  264. x, y, z = self.pt
  265. _3x2, _2y = x.sqr() * 3, y * 2
  266. _4y2 = _2y.sqr()
  267. _4xy2 = x * _4y2
  268. x3 = _3x2.sqr() - _4xy2 * 2
  269. return ECC_Point(x3, _3x2 * (_4xy2 - x3) - _4y2.sqr() * _2_inv, _2y * z)
  270. def zero(self):
  271. cls = self[0].__class__
  272. return ECC_Point(cls.one(), cls.one(), cls.zero())
  273. def __add__(self, p2):
  274. if self.is_inf():
  275. return p2
  276. if p2.is_inf():
  277. return self
  278. (x1, y1, z1), (x2, y2, z2) = self.pt, p2.pt
  279. z1_2, z2_2 = z1.sqr(), z2.sqr()
  280. T1, T2 = x1 * z2_2, x2 * z1_2
  281. T3, T4, T5 = T1 - T2, y1 * z2_2 * z2, y2 * z1_2 * z1
  282. T6, T7, T3_2 = T4 - T5, T1 + T2, T3.sqr()
  283. T8, T9 = T4 + T5, T7 * T3_2
  284. x3 = T6.sqr() - T9
  285. T10 = T9 - x3 * 2
  286. y3 = (T10 * T6 - T8 * T3_2 * T3) * _2_inv
  287. z3 = z1 * z2 * T3
  288. return ECC_Point(x3, y3, z3)
  289. def multiply(self, n): # 算法一:二进制展开法
  290. if n in (0, 1):
  291. return self if n else self.zero()
  292. Q = self
  293. for i in bin(n)[3:]:
  294. Q = Q.double() + self if i == '1' else Q.double()
  295. return Q
  296. def __mul__(self, n): # 算法三:滑动窗法
  297. k = bin(n)[2:]
  298. l, r = len(k), 5 # 滑动窗口为5效果较好
  299. if r >= l: # 如果窗口大于k的二进制位数,则本算法无意义
  300. return self.multiply(n)
  301. P_ = {1: self, 2: self.double()} # 保存P[j]值的字典
  302. for i in range(1, 1 << (r - 1)):
  303. P_[(i << 1) + 1] = P_[(i << 1) - 1] + P_[2]
  304. t = r
  305. while k[t - 1] != '1':
  306. t -= 1
  307. hj = int(k[:t], 2)
  308. Q, j = P_[hj], t
  309. while j < l:
  310. if k[j] == '0':
  311. Q = Q.double()
  312. j += 1
  313. else:
  314. t = min(r, l - j)
  315. while k[j + t - 1] != '1':
  316. t -= 1
  317. hj = int(k[j:j + t], 2)
  318. Q = Q.multiply(1 << t) + P_[hj]
  319. j += t
  320. return Q
  321. def __rmul__(self, n):
  322. return self.__mul__(n)
  323. def __eq__(self, p2):
  324. (x1, y1, z1), (x2, y2, z2) = self.pt, p2.pt
  325. z1_2, z2_2 = z1.sqr(), z2.sqr()
  326. return x1 * z2_2 == x2 * z1_2 and y1 * z2_2 * z2 == y2 * z1_2 * z1
  327. def __neg__(self):
  328. x, y, z = self.pt
  329. return ECC_Point(x, -y, z)
  330. def __getitem__(self, item):
  331. return self.pt[item]
  332. def __repr__(self):
  333. return '%s%s' % (self.__class__.__name__, self.normalize())
  334. def __bytes__(self):
  335. return to_byte(self.normalize(), N_SIZE if type(self[0]) is FQ else None)
  336. def normalize(self):
  337. x, y, z = self.pt
  338. if not hasattr(self, 'normalize_tuple'):
  339. if z != z.one():
  340. z_inv = z.inv()
  341. z_inv_2 = z_inv.sqr()
  342. x, y = x * z_inv_2, y * z_inv_2 * z_inv
  343. self.normalize_tuple = (x.n, y.n) if type(x) is FQ else (x, y)
  344. return self.normalize_tuple
  345. def frobenius(self):
  346. x, y, z = self.pt
  347. return ECC_Point(x.conjugate(), y.conjugate(), z.conjugate() * alpha1)
  348. def frobenius2_neg(self):
  349. x, y, z = self.pt
  350. return ECC_Point(x, -y, z * alpha2)
  351. FQ2_one, FQ2_zero = FQ2(1, 0), FQ2(0, 0) # FQ2单位元、零元
  352. FQ4_one, FQ4_zero = FQ4(FQ2_one, FQ2_zero), FQ4() # FQ4单位元、零元
  353. FQ12_one = FQ12(FQ4_one, FQ4_zero, FQ4_zero) # FQ12单位元
  354. P1 = ECC_Point(FQ(x_p1), FQ(y_p1)) # 群G1的生成元
  355. P2 = ECC_Point(FQ2(*x_p2[::-1]), FQ2(*y_p2[::-1])) # 群G2的生成元
  356. _b1, _b2 = FQ(5), FQ2(0, 5) # b2=βb=(1,0)*5
  357. alpha1 = 0X3F23EA58E5720BDB843C6CFA9C08674947C5C86E0DDD04EDA91D8354377B698B # -2^((q - 1)/12)
  358. alpha2 = 0XF300000002A3A6F2780272354F8B78F4D5FC11967BE65334 # -2^((q - 1)/6)
  359. alpha3 = 0X6C648DE5DC0A3F2CF55ACC93EE0BAF159F9D411806DC5177F5B21FD3DA24D011 # -2^((q - 1)/4)
  360. alpha4 = 0XF300000002A3A6F2780272354F8B78F4D5FC11967BE65333 # -2^((q - 1)/3)
  361. alpha5 = 0X2D40A38CF6983351711E5F99520347CC57D778A9F8FF4C8A4C949C7FA2A96686
  362. _2_inv = 0X5B2000000151D378EB01D5A7FAC763A290F949A58D3D776DF2B7CD93F1A8A2BF # 1/2
  363. _3div2 = 0X5B2000000151D378EB01D5A7FAC763A290F949A58D3D776DF2B7CD93F1A8A2C0 # 3/2
  364. R_ate_a_NAF = '0100000000000000000000000000000000000010001020200020200101000020' # a=6t+2的二进制非相邻表示(2-NAF)(去首10)
  365. hlen = 320 # 8 * ceil(5 * log(N, 2) / 32)
  366. _t, _6t, _6t_3 = 0x600000000058F98A, 0X2400000000215D93C, 0X2400000000215D93F
  367. # 输入系数值和点P,求线函数值
  368. def g_value(a_tuple, P):
  369. (a0, a1, a4), (xP, yP) = a_tuple, P
  370. return FQ12(FQ4(a0, a1 * yP), FQ4_zero, FQ4(a4 * xP, FQ2_zero))
  371. # 获取线函数g T,Q(P)的系数值(分母在最终模幂时值为1,可消去)
  372. def get_g_a_tuple(T, Q):
  373. (xT, yT, zT), (xQ, yQ, zQ) = T, Q
  374. zT_2, zQ_2 = zT.sqr(), zQ.sqr()
  375. zQ_3, t1 = zQ * zQ_2, (xT * zQ_2 - xQ * zT_2) * zT * zQ
  376. a1, t2 = t1 * zQ_3, (yT * zQ_3 - yQ * zT * zT_2) * zQ
  377. a0, a4 = t1 * yQ - t2 * xQ, t2 * zQ_2
  378. return a0, a1, a4
  379. # 线函数g T,Q(P),求过点T和Q的直线在P上的值
  380. def g(T, Q, nP):
  381. return g_value(get_g_a_tuple(T, Q), nP)
  382. # 获取线函数g T,T(P)的系数值(分母在最终模幂时值为1,可消去),利用中间值完成倍点计算
  383. def get_g2_a_tuple(T):
  384. x, y, z = T
  385. _z2, _3x2, _2y = z.sqr(), x.sqr() * 3, y * 2
  386. _4y2, _2yz = _2y.sqr(), _2y * z
  387. a1, a0, a4, _4xy2 = _z2 * _2yz, _4y2 * _2_inv - _3x2 * x, _3x2 * _z2, x * _4y2
  388. x3 = _3x2.sqr() - _4xy2 * 2
  389. y3 = _3x2 * (_4xy2 - x3) - _4y2.sqr() * _2_inv
  390. return (a0, a1, a4), ECC_Point(x3, y3, _2yz)
  391. # 线函数g T,T(P),求过点T的切线在P上的值,利用中间值完成倍点计算
  392. def g2(T, nP):
  393. a_tuple, double_T = get_g2_a_tuple(T)
  394. return g_value(a_tuple, nP), double_T
  395. # BN曲线上R_ate对的计算
  396. def e(P, Q):
  397. nQ, nP_xy = -Q, (-P).normalize()
  398. f, T = g2(Q, nP_xy)
  399. for ai in R_ate_a_NAF:
  400. new_g, T = g2(T, nP_xy)
  401. f = f.sqr() * new_g
  402. if ai == '1':
  403. f, T = f * g(T, Q, nP_xy), T + Q
  404. elif ai == '2': # 用2代替-1
  405. f, T = f * g(T, nQ, nP_xy), T + nQ
  406. Q1, nQ2 = Q.frobenius(), Q.frobenius2_neg()
  407. return final_exp(f * g(T, Q1, nP_xy) * g(T + Q1, nQ2, nP_xy))
  408. # 最终模幂
  409. def final_exp(f):
  410. m = f.frobenius6() * f.inv() # f^(p^6 - 1)
  411. s = m.frobenius2() * m # m^(p^2 + 1)
  412. # 困难部分 s^(p^3 + (6t^2+1)p^2 + (-36t^3-18t^2-12t+1)p + (-36t^3-30t^2-18t-2))
  413. s_6t = s ** _6t
  414. s_6t2 = s_6t ** _t
  415. s_36t3_18t2_12t, a2 = s_6t2 ** _6t_3 * s_6t.sqr2(), s_6t2 * s
  416. a1, a0 = s_36t3_18t2_12t.frobenius6() * s, (s_36t3_18t2_12t * s_6t * a2.sqr2()).frobenius6()
  417. return s.frobenius3() * a2.frobenius2() * a1.frobenius() * a0
  418. # 获取线函数的系数值序列
  419. def get_a_list(Q):
  420. a_tuple, T = get_g2_a_tuple(Q)
  421. a_list, nQ = [a_tuple], -Q
  422. for ai in R_ate_a_NAF:
  423. a_tuple, T = get_g2_a_tuple(T)
  424. a_list.append(a_tuple)
  425. if ai != '0':
  426. a_list.append(get_g_a_tuple(T, nQ if ai == '2' else Q))
  427. T = T + (nQ if ai == '2' else Q)
  428. Q1, nQ2 = Q.frobenius(), Q.frobenius2_neg()
  429. return a_list + [get_g_a_tuple(T, Q1), get_g_a_tuple(T + Q1, nQ2)]
  430. def e_fast(P, a_list):
  431. nP_xy = (-P).normalize()
  432. f, i = g_value(a_list[0], nP_xy), 1
  433. for ai in R_ate_a_NAF:
  434. f, i = f.sqr() * g_value(a_list[i], nP_xy), i + 1
  435. if ai != '0':
  436. f, i = f * g_value(a_list[i], nP_xy), i + 1
  437. return final_exp(f * g_value(a_list[i], nP_xy) * g_value(a_list[-1], nP_xy))
  438. # 获取Comb固定基的预计算表(256个FQ12的列表)
  439. def get_comb_list(n):
  440. comb_list = [FQ12_one, n]
  441. for i in range(7):
  442. tmp = comb_list[2**i]
  443. for _ in range(32):
  444. tmp = tmp.sqr2()
  445. comb_list += [tmp * c for c in comb_list]
  446. return comb_list
  447. # Comb固定基的幂运算
  448. def comb_pow(r, comb_list):
  449. r_bin, res = '0' * (256 - r.bit_length()) + bin(r)[2:], FQ12_one
  450. for i in range(32):
  451. a = int(''.join(r_bin[j] for j in range(i, 256, 32)), 2)
  452. res = comb_list[a] if res is FQ12_one else res.sqr2() * comb_list[a]
  453. return res
  454. # 获取固定点乘的预计算表(32行256列的椭圆曲线点矩阵)
  455. def get_kP_list(P):
  456. one, kP_list = P, []
  457. for i in range(32):
  458. line_list = [0, one] # O, P
  459. for j in range(1, 128):
  460. line_list.append(line_list[j].double()) # 2j·P
  461. line_list.append(line_list[-1] + one) # (2j+1)·P
  462. kP_list.append(line_list)
  463. one = line_list[128].double() if i < 31 else 0
  464. return kP_list
  465. # 使用预计算表的快速点乘
  466. def fast_kG(k, kP_list):
  467. P_list = [kP_list[i][byte] for i, byte in enumerate(k.to_bytes(32, byteorder='little')) if byte]
  468. return sum(P_list[1:], P_list[0])
  469. # SM9算法(GB_T 38635.2-2020) 5.3.6定义的密钥派生函数
  470. # Z为bytes类型,klen表示输出密钥比特长度(8的倍数);输出为bytes类型
  471. def KDF(Z, klen=KEY_LEN):
  472. ksize, K = klen >> 3, bytearray()
  473. for ct in range(1, ceil(ksize / HASH_SIZE) + 1):
  474. K.extend(sm3(Z + to_byte(ct, 4)))
  475. return K[:ksize]
  476. # SM9算法(GB_T 38635.2-2020) 5.3.2.2和5.3.2.3定义的密码函数
  477. def H(i, Z):
  478. Ha = to_int(KDF(to_byte(i, 1) + Z, hlen))
  479. return Ha % (N - 1) + 1
  480. # SM9算法(GB_T 38635.2-2020) 5.3.5定义的消息认证码函数
  481. def MAC(K2, Z):
  482. return sm3(Z + K2)
  483. class SM9: # SM9算法(GB_T 38635.2-2020)
  484. def __init__(self, ID='', ds=None, Ppub_s=None, de=None, Ppub_e=None, hid_s=1, hid_e=3, ks=None, ke=None):
  485. self.ID, self.ID_byte, self.hid_s_byte, self.hid_e_byte = ID, to_byte(ID), to_byte(hid_s, 1), to_byte(hid_e, 1)
  486. if ks: # 作为密钥生成中心,给定签名主私钥(若要随机生成,可指定ks=-1)
  487. self.ks = ks if 0 < ks < N else randrange(1, N)
  488. self.Ppub_s = fast_kG(self.ks, _kP2)
  489. if ke: # 作为密钥生成中心,给定加密主私钥(若要随机生成,可指定ke=-1)
  490. self.ke = ke if 0 < ke < N else randrange(1, N)
  491. self.Ppub_e = fast_kG(self.ke, _kP1)
  492. if ds and Ppub_s: # 作为用户,给定用户签名私钥和签名主公钥
  493. self.k_ds_list, self.Ppub_s, self.gs = get_kP_list(ds), Ppub_s, e(P1, Ppub_s)
  494. self.gs_comb_list = get_comb_list(self.gs)
  495. if de and Ppub_e: # 作为用户,给定用户加密私钥和加密主公钥
  496. self.de_a_list, self.Ppub_e, self.ge = get_a_list(de), Ppub_e, e(Ppub_e, P2)
  497. self.ge_comb_list = get_comb_list(self.ge)
  498. def KGC_gen_user(self, ID):
  499. ID_byte, ds, Ppub_s, de, Ppub_e = to_byte(ID), None, None, None, None
  500. if hasattr(self, 'ks'):
  501. t1 = (H(1, ID_byte + self.hid_s_byte) + self.ks) % N
  502. if t1 == 0: # 需重新产生签名主密钥,并更新所有用户的签名密钥
  503. return False
  504. t2 = self.ks * mod_inv(t1, N) % N
  505. ds, Ppub_s = fast_kG(t2, _kP1), self.Ppub_s # 用户签名私钥和签名主公钥
  506. if hasattr(self, 'ke'):
  507. t1 = (H(1, ID_byte + self.hid_e_byte) + self.ke) % N
  508. if t1 == 0: # 需重新产生加密主密钥,并更新所有用户的加密密钥
  509. return False
  510. t2 = self.ke * mod_inv(t1, N) % N
  511. de, Ppub_e = fast_kG(t2, _kP2), self.Ppub_e # 用户加密私钥和加密主公钥
  512. return SM9(ID, ds, Ppub_s, de, Ppub_e, self.hid_s_byte, self.hid_e_byte)
  513. # 6.2 数字签名生成算法
  514. def sign(self, M, r=None, outbytes=True):
  515. l = 0
  516. while l == 0:
  517. r = r if r else randrange(1, N) # A2
  518. w = bytes(self.gs_pow(r)) # A3
  519. h = H(2, to_byte(M) + w) # A4
  520. l = (r - h) % N # A5
  521. S = fast_kG(l, self.k_ds_list) # A6
  522. return to_byte([h, S]) if outbytes else (h, S)
  523. # 6.4 数字签名验证算法
  524. def verify(self, ID, M_, sig):
  525. h_, S_ = (to_int(sig[:N_SIZE]), ECC_Point.from_byte(sig[N_SIZE:])) if type(sig) is bytes else sig
  526. if not 0 < h_ < N or not S_ or not S_.is_on_curve(): # B1、B2
  527. return False
  528. t = self.gs_pow(h_) # B4
  529. h1 = H(1, to_byte(ID) + self.hid_s_byte) # B5
  530. P = fast_kG(h1, _kP2) + self.Ppub_s # B6
  531. u = e(S_, P) # B7
  532. w_ = bytes(u * t) # B8
  533. h2 = H(2, to_byte(M_) + w_) # B9
  534. return h_ == h2
  535. # A 发起协商(也可用作B生成rB、RB;outbytes=True时输出bytes)
  536. # 7.2 密钥交换协议 A1-A3
  537. def agreement_initiate(self, IDB, r=None, outbytes=True):
  538. QB = fast_kG(H(1, to_byte(IDB) + self.hid_e_byte), _kP1) + self.Ppub_e # A1
  539. rA = r if r else randrange(1, N) # A2
  540. RA = QB * rA # A3
  541. return rA, bytes(RA) if outbytes else RA
  542. # B 响应协商(option=True时计算选项部分)
  543. # 7.2 密钥交换协议 B1-B6
  544. def agreement_response(self, RA, IDA, option=False, rB=None, klen=KEY_LEN, outbytes=True):
  545. RA = ECC_Point.from_byte(RA) if type(RA) is bytes else RA
  546. if not RA or not RA.is_on_curve(): # B4
  547. return False, 'RA不属于椭圆曲线群G1'
  548. rB, RB = self.agreement_initiate(IDA, rB, outbytes) # B1-B3
  549. g1, g2 = self.e_de(RA), bytes(self.ge_pow(rB)) # B4
  550. g1, g3 = bytes(g1), bytes(g1 ** rB) # B4
  551. tmp_byte = to_byte([IDA, self.ID_byte, RA, RB])
  552. SKB = KDF(tmp_byte + g1 + g2 + g3, klen) # B5
  553. if not option:
  554. return True, (RB, SKB)
  555. self.tmp_byte2 = g1 + sm3(g2 + g3 + tmp_byte)
  556. SB = sm3(to_byte(0x82, 1) + self.tmp_byte2) # B6(可选部分)
  557. return True, (RB, SKB, SB)
  558. # A 协商确认
  559. # 7.2 密钥交换协议 A5-A8
  560. def agreement_confirm(self, rA, RA, RB, IDB, SB=None, option=False, klen=KEY_LEN):
  561. RB = ECC_Point.from_byte(RB) if type(RB) is bytes else RB
  562. if not RB or not RB.is_on_curve(): # A5
  563. return False, 'RB不属于椭圆曲线群G1'
  564. g1_, g2_ = bytes(self.ge_pow(rA)), self.e_de(RB) # A5
  565. g2_, g3_ = bytes(g2_), bytes(g2_ ** rA) # A5
  566. tmp_byte = to_byte([self.ID_byte, IDB, RA, RB])
  567. if option and SB: # A6(可选部分)
  568. tmp_byte2 = g1_ + sm3(g2_ + g3_ + tmp_byte)
  569. S1 = sm3(to_byte(0x82, 1) + tmp_byte2)
  570. if S1 != SB:
  571. return False, 'S1 != SB'
  572. SKA = KDF(tmp_byte + g1_ + g2_ + g3_, klen) # A7
  573. if not option or not SB:
  574. return True, SKA
  575. SA = sm3(to_byte(0x83, 1) + tmp_byte2) # A8
  576. return True, (SKA, SA)
  577. # B 协商确认(可选部分)
  578. # 7.2 密钥交换协议 B8
  579. def agreement_confirm2(self, SA):
  580. if not hasattr(self, 'tmp_byte2'):
  581. return False, 'step error'
  582. S2 = sm3(to_byte(0x83, 1) + self.tmp_byte2)
  583. if S2 == SA:
  584. del self.tmp_byte2
  585. return True, ''
  586. return False, 'S2 != SA'
  587. # 8.2 密钥封装算法
  588. def encaps(self, IDB, klen, r=None, outbytes=True):
  589. K = bytes()
  590. while K == bytes(len(K)):
  591. r, C = self.agreement_initiate(IDB, r, outbytes) # A1-A3
  592. w = bytes(self.ge_pow(r)) # A5
  593. K = KDF(to_byte([C, w, IDB]), klen)
  594. return K, C
  595. # 8.4 密钥封装算法
  596. def decaps(self, C, klen):
  597. C = ECC_Point.from_byte(C) if type(C) is bytes else C
  598. if not C or not C.is_on_curve(): # B1
  599. return False, 'C不属于椭圆曲线群G1'
  600. w_ = bytes(self.e_de(C)) # B2
  601. K_ = KDF(to_byte([C, w_, self.ID_byte]), klen) # B3
  602. return (True, K_) if K_ != bytes(len(K_)) else (False, 'K为全0比特串')
  603. # 9.2 加密算法
  604. def encrypt(self, IDB, M, r=None, outbytes=True):
  605. M = to_byte(M)
  606. K, C1 = self.encaps(IDB, (len(M) << 3) + K2_len, r, outbytes) # A1-A6.a.1
  607. K1, K2 = K[:len(M)], K[len(M):] # A6.a.1
  608. C2 = bytes(M[i] ^ K1[i] for i in range(len(M))) # A6.a.2
  609. C3 = MAC(K2, C2) # A7
  610. return to_byte([C1, C3, C2]) if outbytes else (C1, C3, C2)
  611. # 9.4 解密算法
  612. def decrypt(self, C):
  613. C3_start, C3_end = N_SIZE << 1, (N_SIZE << 1) + HASH_SIZE
  614. C1, C3, C2 = (C[:C3_start], C[C3_start:C3_end], C[C3_end:]) if type(C) is bytes else C
  615. res, K_ = self.decaps(C1, (len(C2) << 3) + K2_len) # B1-B3.a.1
  616. if not res:
  617. return False, K_.replace('C', 'C1')
  618. K1_, K2_ = K_[:len(C2)], K_[len(C2):] # B3.a.1
  619. if K1_ == bytes(len(K_)):
  620. return False, 'K1\'为全0比特串'
  621. u = MAC(K2_, C2) # B4
  622. if u != C3:
  623. return False, 'u != C3'
  624. return True, bytes(C2[i] ^ K1_[i] for i in range(len(C2))) # B3.a.2
  625. def e_de(self, P):
  626. return e_fast(P, self.de_a_list)
  627. def gs_pow(self, r):
  628. return comb_pow(r, self.gs_comb_list)
  629. def ge_pow(self, r):
  630. return comb_pow(r, self.ge_comb_list)
  631. _SM9kG_file = 'hggm/SM9_kG.bin' # 预计算数据文件的位置
  632. _kP1, _kP2 = [], [] # P1、P2点的预计算表
  633. if os.path.exists(_SM9kG_file):
  634. with open(_SM9kG_file, 'rb') as f: # 读取预计算数据文件
  635. data = f.read()
  636. G1_size, G2_size = N_SIZE << 1, N_SIZE << 2 # G1点坐标字节数、G2点坐标字节数
  637. P1_line, line = 255 * G1_size, 255 * (G1_size + G2_size) # 一行G1点坐标字节数、一行总字节数
  638. for i in range(0, N_SIZE * line, line):
  639. _kP1.append([0] + [ECC_Point.from_byte(data[j:j + G1_size]) for j in range(i, i + P1_line, G1_size)])
  640. _kP2.append([0] + [ECC_Point.from_byte(data[j:j + G2_size]) for j in range(i + P1_line, i + line, G2_size)])
  641. else: # 预计算数据文件不存在
  642. _kP1, _kP2 = get_kP_list(P1), get_kP_list(P2) # 生成P1、P2点的预计算表
  643. with open(_SM9kG_file, 'wb') as f: # 将预计算表写入二进制文件
  644. f.write(b''.join(map(bytes, [P for x, y in zip(_kP1, _kP2) for P in x[1:] + y[1:]])))

完善了测试代码,创建KGC的代码比上版稍有变动:

  1. IDA, IDB, message = 'Alice', 'Bob', 'Chinese IBS standard'
  2. kgc = SM9(ks=0x130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4,
  3. ke=0x2E65B0762D042F51F0D23542B13ED8CFA2E9A0E7206361E013A283905E31F)
  4. sm9_A, sm9_B = kgc.KGC_gen_user(IDA), kgc.KGC_gen_user(IDB)
  5. assert bytes(sm9_A.gs).hex().swapcase().endswith('F0F071D7D284FCFB')
  6. print("-----------------test sign and verify---------------")
  7. r = 0x033C8616B06704813203DFD00965022ED15975C662337AED648835DC4B1CBE
  8. signature = sm9_A.sign(message, r)
  9. assert signature.hex().swapcase().endswith('827CC2ACED9BAA05')
  10. assert sm9_B.verify(IDA, message, signature)
  11. print("success")
  12. print("-----------------test key agreement---------------")
  13. rA = 0x5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8
  14. rA, RA = sm9_A.agreement_initiate(IDB, rA) # A发起协商
  15. # A将RA发送给B
  16. rB = 0x018B98C44BEF9F8537FB7D071B2C928B3BC65BD3D69E1EEE213564905634FE
  17. res, content = sm9_B.agreement_response(RA, IDA, True, rB) # B响应协商
  18. if not res:
  19. print('B报告协商错误:', content)
  20. return
  21. RB, SKB, SB = content
  22. # B将RB、SB发送给A
  23. res, content = sm9_A.agreement_confirm(rA, RA, RB, IDB, SB, True) # A协商确认
  24. if not res:
  25. print('A报告协商错误:', content)
  26. return
  27. SKA, SA = content
  28. assert SKA.hex().swapcase() == '68B20D3077EA6E2B825315836FDBC633'
  29. # A将SA发送给B
  30. res, content = sm9_B.agreement_confirm2(SA) # B协商确认
  31. if not res:
  32. print('B报告协商错误:', content)
  33. return
  34. assert SKA == SKB
  35. print("success")
  36. print("-----------------test encrypt and decrypt---------------")
  37. message = 'Chinese IBE standard'
  38. kgc = SM9(ke=0x01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22)
  39. sm9_A, sm9_B = kgc.KGC_gen_user(IDA), kgc.KGC_gen_user(IDB)
  40. C = sm9_A.encrypt(IDB, message, 0xAAC0541779C8FC45E3E2CB25C12B5D2576B2129AE8BB5EE2CBE5EC9E785C)
  41. assert C.hex().swapcase().endswith('378CDD5DA9513B1C')
  42. res, content = sm9_B.decrypt(C)
  43. if not res:
  44. print('解密错误:', content)
  45. return
  46. assert message == content.decode()
  47. print("success")

运行结果如下,可见实际算法比上版有20%~4倍提升不等:

以上全部代码在:hggm - 国密算法 SM2 SM3 SM4 SM9 ZUC Python实现完整代码: 国密算法 SM2公钥密码 SM3杂凑算法 SM4分组密码 SM9标识密码 ZUC序列密码 Python代码完整实现 效率高于所有公开的Python国密算法库 (gitee.com)

hggm国密算法Python库自2022年3月开始开发,2022年4月首次公开,目前已经包含SM2、SM3、SM4、SM9、ZUC的慢速版和快速版,基本完善了。出于个人研究学习的兴趣,纰漏在所难免,可继续改进提升的地方还很多,望各位不吝赐教,我会持续更新优化。从一开始就坚持完全开源,希望能帮到大家,也希望能以个人之所学为网安事业尽绵薄之力。

但愿这不止是完结,而是全新的开始。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/489169
推荐阅读
相关标签
  

闽ICP备14008679号