场景:处理旅游数据

Ethan 打开笔记本电脑,从电脑桌面上的文件夹里翻出两个 Excel 表格文件,在这两个文件里,分别存着最近去过普吉岛和新西兰旅游的旅客信息,需要从这两份数据里,找出那些去过普吉岛但没去过新西兰的人,再让销售人员向他们推销一些新西兰精品旅游路线。将文件转换为 JSON 格式后,里面的内容大致如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 去过普吉岛的人员数据
users_visited_phuket = [
{
"first_name": "Sirena",
"last_name": "Gross",
"phone_number": "650-568-0388",
"date_visited": "2018-03-14",
},
...
]

# 去过新西兰的人员数据
users_visited_nz = [
{
"first_name": "Justin",
"last_name": "Malcom",
"phone_number": "267-282-1964",
"date_visited": "2011-03-13",
},
...
]

第一次蛮力尝试

因为拿到的旅客数据里,并没有旅客 ID 之类的唯一标识符,所以我们其实无法精确地找出重复旅客,只能用姓名 + 电话号码来判断两位旅客是不是同一个人:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def find_potential_customers():
for puket_record in users_visited_puket:
is_potential = True
for nz_record in users_visited_nz:
if (
puket_record['first_name'] == nz_record['first_name']
and puket_record['last_name'] == nz_record['last_name']
and puket_record['phone_number'] == nz_record['phone_number']
):
is_potential = False
break

if is_potential:
yield puket_record

虽然这段代码能完成任务,但它有非常严重的性能问题。对于每条普吉岛旅客记录,我们都需要轮询所有的新西兰旅客记录,尝试找到匹配项,时间复杂度是 O(n*m)。

使用集合优化函数

Python 里的集合是基于哈希表实现的,判断一个东西是否在集合里,速度非常快,平均时间复杂度是 O(1)。我们可以先将所有的新西兰旅客记录转换成一个集合,之后查找匹配时,程序就不需要再遍历所有记录:

1
2
3
4
5
6
7
8
9
10
def find_potential_customers():
nz_records_idx = {
(rec['first_name'], rec['last_name'], rec['phone_number'])
for rec in users_visited_nz
}

for rec in users_visited_puket:
key = (rec['first_name'], rec['last_name'], rec['phone_number'])
if key not in nz_records_idx:
yield rec

引入集合后,新函数的性能有了突破性的增长,时间复杂度会直接线性下降:O(n+m),足够满足需求。

利用集合的游戏规则

计算去过普吉岛但没去过新西兰的人,其实就是一次集合的差值运算。对于任何自定义类型来说,当你对两个对象进行相等比较时,Python 只会判断它们是不是指向内存里的同一个地址。为了让集合能够正确处理类型,我们首先要重写类型的 __eq__ 魔法方法

对于哈希表来说,两个相等的对象,其哈希值也必须一样,否则一切算法逻辑都不再成立。因此,Python 会在发现重写了类型的 __eq__ 方法后,直接将其变为不可哈希,以此强制要求你为其设计新的哈希值算法

完成 VisitRecord 建模,做完所有的准备工作后,剩下的事情便顺水推舟了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class VisitRecord:
def __init__(self, first_name, last_name, phone_number, date_visited):
self.first_name = first_name
self.last_name = last_name
self.phone_number = phone_number
self.date_visited = date_visited

def __hash__(self):
return hash(self.comparable_fields)

def __eq__(self, other):
if isinstance(other, self.__class__):
return self.comparable_fields == other.comparable_fields
return False

@property
def comparable_fields(self):
return (self.first_name, self.last_name, self.phone_number)

def find_potential_customers():
return set(VisitRecord(**r) for r in users_visited_puket) - set(
VisitRecord(**r) for r in users_visited_nz
)

使用 dataclasses

dataclasses 最主要的用途是利用类型注解语法来快速定义 VisitRecord 数据类。在默认情况下,由 @dataclass 创建的数据类都是可修改的,不支持任何哈希操作,因此你必须指定 frozen=True,显式地将当前类变为不可变类型,这样才能正确计算对象的哈希值:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from dataclasses import dataclass, field

@dataclass(frozen=True)
class VisitRecord:
first_name: str
last_name: str
phone_number: str
date_visited: str = field(compare=False)


def find_potential_customers():
return set(VisitRecord(**r) for r in users_visited_puket) - set(
VisitRecord(**r) for r in users_visited_nz
)