Jax.numpy旋转矩阵

，我们通过给定三个欧拉角，来旋转整个分子系统。这就需要我们分别定义三个维度的旋转矩阵$$R_x(\phi),R_y(\psi),R_z(\theta)$$
，分别表示绕$$X$$

def rotation(psi,phi,theta,v):
""" Module of rotation in 3 Euler angles. """
RY = np.array([[np.cos(psi),0,-np.sin(psi)],
[0, 1, 0],
[np.sin(psi),0,np.cos(psi)]])
RX = np.array([[1,0,0],
[0,np.cos(phi),-np.sin(phi)],
[0,np.sin(phi),np.cos(phi)]])
RZ = np.array([[np.cos(theta),-np.sin(theta),0],
[np.sin(theta),np.cos(theta),0],
[0,0,1]])
return np.dot(RZ,np.dot(RX,np.dot(RY,v)))
multi_rotation = jit(vmap(rotation,(None,None,None,0)))

In [1]: from jax import numpy as np
In [2]: from jax import jit, vmap
In [3]: def rotation(psi,phi,theta,v):
...:     """ Module of rotation in 3 Euler angles. """
...:     RY = np.array([[np.cos(psi),0,-np.sin(psi)],
...:                    [0, 1, 0],
...:                    [np.sin(psi),0,np.cos(psi)]])
...:     RX = np.array([[1,0,0],
...:                    [0,np.cos(phi),-np.sin(phi)],
...:                    [0,np.sin(phi),np.cos(phi)]])
...:     RZ = np.array([[np.cos(theta),-np.sin(theta),0],
...:                    [np.sin(theta),np.cos(theta),0],
...:                    [0,0,1]])
...:     return np.dot(RZ,np.dot(RX,np.dot(RY,v)))
...:
In [4]: multi_rotation = jit(vmap(rotation,(None,None,None,0)))
In [5]: import numpy as onp
In [6]: v=onp.random.random((3,3))
In [7]: v
Out[7]:
array([[0.97911664, 0.48098486, 0.44966794],
[0.25350689, 0.50949849, 0.77506796],
[0.24502845, 0.23313826, 0.72014647]])
In [8]: multi_rotation(onp.pi, onp.pi, 0, v)
Out[8]:
DeviceArray([[-0.97911656, -0.4809849 ,  0.449668  ],
[-0.25350684, -0.50949854,  0.7750679 ],
[-0.24502839, -0.23313832,  0.7201465 ]], dtype=float32)

MindSpore旋转矩阵

In [1]: from mindspore import ops, Tensor
In [2]: import mindspore as ms
In [3]: import numpy as np
In [4]: psi = Tensor([np.pi], ms.float32)
In [5]: phi = Tensor([np.pi], ms.float32)
In [6]: theta = Tensor([0.], ms.float32)
In [7]: v = Tensor(np.random.random((3,3)), ms.float32)
In [8]: v
Out[8]:
Tensor(shape=[3, 3], dtype=Float32, value=
[[ 4.51581478e-01,  7.52180338e-01,  2.84639597e-01],
[ 8.46439958e-01,  2.95659006e-01,  1.81022584e-01],
[ 8.94563913e-01,  2.25287616e-01,  1.71754003e-01]])
In [9]: zero = Tensor([0.], ms.float32)
In [10]: one = Tensor([1.], ms.float32)
In [11]: def rotation(psi, phi, theta, v):
...:     RY = ops.Concat(-1)((ops.Cos()(psi), zero, -ops.Sin()(psi),
...:                          zero, one, zero,
...:                          ops.Sin()(psi), zero, ops.Cos()(psi)))
...:     RY = RY.reshape(3, 3)
...:     RX = ops.Concat(-1)((one, zero, zero,
...:                          zero, ops.Cos()(phi), -ops.Sin()(phi),
...:                          zero, ops.Sin()(phi), ops.Cos()(phi)))
...:     RX = RX.reshape(3, 3)
...:     RZ = ops.Concat(-1)((ops.Cos()(theta), -ops.Sin()(theta), zero,
...:                        ops.Sin()(theta), ops.Cos()(theta), zero,
...:                        zero, zero, one))
...:     RZ = RZ.reshape(3, 3)
...:     dot = ops.Einsum('ij,kj->ki')
...:     return dot((RZ, dot((RX, dot((RY, v))))))
...:
In [12]: rotation(psi, phi, theta, v)
Out[12]:
Tensor(shape=[3, 3], dtype=Float32, value=
[[-4.51581448e-01, -7.52180338e-01,  2.84639567e-01],
[-8.46439958e-01, -2.95659035e-01,  1.81022629e-01],
[-8.94563913e-01, -2.25287631e-01,  1.71754062e-01]])

1. 在上述案例中，我们先定义了一系列的一维Tensor来作为旋转矩阵的元素，使用MindSpore的Concat算子将这些一维Tensor的最后一维取出组成一个新的Tensor，再对其做reshape操作，得到一个我们所需要的旋转矩阵。

1. 在Jax中我们是使用了vmap将旋转矩阵对单个矢量旋转的操作扩展到对多个矢量的旋转操作，而在MindSpore中虽然也支持了Vmap的算子，但是这里我们使用的是MindSpore所支持的另外一个功能：爱因斯坦求和算子。使用这个算子，我们就允许了旋转矩阵直接对多个矢量输入的指定维度进行运算，一样也可以得到我们想要的计算结果。