numpy 如何用 np.lib.stride_tricks.as_strided 实现零拷贝视图

as_strided的核心原理是零拷贝内存重解释:通过新shape、strides和offset直接映射原数组内存,要求总字节数≤原nbytes且strides以字节为单位;需手动保证安全,推荐writeable=False。

as_strided 的核心原理:绕过内存复制,直接重解释指针

它不创建新数组,只是用新的 shape、strides 和 offset 去“重新读取”原数组的内存块。只要新视图的内存访问范围完全落在原数组的连续内存内,就是安全的零拷贝;一旦越界或 stride 设错,就会读到脏数据甚至 segfault —— as_strided 不做边界检查,全靠你手动保证。

  • 必须确保 shape × 对应 strides 所覆盖的总字节数 ≤ 原数组 nbytes
  • strides 单位是字节,不是元素个数:比如 arr.dtype == np.float64(8 字节),则跨 1 行的 stride 应为 8 * arr.shape[1],不是 arr.shape[1]
  • 推荐配合 writeable=False 创建只读视图,避免意外修改底层内存引发未定义行为

滑动窗口(rolling window)的经典用法

这是最常用也最容易出错的场景。例如把一维数组 a = np.arange(10) 变成每 3 个元素一组、步长为 1 的窗口:[[0,1,2], [1,2,3], ..., [7,8,9]],形状应为 (8, 3)

from numpy.lib.stride_tricks import as_strided
a = np.arange(10)
win_size = 3
shape = (len(a) - win_size + 1, win_size)
strides = (a.strides[0], a.strides[0])  # 每行起始地址差 1 个元素(字节),列内也是 1 元素步长
windows = as_strided(a, shape=shape, strides=strides, writeable=False)

注意:strides 不能写成 (1, 1) —— 这是字节单

位错误;也不能写成 (a.itemsize, a.itemsize) 而不乘维度,因为 a.strides[0] 已经是字节跨度,直接复用最安全。

二维图像 patch 提取:避免 reshape + loop

对图像 img(如 (64, 64))提取 (16, 16) 大小、无重叠的 patches,目标 shape 是 (4, 4, 16, 16)(即 4×4 个 patch)。

  • img.strides(64 * img.itemsize, img.itemsize)
  • 新视图在 patch 行/列方向的 stride 应为 16 * img.strides[0]16 * img.strides[1]
  • patch 内部 stride 不变:img.strides
h, w = img.shape
ph, pw = 16, 16
shape = (h//ph, w//pw, ph, pw)
strides = (ph * img.strides[0], pw * img.strides[1], img.strides[0], img.strides[1])
patches = as_strided(img, shape=shape, strides=strides, writeable=False)

如果想支持重叠 patch(如步长=8),只需把 strides 中前两个分量改为 8 * img.strides[0]8 * img.strides[1],并相应调整 shape

为什么不用 sliding_window_view?

np.sliding_window_view(v1.20+)本质就是封装了 as_strided,但它做了安全校验、自动推导 stride,并返回可写视图(底层仍零拷贝)。如果你用的是旧 NumPy 或需要极致控制(比如自定义非均匀 stride),才需手写 as_strided;否则优先用 sliding_window_view —— 它更安全、语义清晰,且同样零拷贝。

真正容易被忽略的是:即使零拷贝,视图仍共享原数组内存;若原数组被 del 或超出作用域,视图会变成悬空指针,读取结果不可预测 —— 所以务必确保源数组生命周期 ≥ 视图使用期。