跳到主要内容

NumPy 数值计算

NumPy(Numerical Python)是 Python 科学计算的基础库,提供了高性能的多维数组对象 ndarray,以及大量用于数组运算的函数。

安装

pip install numpy

创建数组

import numpy as np

# 从列表创建
arr1 = np.array([1, 2, 3, 4, 5])

# 等差数列
arr2 = np.arange(0, 10, 2) # [0, 2, 4, 6, 8]
arr3 = np.linspace(0, 1, 5) # [0, 0.25, 0.5, 0.75, 1]

# 特殊数组
zeros = np.zeros((3, 4)) # 3×4 零矩阵
ones = np.ones((2, 3)) # 2×3 全1矩阵
eye = np.eye(3) # 3×3 单位矩阵
empty = np.empty((2, 2)) # 未初始化的数组

# 随机数组
rand = np.random.rand(3, 3) # [0, 1) 均匀分布
randn = np.random.randn(3, 3) # 标准正态分布
randint = np.random.randint(0, 10, (3, 3)) # 随机整数

数组属性

arr = np.array([[1, 2, 3], [4, 5, 6]])

print(arr.ndim) # 2(维度数)
print(arr.shape) # (2, 3)(形状)
print(arr.size) # 6(元素总数)
print(arr.dtype) # int64(数据类型)
print(arr.itemsize) # 8(每个元素字节数)

索引与切片

arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 索引
print(arr[0, 1]) # 2
print(arr[0][1]) # 2(等价写法)

# 切片
print(arr[0:2, 1:3]) # [[2, 3], [5, 6]]
print(arr[:, 0]) # 第一列: [1, 4, 7]
print(arr[1, :]) # 第二行: [4, 5, 6]

# 布尔索引
print(arr[arr > 5]) # [6, 7, 8, 9]

# 花式索引
print(arr[[0, 2], [1, 2]]) # [2, 9](取 (0,1) 和 (2,2))

数组运算

NumPy 的数组运算都是逐元素的:

a = np.array([1, 2, 3])
b = np.array([4, 5, 6])

print(a + b) # [5, 7, 9]
print(a - b) # [-3, -3, -3]
print(a * b) # [4, 10, 18]
print(a / b) # [0.25, 0.4, 0.5]
print(a ** 2) # [1, 4, 9]
print(np.sqrt(a)) # [1, 1.414, 1.732]

广播机制

广播让不同形状的数组也能进行运算:

arr = np.array([[1, 2, 3], [4, 5, 6]]) # shape: (2, 3)
scalar = 10
row = np.array([1, 0, 1]) # shape: (3,)
col = np.array([[1], [2]]) # shape: (2, 1)

print(arr + scalar) # 每个元素加 10
print(arr + row) # 每行加上 [1, 0, 1]
print(arr + col) # 每列加上 [1, 2]

广播规则:从最后一个维度开始比较,维度相等或其中一个为 1 时可以广播。

常用函数

arr = np.array([[1, 2, 3], [4, 5, 6]])

# 统计函数
print(arr.sum()) # 21
print(arr.sum(axis=0)) # 按列求和: [5, 7, 9]
print(arr.sum(axis=1)) # 按行求和: [6, 15]
print(arr.mean()) # 平均值
print(arr.max()) # 最大值
print(arr.min()) # 最小值
print(arr.argmax()) # 最大值的索引(展平后)
print(arr.std()) # 标准差

# 形状操作
print(arr.reshape(3, 2)) # 改变形状
print(arr.flatten()) # 展平为一维
print(arr.T) # 转置
print(arr.transpose()) # 转置(等价)

# 条件函数
print(np.where(arr > 3, arr, 0)) # 大于3保留,否则置0

线性代数

a = np.array([[1, 2], [3, 4]])
b = np.array([[5, 6], [7, 8]])

# 矩阵乘法
print(a @ b) # 矩阵乘法(Python 3.5+)
print(np.dot(a, b)) # 等价写法

# 线性代数函数
from numpy import linalg

print(linalg.inv(a)) # 逆矩阵
print(linalg.det(a)) # 行列式
print(linalg.eig(a)) # 特征值与特征向量
print(linalg.solve(a, b)) # 解线性方程组 ax = b

数组拷贝

arr = np.array([1, 2, 3])

# 视图:共享数据
view = arr.view()
view[0] = 100
print(arr) # [100, 2, 3](原数组也被修改)

# 深拷贝:独立数据
copy = arr.copy()
copy[0] = 999
print(arr) # [100, 2, 3](原数组不变)

小结

功能常用方法
创建数组np.array, np.zeros, np.ones, np.arange, np.linspace
随机数np.random.rand, np.random.randn, np.random.randint
索引切片arr[i, j], arr[:, 0], arr[arr > 0]
统计sum, mean, max, min, std, argmax
形状reshape, flatten, T, transpose
线性代数@, dot, linalg.inv, linalg.det