Python: 列表、陣列及迭代器切片的區別及聯絡

orion發表於2022-07-10

1. 對列表和陣列進行切片

1.1 切片索引

眾所周知,Python中的列表和numpy陣列都支援用begin: end語法來表示[begin, end)區間的的切片索引:

import numpy as np
my_list= [1, 2, 3, 4, 5]
print(my_list[2: 4]) # [3, 4]

my_arr = np.array([1, 2, 3, 4, 5])
print(my_arr[2: 4]) # [3 4]

以上操作實際上等同於用slice切片索引物件對其進行切片:

print(my_list[slice(2, 4)]) # [3, 4]
print(my_arr[slice(2, 4)]) # [3 4]

numpy陣列還支援用列表和numpy陣列來表示切片索引,而列表則不支援:

print(my_arr[[2, 3]]) # [3 4]
print(my_arr[np.arange(2, 4)]) # [3, 4]

print(my_list[[2, 3]]) # TypeError: list indices must be integers or slices, not list
print(my_list[np.arange(2, 4)]) # TypeError: only integer scalar arrays can be converted to a scalar index

Pytorch的torch.utils.data.Dataset資料集支援單元素索引,但不支援切片:

from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, ToTensor, Normalize

transform = Compose(
        [ToTensor(),
         Normalize((0.1307,), (0.3081,))
         ]
)
        
data = FashionMNIST(
        root="data",
        download=True,
        train=True,
        transform=transform
    )

print(data[0], data[1]) # (tensor(...), 0) (tensor(...), 0)
print(data[[0, 1]]) # ValueError: only one element tensors can be converted to Python scalars
print(data[: 2]) # ValueError: only one element tensors can be converted to Python scalars

要想對torch.utils.data.Dataset進行切片,需要建立Subset物件:

import torch
indices = [0, 1] # or indices = np.arange(2)
data_0to1 = torch.utils.data.Subset(data, indices)
print(type(data_0to1)) # <class 'torch.utils.data.dataset.Subset'>

Subset物件同樣支援單元素索引操作且不支援切片:

print(data_0to1[0]) # (tensor(...), 0)

檢視Pytorch原始碼可知,Subset類的定義實際上是這樣的:

class Subset(Dataset[T_co]):
    r"""
    Subset of a dataset at specified indices.

    Args:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    dataset: Dataset[T_co]
    indices: Sequence[int]

    def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

從以上程式碼片段可以清晰地看到Subset類用indices來儲存本身做為子集的索引集合,然後重寫(override)了__getitem__()方法來實現對子集的單元素索引。

1.2 對切片索引進行命名

有時我們會使用充滿硬編碼的切片索引,這使得程式碼難以閱讀,比如下面這段程式碼:

record = ".....100...513.25.."
cost = int(record[5: 8]) * float(record[11: 17])
print(cost)  # 51325.0

與其這樣做,我們不如對切片進行命名:

SHARES = slice(5, 8)
PRICE = slice(11, 17)
cost = int(record[SHARES]) * float(record[PRICE])
print(cost) # 51325.0

在後一種版本中,由於避免了使用許多神祕難懂的硬編碼索引,我們的程式碼就變得清晰了許多。

正如我們前面所說,這裡的slice()函式會建立一個slice型別的切片物件,可以用在任何執行切片的地方:

items = [0, 1, 2, 3, 4, 5, 6]
a = slice(2, 4)
print(items[2: 4]) # [2, 3]
print(items[a]) # [2, 3]
items[a] = [10, 11] 
print(items) # [0, 1, 10, 11, 4, 5, 6]
del items[a]
print(items) # [0, 1, 4, 5, 6]

如果有一個slice物件的例項s,可以分別用過s.starts.stop以及s.step屬性來跌倒關於該物件的資訊。例如:

a = slice(5, 50, 2)
print(a.start, a.stop, a.step) # 5 10 2

此外,可以通過使用indices(size)方法將切片對映到特定大小的序列上。這會返回一個[start, stop, step)元組,所有的值都已經恰當地限制在邊界以內(當做索引操作時可避免出現IndexError異常)。例如:

s = 'HelloWorld'
print(a.indices(len(s)))
print(*a.indices(len(s)))
for i in range(*a.indices(len(s))):
    print(s[i])
# W
# r
# d

2. 對迭代器做切片操作

要對迭代器和生成器做切片操作,普通的切片操作符在這裡是不管用的:

def count(n):
    while True:
        yield n
        n += 1
c = count(0)
print(c[10: 20]) # TypeError: 'generator' object is not subscriptable

此時,itertools.islice()函式是最完美的選擇:

import itertools
for x in itertools.islice(c, 10, 20):
    print(x)
# 10
# 11
# 12
# 13
# 14
# 15
# 16
# 17
# 18
# 19

注意,迭代器和生成器之所以沒法執行普通的切片操作,這是因為不知道它們的長度是多少(而且它們也沒有實現索引)。islice()產生的結果是一個迭代器,它可以產生出所需要的切片元素,但這是通過訪問並丟棄起始索引之前的元素來實現的。之後的元素會由islice物件產生出來,直到到達結束索引為止。

還有一點需要重點強調的是islice()會消耗掉所提供的的迭代器中資料。由於迭代器中的元素只能訪問一次,沒法倒回去,因此這裡就需要引起我們的注意了。如果之後還需要倒回去訪問前面的元素,那也許就應該先將資料轉到列表中去。

參考

相關文章