File size: 7,060 Bytes
bca3a49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#            http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rot3Array Matrix Class."""

from __future__ import annotations
import dataclasses
from typing import List

import torch

from dockformer.utils.geometry import utils
from dockformer.utils.geometry import vector
from dockformer.utils.tensor_utils import tensor_tree_map


COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']

@dataclasses.dataclass(frozen=True)
class Rot3Array:
    """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays."""
    xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32})
    xy: torch.Tensor
    xz: torch.Tensor
    yx: torch.Tensor
    yy: torch.Tensor
    yz: torch.Tensor
    zx: torch.Tensor
    zy: torch.Tensor
    zz: torch.Tensor

    __array_ufunc__ = None

    def __getitem__(self, index):
        field_names = utils.get_field_names(Rot3Array) 
        return Rot3Array(
            **{
                name: getattr(self, name)[index] 
                for name in field_names
            }
        )
    
    def __mul__(self, other: torch.Tensor):
        field_names = utils.get_field_names(Rot3Array)
        return Rot3Array(
            **{
                name: getattr(self, name) * other
                for name in field_names
            }
        )
    
    def __matmul__(self, other: Rot3Array) -> Rot3Array:
        """Composes two Rot3Arrays."""
        c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
        c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
        c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
        return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)

    def map_tensor_fn(self, fn) -> Rot3Array:
        field_names = utils.get_field_names(Rot3Array) 
        return Rot3Array(
            **{
                name: fn(getattr(self, name))
                for name in field_names
            }
        )

    def inverse(self) -> Rot3Array:
        """Returns inverse of Rot3Array."""
        return Rot3Array(
            self.xx, self.yx, self.zx,
            self.xy, self.yy, self.zy,
            self.xz, self.yz, self.zz
        )

    def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
        """Applies Rot3Array to point."""
        return vector.Vec3Array(
            self.xx * point.x + self.xy * point.y + self.xz * point.z,
            self.yx * point.x + self.yy * point.y + self.yz * point.z,
            self.zx * point.x + self.zy * point.y + self.zz * point.z
        )

    def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
        """Applies inverse Rot3Array to point."""
        return self.inverse().apply_to_point(point)


    def unsqueeze(self, dim: int):
        return Rot3Array(
            *tensor_tree_map(
                lambda t: t.unsqueeze(dim), 
                [getattr(self, c) for c in COMPONENTS]
            )
        )

    def stop_gradient(self) -> Rot3Array:
        return Rot3Array(
            *[getattr(self, c).detach() for c in COMPONENTS]
        )

    @classmethod
    def identity(cls, shape, device) -> Rot3Array:
        """Returns identity of given shape."""
        ones = torch.ones(shape, dtype=torch.float32, device=device)
        zeros = torch.zeros(shape, dtype=torch.float32, device=device)
        return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones)

    @classmethod
    def from_two_vectors(
        cls, e0: vector.Vec3Array,
        e1: vector.Vec3Array
    ) -> Rot3Array:
        """Construct Rot3Array from two Vectors.

        Rot3Array is constructed such that in the corresponding frame 'e0' lies on
        the positive x-Axis and 'e1' lies in the xy plane with positive sign of y.

        Args:
            e0: Vector
            e1: Vector
        Returns:
            Rot3Array
        """
        # Normalize the unit vector for the x-axis, e0.
        e0 = e0.normalized()
        # make e1 perpendicular to e0.
        c = e1.dot(e0)
        e1 = (e1 - c * e0).normalized()
        # Compute e2 as cross product of e0 and e1.
        e2 = e0.cross(e1)
        return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z)

    @classmethod
    def from_array(cls, array: torch.Tensor) -> Rot3Array:
        """Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
        rows = torch.unbind(array, dim=-2)
        rc = [torch.unbind(e, dim=-1) for e in rows]
        return cls(*[e for row in rc for e in row])

    def to_tensor(self) -> torch.Tensor:
        """Convert Rot3Array to array of shape [..., 3, 3]."""
        return torch.stack(
            [
                torch.stack([self.xx, self.xy, self.xz], dim=-1),
                torch.stack([self.yx, self.yy, self.yz], dim=-1),
                torch.stack([self.zx, self.zy, self.zz], dim=-1)
            ],
            dim=-2
        )

    @classmethod
    def from_quaternion(cls,
        w: torch.Tensor,
        x: torch.Tensor,
        y: torch.Tensor,
        z: torch.Tensor,
        normalize: bool = True,
        eps: float = 1e-6
    ) -> Rot3Array:
        """Construct Rot3Array from components of quaternion."""
        if normalize:
            inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
            w = w * inv_norm
            x = x * inv_norm
            y = y * inv_norm
            z = z * inv_norm
        xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
        xy = 2.0 * (x * y - w * z)
        xz = 2.0 * (x * z + w * y)
        yx = 2.0 * (x * y + w * z)
        yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
        yz = 2.0 * (y * z - w * x)
        zx = 2.0 * (x * z - w * y)
        zy = 2.0 * (y * z + w * x)
        zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
        return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)

    def reshape(self, new_shape):
        field_names = utils.get_field_names(Rot3Array) 
        reshape_fn = lambda t: t.reshape(new_shape)
        return Rot3Array(
            **{
                name: reshape_fn(getattr(self, name))
                for name in field_names
            }
        )

    @classmethod
    def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array:
        field_names = utils.get_field_names(Rot3Array) 
        cat_fn = lambda l: torch.cat(l, dim=dim)
        return cls(
            **{
                name: cat_fn([getattr(r, name) for r in rots])
                for name in field_names
            }
        )