AF3 Rotation 类解读

Rotation 类(rigid_utils 模块)是 AlphaFold3 中用于 3D旋转 的核心组件,支持两种旋转表示: 1️⃣ 旋转矩阵 (3x3)
2️⃣ 四元数 (quaternion, 4元向量)

👉 设计目标

  • 允许灵活选择 旋转矩阵 或 四元数

  • 封装了常用的 旋转操作(组合、逆旋转、应用到点上等)

  • 像 torch.Tensor 一样,支持索引、拼接、广播等操作


class Rotation:
        A 3D rotation. Depending on how the object is initialized, the
        rotation is represented by either a rotation matrix or a
        quaternion, though both formats are made available by helper functions.
        To simplify gradient computation, the underlying format of the
        rotation cannot be changed in-place. Like Rigid, the class is designed
        to mimic the behavior of a torch Tensor, almost as if each Rotation
        object were a tensor of rotations, in one format or another.
    def __init__(self,
        rot_mats: Optional[torch.Tensor] = None,
        quats: Optional[torch.Tensor] = None,
        normalize_quats: bool = True,
                    A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
                    A [*, 4] quaternion. Mutually exclusive with rot_mats. If
                    normalize_quats is not True, must be a unit quaternion
                    If quats is specified, whether to normalize quats
        if((rot_mats is None and quats is None) or 
            (rot_mats is not None and quats is not None)):
            raise ValueError("Exactly one input argument must be specified")

        if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or 
            (quats is not None and quats.shape[-1] != 4)):
            raise ValueError(
                "Incorrectly shaped rotation matrix or quaternion"

        # Force full-precision
        if(quats is not None):
            quats = quats.to(dtype=torch.float32)
        if(rot_mats is not None):
            rot_mats = rot_mats.to(dtype=torch.float32)

        if(quats is not None and normalize_quats):
            quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)

        self._rot_mats = rot_mats
        self._quats = quats

    def identity(
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        requires_grad: bool = True,
        fmt: str = "quat",
    ) -> Rotation:
            Returns an identity Rotation.

                    The "shape" of the resulting Rotation object. See documentation
                    for the shape property
                    The torch dtype for the rotation
                    The torch device for the new rotation
                    Whether the underlying tensors in the new rotation object
                    should require gradient computation
                    One of "quat" or "rot_mat". Determines the underlying format
                    of the new object's rotation 
                A new identity rotation
        if(fmt == "rot_mat"):
            rot_mats = identity_rot_mats(
                shape, dtype, device, requires_grad,
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(fmt == "quat"):
            quats = identity_quats(shape, dtype, device, requires_grad)
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
            raise ValueError(f"Invalid format: f{fmt}")

    # Magic methods

    def __getitem__(self, index: Any) -> Rotation:
            Allows torch-style indexing over the virtual shape of the rotation
            object. See documentation for the shape property.

                    A torch index. E.g. (1, 3, 2), or (slice(None,))
                The indexed rotation
        if type(index) != tuple:
            index = (index,)

        if(self._rot_mats is not None):
            rot_mats = self._rot_mats[index + (slice(None), slice(None))]
            return Rotation(rot_mats=rot_mats)
        elif(self._quats is not None):
            quats = self._quats[index + (slice(None),)]
            return Rotation(quats=quats, normalize_quats=False)
            raise ValueError("Both rotations are None")

    def __mul__(self,
        right: torch.Tensor,
    ) -> Rotation:
            Pointwise left multiplication of the rotation with a tensor. Can be
            used to e.g. mask the Ro



