Pytorch语法

数学运算

矩阵点乘,即对应元素相乘

1
2
3
numpy   : A * B 
matlab : A.*B
pytorch : A.mul(B)

矩阵运算相乘,即 $A\in{R^{n\times m}},$ $B\in{R^{m\times n}}$, $AB\in R^{n\times n}$

1
2
3
numpy   : np.dot(A, B) 
matlab : A*B
pytorch : A.mm(B)

操作

查到数组中大于3的元素及索引

1
2
3
numpy   : out = [i for i, x in enumerate(e) if x > 3]
matlab : out = find(e > 3)
pytorch : out = e[e.gt(3)] # gt大于 lt小于 eq等于