pytorch中的Tensor乘法运算

pytorch中的Tensor乘法分为点乘(dot multiplication/point-wise multiplication/element-wise multiplication)和矩阵乘法(matrix multiplication)。

1. Tensor dot multiplication

Tensor dot multiplcation 采用 *符号,如:C=A * B,如果是矩阵A的point-wise的平方运算,可以采用C=A ** 2来实现。

:todo dot product是否要求shape一定相同?

2. Tensor matrix multiplication

矩阵运算分为两类,dimension相同的tensor的运算,dimension不同的tensor的运算。

(1)matrix product with same dimension

设Tensor A.size()=[b,i,j], tensor B.size()=[b,j,k],则下面几种形式可以得到它们矩阵乘法的结果C.size()=[b,i,k]。

  • C = A @ B
  • C = torch.mulmat(A,B)
  • C = torch.bmm(A,B)
  • C = torch.einsum(‘bij,bjk->bik’,A,B)

(2)matrix product with different dimension

设Tensor A.size()=[b,i,j], tensor B.size()=[j,k],则下面几种形式可以得到它们矩阵乘法的结果C.size()=[b,i,k]。

  • C = torch.mulmat(A,B)
  • C = torch.einsum(‘bij,jk->bik’,A,B)

torch.bmm不支持broadcast, 而torch.matmul()支持broadcast。 broadcast是从numpy借鉴过来的一种性质,可以自动对维数不同的两个tensor用unsquzee()进行补齐,需要满足broadcast的条件才能进行矩阵乘法运算,条件是:

  • same shape, 比如都是 [b,i,j],此时结果为dot product. [:todo, 需要确认]
  • the size of the trailing axes for both arrays in an operation must either be the same size or one of them must be one.

参考资料