方案介绍
5、第五步
本节将针对上一节的图进行数学公式的分析与推导,推导过程尽量详细,本章节的推导基本用到了上面介绍RSA方案中的公式,另外有兴趣的同学也可以自行看下数论里面的知识,进而完成整个PSI协议的推导。
准备工作:
- Server侧的样本ID集合{hc1, hc2, …,hcv},Client侧的样本ID集合{hs1, hs2, hsw}
- Server产生RSA加密的公钥与秘钥,秘钥保留在Server端,公钥(e,n)下发到Client端。
- Full-Domain Hash H。(小于n,并且与n互质,数据量特别大的情况下要考虑空间问题)。
- Client随机数R。(小于n,并且与n互质)
- 下面详细描述下交互的流程,针对上图进行讲解。
Server侧计算样本对齐ID(手机号、身份证号等)的最终签名。计算方式如下:
Client侧生成一个随机数Rc(大于1小于n,并且与n互质),并且针对要对齐的ID进行公钥加密处理,然后乘以Rc进行加盲扰动。
Client侧将上述加盲扰动的乘积值传递给Server侧。
4、第四步Server侧接受Client侧发送的数据,并且进行使用私钥进行签名的初步计算。

Server侧将初步计算的签名传输给Client侧继续完成去盲的操作,生产最后的Client ID的签名工作,并且发送Server侧的ID的签名,与Client的ID签名进行样本对齐。
这样Client侧的ID也进行了签名,可以与Server侧的ID进行对齐了。
此算法在两方ID数量差别很大的场景很有优势。例如,差几个数量级,10亿交10w,可以让拥有ID数量少的一方生成随机数(RSA加密),ID数量多的一方拥有RSA私钥(RSA盲签名),这样可以显著减少计算和通信开销。
Python实现:
#!/usr/bin/python3
#coding=utf8
from Cryptodome.PublicKey import RSA
import hashlib
import binascii
import gmpy2
import os
rand_bits = 128
def hash_bignumber(num,method='sha1'):
'''
num: an integer
'''
if method == 'sha1':
hash_obj = hashlib.sha1(str(num).encode('utf-8'))
digest_hex = hash_obj.hexdigest()
return int(digest_hex,16)
def gen_key():
key = RSA.generate(1024)
pk = (key.n,key.e)
sk = (key.n,key.d)
return pk,sk
def blind_msg_arr_use_pk(msg_arr, pk):
msg_hash_number_blind = []
rand_private_list = []
for item in msg_arr:
hash_num = hash_bignumber(item)
hash_num = hash_num % pk[0]
ra = int(binascii.hexlify(os.urandom(rand_bits)),16)
cipher_ra = gmpy2.powmod(ra,pk[1],pk[0])
rand_private_list.append(ra)
msg_hash_number_blind.append(hash_num*cipher_ra)
return msg_hash_number_blind, rand_private_list
def deblind_arr_from_client(hash_arr_blind,sk):
deblind_hash_arr = []
for item in hash_arr_blind:
de_blind_number = gmpy2.powmod(item,sk[1],sk[0])
deblind_hash_arr.append(de_blind_number)
return deblind_hash_arr
def enc_and_hash_serverlist(server_list,sk):
hash_server_list = []
for item in server_list:
hash_num = hash_bignumber(item)
c_hash_num = gmpy2.powmod(hash_num,sk[1],sk[0])
hash_server_list.append(hash_bignumber(c_hash_num))
return hash_server_list
def hash_deblind_client_arr(deblind_hash_arr, rand_list, pk):
db_client = []
for item,ra in zip(deblind_hash_arr,rand_list):
ra_inv = gmpy2.invert(ra,pk[0]) # ra*ra_inv == 1 mod n
db_client.append(hash_bignumber((item * ra_inv) % pk[0]))
return db_client
def get_common_elements_idx(db_client,db_server):
# search out the common elements in O(n^2) complexity
# return the common elements index in local data list
common_set_index = []
for idx in range(len(db_client)):
rec_a = db_client[idx]
for rec_b in db_server:
if rec_a == rec_b:
common_set_index.append(idx)
return common_set_index
# Server side:
# RSA key generation and send pk to client
pk,sk = gen_key()
# Client side blind the local array and send to server
msg_arr_client = [12,3,4,8,10,23]
blind_arr, rlist = blind_msg_arr_use_pk(msg_arr_client,pk)
# Server side
# receive the blind arr from client and deblind them with secret key
received_blind_arr = blind_arr.copy()
deblind_hash_arr = deblind_arr_from_client(received_blind_arr,sk)
server_list = [12,3,4,5,1,32,45]
# encrypt and hash the server list
hashed_server_list = enc_and_hash_serverlist(server_list,sk)
# send the deblind array and hashed list to client
# Client side
# Receive the deblind array and hashed server list
received_deblind_hash_arr = deblind_hash_arr.copy()
db_server = hashed_server_list.copy()
# hash the dblind array
db_client = hash_deblind_client_arr(received_deblind_hash_arr,rlist,pk)
common_index_local_list = get_common_elements_idx(db_client,db_server)
print('The hash of db Client:', db_client)
print('The hash of db Server:', db_server)
for idx, true_id in zip(common_index_local_list,[0,1,2]):
assert idx == true_id