Pytorch系列之常用基礎操作

Taaccoo發表於2020-12-12

各種張量初始化

建立特殊型別的tensor

a = torch.FloatTensor(2,3) 
a = torch.DoubleTensor(2,3)
...  

設定pytorch中tensor的預設型別

torch.set_default_tensor_type(torch.DoubleTensor)

更改tensor型別

a.float()

各種常用初始化

torch.randn_like()

torch.rand(3,3)   #建立 0-1 (3,3)矩陣

torch.randn(3,3)  #建立 -1-1 (3,3)矩陣

torch.randint(1,10,[2,2])  #建立 1-10 (2,2) int型矩陣

按照不同的均值和方差進行初始化

torch.normal(mean=torch.full([20],0),std=torch.arange(0,1,0.1))

按照間隔初始化

torch.linspace(0,10,step=3)

torch.arange(1,10,5)

建立單位矩陣

torch.eye(4,4)

建立打亂的數列

torch.randperm(10)

返回tensor元素個數

torch.numel(torch.rand(2,2))

維度操作

矩陣拼接

torch.cat((x,x),0)
torch.stack((x,x),0)   #與cat不同的是,stack在拼接的時候,要增加一個維度

矩陣拆分

chuck直接按照數量來拆分,輸入N就拆分成N個

torch.chunk(a,N,dim) 

split的兩種用法,第一種是輸入一個數字,這樣就會拆分成這個總維度/數字個維度,第二個是如輸入一個列表,會按照列表指定的維度進行拆分

torch.split(a,[1,2],dim)

矩陣選取

在某個維度上選擇連續的N 列或者行

torch.narrow(dim,index,size)

選擇一個維度dim,從index開始取size個列或者行

a.index_select(dim, list)

各種選取

a[ : , 1:10,  ::2 , 1:10:2]

矩陣打平後選取

torch.take( tensor , list)

維度變化

a.view(1,5)
a.reshape(1,5) 

維度減少和增加

只有一個維度的時候,就是0在前面插入,-1或1在後面插入,可以把list當成是0.5維度

a.unsqueeze(1)
a.squeeze(1) 

維度擴張

a.expand()  

維度擴充套件expand,注意這裡的維度只能由1擴張成N,其他情況下是不能擴張的,另外維度不變的時候也可以用-1代替

a.repead()  

另外一種方式是使用repeat函式,repeat表示將之前的維度複製多少次,通過複製來進行擴張

維度交換

transpose(2,3)  # 交換兩個維度
permute(4,2,1,3) # 交換多個維度

數學運算

基礎運算

其中加減除法都可以使用運算子直接計算,乘法需要額外注意兩種不同的乘法,其中:

mul或者*是矩陣對應元素相乘

mm是針對於二維的矩陣正常乘法

matmul是針對任意維度矩陣的正常乘法,@是其符號過載

數字近似

floor() 向下取整

ceil() 向上取整

trunc() 保留整數

frac() 保留小數

數值裁剪

clamp(min)

clamp(min,max) #在這個閾值之外的都變成閾值

累乘

prod()

線性代數相關

trace           #矩陣的跡

diag            #獲取主對角線元素

triu/tril       #獲取上下三角矩陣

t               #轉置

dot/cross       #內積與外積

其他

Numpy Tensor 互相轉換

np_data = np.arange(6).reshape((2, 3))
torch_data = torch.from_numpy(np_data)
tensor2array = torch_data.numpy()

型別判斷

isinstance(a,torch.FloatTensor)

廣播

什麼時候可以使用廣播,廣播將從最後一個維度開始,從後往前開始匹配,當一個物件的維度是1或者與另一個物件的維度大小一樣的時候,可以匹配上,另外,如果一個物件的維度少於另外一個維度的物件,只要從後往前開始的維度匹配,那麼就可以使用廣播。

例如

(1,2,3,4) 和 (2,3,4) or (1,2,3,4) 可以廣播

(1,2,3,4) 和 (1,1,1) or (1,1,1,1) 可以廣播

topk

topk可以幫助返回在某一維度上最大的k個值以及下標,只需要將largest=False,就可以返回最小的k個值

where條件選擇

根據條件是否成立,選擇矩陣X或者矩陣Y中的元素

where(condition > 0.5 , X , Y )  

gather

本質就是在查表,第一個引數是表格,第二個是維度,第三個是要查詢的索引

操作就是,在inpu中選擇維度dim,然後根據index編號,讀取input中的元素

torch.gather(input,dim,index,out=None) 
 

相關文章