当前位置: 首页 > article >正文

AF3 Rotation 类解读

Rotation 类(rigid_utils 模块)是 AlphaFold3 中用于 3D旋转 的核心组件,支持两种旋转表示: 1️⃣ 旋转矩阵 (3x3)
2️⃣ 四元数 (quaternion, 4元向量)

👉 设计目标

  • 允许灵活选择 旋转矩阵 或 四元数

  • 封装了常用的 旋转操作(组合、逆旋转、应用到点上等)

  • 像 torch.Tensor 一样,支持索引、拼接、广播等操作

源代码:

class Rotation:
    """
        A 3D rotation. Depending on how the object is initialized, the
        rotation is represented by either a rotation matrix or a
        quaternion, though both formats are made available by helper functions.
        To simplify gradient computation, the underlying format of the
        rotation cannot be changed in-place. Like Rigid, the class is designed
        to mimic the behavior of a torch Tensor, almost as if each Rotation
        object were a tensor of rotations, in one format or another.
    """
    def __init__(self,
        rot_mats: Optional[torch.Tensor] = None,
        quats: Optional[torch.Tensor] = None,
        normalize_quats: bool = True,
    ):
        """
            Args:
                rot_mats:
                    A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
                    quats
                quats:
                    A [*, 4] quaternion. Mutually exclusive with rot_mats. If
                    normalize_quats is not True, must be a unit quaternion
                normalize_quats:
                    If quats is specified, whether to normalize quats
        """
        if((rot_mats is None and quats is None) or 
            (rot_mats is not None and quats is not None)):
            raise ValueError("Exactly one input argument must be specified")

        if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or 
            (quats is not None and quats.shape[-1] != 4)):
            raise ValueError(
                "Incorrectly shaped rotation matrix or quaternion"
            )

        # Force full-precision
        if(quats is not None):
            quats = quats.to(dtype=torch.float32)
        if(rot_mats is not None):
            rot_mats = rot_mats.to(dtype=torch.float32)

        if(quats is not None and normalize_quats):
            quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)

        self._rot_mats = rot_mats
        self._quats = quats

    @staticmethod
    def identity(
        shape,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        requires_grad: bool = True,
        fmt: str = "quat",
    ) -> Rotation:
        """
            Returns an identity Rotation.

            Args:
                shape:
                    The "shape" of the resulting Rotation object. See documentation
                    for the shape property
                dtype:
                    The torch dtype for the rotation
                device:
                    The torch device for the new rotation
                requires_grad:
                    Whether the underlying tensors in the new rotation object
                    should require gradient computation
                fmt:
                    One of "quat" or "rot_mat". Determines the underlying format
                    of the new object's rotation 
            Returns:
                A new identity rotation
        """
        if(fmt == "rot_mat"):
            rot_mats = identity_rot_mats(
                shape, dtype, device, requires_grad,
            )
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(fmt == "quat"):
            quats = identity_quats(shape, dtype, device, requires_grad)
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError(f"Invalid format: f{fmt}")

    # Magic methods

    def __getitem__(self, index: Any) -> Rotation:
        """
            Allows torch-style indexing over the virtual shape of the rotation
            object. See documentation for the shape property.

            Args:
                index:
                    A torch index. E.g. (1, 3, 2), or (slice(None,))
            Returns:
                The indexed rotation
        """
        if type(index) != tuple:
            index = (index,)

        if(self._rot_mats is not None):
            rot_mats = self._rot_mats[index + (slice(None), slice(None))]
            return Rotation(rot_mats=rot_mats)
        elif(self._quats is not None):
            quats = self._quats[index + (slice(None),)]
            return Rotation(quats=quats, normalize_quats=False)
        else:
            raise ValueError("Both rotations are None")

    def __mul__(self,
        right: torch.Tensor,
    ) -> Rotation:
        """
            Pointwise left multiplication of the rotation with a tensor. Can be
            used to e.g. mask the Ro

http://www.kler.cn/a/600418.html

相关文章:

  • Java多线程与高并发专题——如何利用 CompletableFuture 解决“聚合打车服务平台”的问题?
  • Sqladmin - FastAPI框架下一键生成管理后台
  • python常见反爬思路详解
  • 网络基础梳理
  • OWASP Top漏洞说明
  • Python爬虫获取1688商品(按图搜索)接口的返回数据说明
  • vulnhub-Tr0ll ssh爆破、wireshark流量分析,exp、寻找flag。思维导图带你清晰拿到所以flag
  • 蓝桥杯——————数位排序(java)
  • uniapp自身bug | uniapp+vue3打包后 index.html无法直接运行
  • Android Compose 框架基本状态管理(mutableStateOf、State 接口)深入剖析(十四)
  • Unity将运行时Mesh导出为fbx
  • 基于websocketpp实现的五子棋项目
  • MIPI 详解:XAPP894 D-PHY Solutions
  • 北京交通大学第三届C语言积分赛
  • 新手如何使用 Milvus
  • 大数据学习(83)-数仓建模理论
  • 新版 eslintrc 文件弃用 .eslintignore已弃用 替代方案
  • x-cmd install | Wuzz - Web 开发与安全测试利器,交互式 HTTP 工具
  • 基于javaweb的SpringBoot公司财务管理设计与实现(源码+文档+部署讲解)
  • Linux上位机开发实战(编写API库)