pytorch中的Tensor乘法分为点乘(dot multiplication/point-wise multiplication/element-wise multiplication)和矩阵乘法(matrix multiplication)。
1. Tensor dot multiplication
Tensor dot multiplcation 采用 *
符号,如:C=A * B
,如果是矩阵A的point-wise的平方运算,可以采用C=A ** 2
来实现。
:todo dot product是否要求shape一定相同?
2. Tensor matrix multiplication
矩阵运算分为两类,dimension相同的tensor的运算,dimension不同的tensor的运算。
(1)matrix product with same dimension
设Tensor A.size()=[b,i,j], tensor B.size()=[b,j,k],则下面几种形式可以得到它们矩阵乘法的结果C.size()=[b,i,k]。
- C = A @ B
- C = torch.mulmat(A,B)
- C = torch.bmm(A,B)
- C = torch.einsum(‘bij,bjk->bik’,A,B)
(2)matrix product with different dimension
设Tensor A.size()=[b,i,j], tensor B.size()=[j,k],则下面几种形式可以得到它们矩阵乘法的结果C.size()=[b,i,k]。
- C = torch.mulmat(A,B)
- C = torch.einsum(‘bij,jk->bik’,A,B)
torch.bmm不支持broadcast, 而torch.matmul()支持broadcast。 broadcast是从numpy借鉴过来的一种性质,可以自动对维数不同的两个tensor用unsquzee()进行补齐,需要满足broadcast的条件才能进行矩阵乘法运算,条件是:
- same shape, 比如都是 [b,i,j],此时结果为dot product. [:todo, 需要确认]
- the size of the trailing axes for both arrays in an operation must either be the same size or one of them must be one.
参考资料