Tensor基本运算

矩阵相乘

torch.mm:只适合矩阵 dim=2情形

torch.matmul:适用任何形式

@:简便写法

>>>a=torch.rand(4,784)
>>>x=torch.rand(512,784)
>>>(a@x.t()).shape
#torch.Size([4, 512])
>>>a=torch.rand(4,3,28,64)
>>>b=torch.rand(4,3,64,32)
>>>torch.matmul(a,b).shape
#torch.Size([4, 3, 28, 32])

乘方

power

>>>a=torch.full([2,2],3)
>>>a.pow(2)
#tensor([[9, 9],
# [9, 9]])

取整

.floor():向下取整

.ceil():向上取整

.trunc():取小数

.frac():取整数

.round():四舍五入

a=torch.tensor(3.14)
a.floor(),a.ceil(),a.trunc(),a.frac(),a.round()
#(tensor(3.), tensor(4.), tensor(3.), tensor(0.1400), tensor(3.))

裁剪

.clamp():输入参数min :将小于min的数都置为min

​ 输入参数(min,max):将小于min的数都置为min,大于max的数都置为max

a=torch.rand(2,3)*15
a.clamp(1,10)
#tensor([[10.0000, 10.0000, 2.5097],
# [10.0000, 1.2573, 8.4877]])

自动求导:

方法一
1. 首先对需要求导的参数使用requires_grad_()方法标明需要求导
2. 计算mse:torch.nn.functional.mse_loss($y$,$\hat y$)
3. 使用torch.autograd.grad(mse,[w])对其进行求导

方法二

  1. 首先对需要求导的参数使用requires_grad_()方法标明需要求导

  2. 计算mse:torch.nn.functional.mse_loss($y,\hat y$)

  3. 调用mse.backward该指令会计算mse对所有已设置需要求导变量的梯度

  4. 调用w.grad显示梯度

    backward设置(retain_graph=True)才可以再一次调用,不设置则会报错

import torch.nn.functional as F
x=torch.ones(1)
w=torch.Tensor([2])
w.requires_grad_()
mse=F.mse_loss(torch.ones(1),x*w)
#torch.autograd.grad(mse,[w])
mse.backward()
print(w.grad)

#tensor([2.])

Tensor统计属性

范数

范数

.norm§:求矩阵的 p 范数

.norm(p,dim=x):在 x 维度上做p范数,输出shape为除了原维度去掉x维度

a = torch.rand([8])
print(a)
b = a.reshape(2,2,2)
print(b)
b.norm(1,dim=0)
#tensor([0.3336, 0.0033, 0.5679, 0.7974, 0.1241, 0.4108, 0.2766, 0.8038])
#tensor([[[0.3336, 0.0033],
# [0.5679, 0.7974]],
#
# [[0.1241, 0.4108],
# [0.2766, 0.8038]]])
#tensor([[0.4577, 0.4141],
# [0.8445, 1.6012]])

统计属性

.prod():累乘

a = torch.rand([8])
print(a)
a.prod()
#tensor(0.0008)

.argmax():返回最大元素的索引,该索引是tensor打平为1维的索引

.argmin():返回最小元素的索引,该索引是tensor打平为1维的索引

.argmax(dim=x):返回最大元素的索引,该索引是 x维度上 的索引

a = torch.rand([2,2,3])
print(a)
a.argmax()
# tensor([[[0.2517, 0.9526, 0.5908],
# [0.1431, 0.3951, 0.5676]],

# [[0.7481, 0.8191, 0.4051],
# [0.7140, 0.4541, 0.5540]]])
# tensor(1)
a = torch.rand([2,2,3])
print(a)
a.argmax(dim=1)
# tensor([[[0.0630, 0.4025, 0.8124],
# [0.2175, 0.4514, 0.5231]],

# [[0.8366, 0.4124, 0.6334],
# [0.3470, 0.0701, 0.2093]]])
# tensor([[1, 1, 0],
# [0, 0, 0]])

keepdim=True :返回的tensor与原tensor维度一样

TOPK与K-TH

.topk(k,dim=x,largest=true): largest=true返回x维度上最大的k个值,largest=false返回x维度上最小的k个值,输出第一个参数为其值,第二个参数维其索引

a = torch.rand([4,4])
print(a)
a.topk(3,dim=1)

# tensor([[0.2393, 0.7239, 0.3985, 0.5578],
# [0.8645, 0.0815, 0.7446, 0.3979],
# [0.6933, 0.7192, 0.4393, 0.2296],
# [0.1022, 0.7430, 0.6715, 0.9983]])
# torch.return_types.topk(
# values=tensor([[0.7239, 0.5578, 0.3985],
# [0.8645, 0.7446, 0.3979],
# [0.7192, 0.6933, 0.4393],
# [0.9983, 0.7430, 0.6715]]),
# indices=tensor([[1, 3, 2],
# [0, 2, 3],
# [1, 0, 2],
# [3, 1, 2]]))

.kthvalue(k,dim=x):返回由小到大第k个值及其索引

a = torch.rand([4,4])
print(a)
a.kthvalue(3,dim=1)

# tensor([[0.4287, 0.7747, 0.8699, 0.7784],
# [0.1043, 0.4982, 0.5863, 0.3341],
# [0.1408, 0.0510, 0.4056, 0.9592],
# [0.3366, 0.1080, 0.8596, 0.3885]])
# torch.return_types.kthvalue(
# values=tensor([0.7784, 0.4982, 0.4056, 0.3885]),
# indices=tensor([3, 1, 2, 3]))

高阶操作

torch.where

torch.where(condition,x,y)→Tensor

$$
out_i=\begin{cases}x_i\ \ if\ condition_i\\y_i\ \ otherwise\end{cases}
$$

a = torch.zeros([4,4])
b=torch.ones([4,4])
condition=torch.rand([4,4])
print(condition)
torch.where(condition>0.5,a,b)

# tensor([[0.5633, 0.7544, 0.6521, 0.6338],
# [0.5439, 0.5644, 0.6126, 0.1168],
# [0.6247, 0.4382, 0.4246, 0.2221],
# [0.0017, 0.7347, 0.6782, 0.9357]])
# tensor([[0., 0., 0., 0.],
# [0., 0., 0., 1.],
# [0., 1., 1., 1.],
# [1., 0., 0., 0.]])

torch.gather

torch.gather(input,dim,index)→Tensor:根据将index的第dim维作为索引查取input中对应元素并生成Tensor输出
$$
input=\begin{bmatrix}cat\\dog\\fish\end{bmatrix}\ \ dim=0\ \ index=\begin{bmatrix}1\\2\\0\end{bmatrix}\Rightarrow\ \ out=\begin{bmatrix}dog\\fish\\cat\end{bmatrix}
$$

a=torch.rand(4,10)
a1=a.topk(3,dim=1)
i=a1[1]
b=torch.arange(10)+100
torch.gather(b.expand(4,10),dim=1,index=i)
# tensor([[100, 105, 101],
# [101, 105, 108],
# [107, 104, 100],
# [101, 107, 106]])