跳到主要内容

控制流程进阶

Python 的 for 循环和 with 语句看似简单,但背后有一套精心设计的协议。理解迭代器、生成器和上下文管理器,能让你写出更优雅、更高效的代码。

迭代协议

Python 中所有可以用 for 循环遍历的对象都实现了迭代协议。这个协议包含两个部分:

  1. 可迭代对象(Iterable):实现了 __iter__() 方法,返回一个迭代器
  2. 迭代器(Iterator):实现了 __next__() 方法,每次返回下一个元素,没有元素时抛出 StopIteration
class CountDown:
def __init__(self, start):
self.start = start

def __iter__(self):
return CountDownIterator(self.start)

class CountDownIterator:
def __init__(self, start):
self.current = start

def __next__(self):
if self.current < 0:
raise StopIteration
num = self.current
self.current -= 1
return num

def __iter__(self):
return self

for n in CountDown(5):
print(n, end=" ") # 5 4 3 2 1 0

通常我们会把 __iter____next__ 放在同一个类里:

class CountDown:
def __init__(self, start):
self.start = start

def __iter__(self):
self.current = self.start
return self

def __next__(self):
if self.current < 0:
raise StopIteration
num = self.current
self.current -= 1
return num

iter() 函数

iter() 函数可以接受两个参数来创建一个迭代器:

import random

# 不断调用 random.randint(1, 6),直到返回 6 为止
for roll in iter(lambda: random.randint(1, 6), 6):
print(f"掷出了 {roll}")
print("掷出了 6,结束!")

第一个参数是无参函数,第二个参数是哨兵值。当函数的返回值等于哨兵值时,迭代停止。

生成器

生成器(generator)是一种更简洁的迭代器实现方式。生成器函数使用 yield 关键字返回数据,而不是 return

def countdown(start):
while start >= 0:
yield start
start -= 1

for n in countdown(5):
print(n, end=" ") # 5 4 3 2 1 0

生成器函数被调用时不会立即执行,而是返回一个生成器对象。每次调用生成器对象的 __next__() 时,函数执行到 yield 处暂停并返回值,下次从暂停处继续执行。

生成器 vs 列表

生成器是惰性计算的,它不会一次性把所有值都生成出来,而是按需逐个生成。这在处理大数据时非常有用:

# 列表:一次性占用大量内存
squares_list = [x**2 for x in range(1000000)]

# 生成器:几乎不占内存
squares_gen = (x**2 for x in range(1000000))

print(sum(squares_gen)) # 按需计算

生成器表达式和列表推导式的语法很像,只是用圆括号代替了方括号。

yield from

yield from 可以把一个子生成器的值逐个 yield 出来:

def flatten(nested):
for item in nested:
if isinstance(item, list):
yield from flatten(item)
else:
yield item

nested = [1, [2, [3, 4], 5], 6]
print(list(flatten(nested))) # [1, 2, 3, 4, 5, 6]

yield from 不仅简化了代码,还能正确传递子生成器的返回值(Python 3.3+)。

itertools

itertools 模块提供了很多实用的迭代器工具:

import itertools

# 无限计数器
for i in itertools.count(10, 2): # 从 10 开始,步长 2
if i > 20:
break
print(i) # 10 12 14 16 18 20

# 循环迭代
for item in itertools.cycle(["A", "B", "C"]):
print(item) # A B C A B C ... 无限循环

# 组合
print(list(itertools.combinations([1, 2, 3], 2))) # [(1,2), (1,3), (2,3)]
print(list(itertools.permutations([1, 2, 3], 2))) # 所有排列
print(list(itertools.product([1, 2], ["A", "B"]))) # 笛卡尔积

# 链式连接
print(list(itertools.chain([1, 2], [3, 4], [5, 6]))) # [1, 2, 3, 4, 5, 6]

# 分组
for key, group in itertools.groupby("AAABBBCCAAA"):
print(key, list(group))
# A ['A', 'A', 'A']
# B ['B', 'B', 'B']
# C ['C', 'C']
# A ['A', 'A', 'A']

groupby 要求输入已经按分组键排序,否则同一组可能会被拆成多个。

上下文管理器

上下文管理器(context manager)用于管理资源的获取和释放,最常见的用法是 with 语句:

with open("data.txt") as f:
content = f.read()
# 文件自动关闭,即使发生异常

自定义上下文管理器

实现 __enter____exit__ 方法:

class Timer:
def __enter__(self):
import time
self.start = time.time()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
import time
elapsed = time.time() - self.start
print(f"耗时: {elapsed:.4f} 秒")
# 返回 False 表示不吞掉异常
return False

with Timer():
# 执行业务逻辑
result = sum(range(1000000))

__exit__ 接收三个参数:异常类型、异常值、回溯信息。如果没有异常,这三个参数都是 None

@contextmanager

contextlib.contextmanager 装饰器可以更简洁地创建上下文管理器:

from contextlib import contextmanager

@contextmanager
def managed_resource(name):
print(f"获取资源: {name}")
resource = {"name": name}
try:
yield resource
finally:
print(f"释放资源: {name}")

with managed_resource("数据库连接") as r:
print(f"使用 {r['name']}")

yield 之前的代码相当于 __enter__yield 之后的代码相当于 __exit__

多个上下文管理器

可以同时使用多个 with

with open("input.txt") as fin, open("output.txt", "w") as fout:
fout.write(fin.read().upper())

忽略异常

contextlib.suppress 可以优雅地忽略特定异常:

from contextlib import suppress

with suppress(FileNotFoundError):
os.remove("临时文件.txt")
# 如果文件不存在,不会报错

else 块

Python 的 forwhile 循环可以带一个 else 块。很多人不知道这个特性,或者误以为它和 if-else 有关。实际上,循环的 else 只在循环正常结束(没有被 break)时执行

for i in range(5):
if i == 10:
break
else:
print("循环正常完成,没有触发 break")

# 输出:循环正常完成,没有触发 break

这在搜索场景中很有用:判断一个列表中是否没有满足条件的元素:

def is_all_passed(scores):
for score in scores:
if score < 60:
print("有人不及格")
break
else:
print("全部及格!")

is_all_passed([85, 92, 78]) # 全部及格!
is_all_passed([85, 55, 78]) # 有人不及格

如果不用 else,通常需要设置一个标志位:

def is_all_passed(scores):
has_failed = False
for score in scores:
if score < 60:
has_failed = True
break
if not has_failed:
print("全部及格!")

for-else 的版本更简洁。

小结

  • 迭代协议__iter__ 返回迭代器,__next__ 逐个返回值
  • 生成器:用 yield 实现惰性计算,节省内存
  • 生成器表达式(x for x in ...),比列表推导式更省内存
  • yield from:委托给子生成器
  • itertools:组合、排列、分组、链式连接等工具
  • 上下文管理器with 语句自动管理资源
  • @contextmanager:用生成器函数快速创建上下文管理器
  • for-elseelse 只在循环未被 break 时执行