各種張量初始化
建立特殊型別的tensor
a = torch.FloatTensor(2,3)
a = torch.DoubleTensor(2,3)
...
設定pytorch中tensor的預設型別
torch.set_default_tensor_type(torch.DoubleTensor)
更改tensor型別
a.float()
各種常用初始化
torch.randn_like()
torch.rand(3,3) #建立 0-1 (3,3)矩陣
torch.randn(3,3) #建立 -1-1 (3,3)矩陣
torch.randint(1,10,[2,2]) #建立 1-10 (2,2) int型矩陣
按照不同的均值和方差進行初始化
torch.normal(mean=torch.full([20],0),std=torch.arange(0,1,0.1))
按照間隔初始化
torch.linspace(0,10,step=3)
torch.arange(1,10,5)
建立單位矩陣
torch.eye(4,4)
建立打亂的數列
torch.randperm(10)
返回tensor元素個數
torch.numel(torch.rand(2,2))
維度操作
矩陣拼接
torch.cat((x,x),0)
torch.stack((x,x),0) #與cat不同的是,stack在拼接的時候,要增加一個維度
矩陣拆分
chuck直接按照數量來拆分,輸入N就拆分成N個
torch.chunk(a,N,dim)
split的兩種用法,第一種是輸入一個數字,這樣就會拆分成這個總維度/數字個維度,第二個是如輸入一個列表,會按照列表指定的維度進行拆分
torch.split(a,[1,2],dim)
矩陣選取
在某個維度上選擇連續的N 列或者行
torch.narrow(dim,index,size)
選擇一個維度dim,從index開始取size個列或者行
a.index_select(dim, list)
各種選取
a[ : , 1:10, ::2 , 1:10:2]
矩陣打平後選取
torch.take( tensor , list)
維度變化
a.view(1,5)
a.reshape(1,5)
維度減少和增加
只有一個維度的時候,就是0在前面插入,-1或1在後面插入,可以把list當成是0.5維度
a.unsqueeze(1)
a.squeeze(1)
維度擴張
a.expand()
維度擴充套件expand,注意這裡的維度只能由1擴張成N,其他情況下是不能擴張的,另外維度不變的時候也可以用-1代替
a.repead()
另外一種方式是使用repeat函式,repeat表示將之前的維度複製多少次,通過複製來進行擴張
維度交換
transpose(2,3) # 交換兩個維度
permute(4,2,1,3) # 交換多個維度
數學運算
基礎運算
其中加減除法都可以使用運算子直接計算,乘法需要額外注意兩種不同的乘法,其中:
mul或者*是矩陣對應元素相乘
mm是針對於二維的矩陣正常乘法
matmul是針對任意維度矩陣的正常乘法,@是其符號過載
數字近似
floor() 向下取整
ceil() 向上取整
trunc() 保留整數
frac() 保留小數
數值裁剪
clamp(min)
clamp(min,max) #在這個閾值之外的都變成閾值
累乘
prod()
線性代數相關
trace #矩陣的跡
diag #獲取主對角線元素
triu/tril #獲取上下三角矩陣
t #轉置
dot/cross #內積與外積
其他
Numpy Tensor 互相轉換
np_data = np.arange(6).reshape((2, 3))
torch_data = torch.from_numpy(np_data)
tensor2array = torch_data.numpy()
型別判斷
isinstance(a,torch.FloatTensor)
廣播
什麼時候可以使用廣播,廣播將從最後一個維度開始,從後往前開始匹配,當一個物件的維度是1或者與另一個物件的維度大小一樣的時候,可以匹配上,另外,如果一個物件的維度少於另外一個維度的物件,只要從後往前開始的維度匹配,那麼就可以使用廣播。
例如
(1,2,3,4) 和 (2,3,4) or (1,2,3,4) 可以廣播
(1,2,3,4) 和 (1,1,1) or (1,1,1,1) 可以廣播
topk
topk可以幫助返回在某一維度上最大的k個值以及下標,只需要將largest=False,就可以返回最小的k個值
where條件選擇
根據條件是否成立,選擇矩陣X或者矩陣Y中的元素
where(condition > 0.5 , X , Y )
gather
本質就是在查表,第一個引數是表格,第二個是維度,第三個是要查詢的索引
操作就是,在inpu中選擇維度dim,然後根據index編號,讀取input中的元素
torch.gather(input,dim,index,out=None)