Spaces:
Sleeping
Sleeping
# Copyright 2021 DeepMind Technologies Limited | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Rot3Array Matrix Class.""" | |
from __future__ import annotations | |
import dataclasses | |
from typing import List | |
import torch | |
from dockformer.utils.geometry import utils | |
from dockformer.utils.geometry import vector | |
from dockformer.utils.tensor_utils import tensor_tree_map | |
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] | |
class Rot3Array: | |
"""Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" | |
xx: torch.Tensor = dataclasses.field(metadata={'dtype': torch.float32}) | |
xy: torch.Tensor | |
xz: torch.Tensor | |
yx: torch.Tensor | |
yy: torch.Tensor | |
yz: torch.Tensor | |
zx: torch.Tensor | |
zy: torch.Tensor | |
zz: torch.Tensor | |
__array_ufunc__ = None | |
def __getitem__(self, index): | |
field_names = utils.get_field_names(Rot3Array) | |
return Rot3Array( | |
**{ | |
name: getattr(self, name)[index] | |
for name in field_names | |
} | |
) | |
def __mul__(self, other: torch.Tensor): | |
field_names = utils.get_field_names(Rot3Array) | |
return Rot3Array( | |
**{ | |
name: getattr(self, name) * other | |
for name in field_names | |
} | |
) | |
def __matmul__(self, other: Rot3Array) -> Rot3Array: | |
"""Composes two Rot3Arrays.""" | |
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) | |
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) | |
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) | |
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) | |
def map_tensor_fn(self, fn) -> Rot3Array: | |
field_names = utils.get_field_names(Rot3Array) | |
return Rot3Array( | |
**{ | |
name: fn(getattr(self, name)) | |
for name in field_names | |
} | |
) | |
def inverse(self) -> Rot3Array: | |
"""Returns inverse of Rot3Array.""" | |
return Rot3Array( | |
self.xx, self.yx, self.zx, | |
self.xy, self.yy, self.zy, | |
self.xz, self.yz, self.zz | |
) | |
def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: | |
"""Applies Rot3Array to point.""" | |
return vector.Vec3Array( | |
self.xx * point.x + self.xy * point.y + self.xz * point.z, | |
self.yx * point.x + self.yy * point.y + self.yz * point.z, | |
self.zx * point.x + self.zy * point.y + self.zz * point.z | |
) | |
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: | |
"""Applies inverse Rot3Array to point.""" | |
return self.inverse().apply_to_point(point) | |
def unsqueeze(self, dim: int): | |
return Rot3Array( | |
*tensor_tree_map( | |
lambda t: t.unsqueeze(dim), | |
[getattr(self, c) for c in COMPONENTS] | |
) | |
) | |
def stop_gradient(self) -> Rot3Array: | |
return Rot3Array( | |
*[getattr(self, c).detach() for c in COMPONENTS] | |
) | |
def identity(cls, shape, device) -> Rot3Array: | |
"""Returns identity of given shape.""" | |
ones = torch.ones(shape, dtype=torch.float32, device=device) | |
zeros = torch.zeros(shape, dtype=torch.float32, device=device) | |
return cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) | |
def from_two_vectors( | |
cls, e0: vector.Vec3Array, | |
e1: vector.Vec3Array | |
) -> Rot3Array: | |
"""Construct Rot3Array from two Vectors. | |
Rot3Array is constructed such that in the corresponding frame 'e0' lies on | |
the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. | |
Args: | |
e0: Vector | |
e1: Vector | |
Returns: | |
Rot3Array | |
""" | |
# Normalize the unit vector for the x-axis, e0. | |
e0 = e0.normalized() | |
# make e1 perpendicular to e0. | |
c = e1.dot(e0) | |
e1 = (e1 - c * e0).normalized() | |
# Compute e2 as cross product of e0 and e1. | |
e2 = e0.cross(e1) | |
return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) | |
def from_array(cls, array: torch.Tensor) -> Rot3Array: | |
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" | |
rows = torch.unbind(array, dim=-2) | |
rc = [torch.unbind(e, dim=-1) for e in rows] | |
return cls(*[e for row in rc for e in row]) | |
def to_tensor(self) -> torch.Tensor: | |
"""Convert Rot3Array to array of shape [..., 3, 3].""" | |
return torch.stack( | |
[ | |
torch.stack([self.xx, self.xy, self.xz], dim=-1), | |
torch.stack([self.yx, self.yy, self.yz], dim=-1), | |
torch.stack([self.zx, self.zy, self.zz], dim=-1) | |
], | |
dim=-2 | |
) | |
def from_quaternion(cls, | |
w: torch.Tensor, | |
x: torch.Tensor, | |
y: torch.Tensor, | |
z: torch.Tensor, | |
normalize: bool = True, | |
eps: float = 1e-6 | |
) -> Rot3Array: | |
"""Construct Rot3Array from components of quaternion.""" | |
if normalize: | |
inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps)) | |
w = w * inv_norm | |
x = x * inv_norm | |
y = y * inv_norm | |
z = z * inv_norm | |
xx = 1.0 - 2.0 * (y ** 2 + z ** 2) | |
xy = 2.0 * (x * y - w * z) | |
xz = 2.0 * (x * z + w * y) | |
yx = 2.0 * (x * y + w * z) | |
yy = 1.0 - 2.0 * (x ** 2 + z ** 2) | |
yz = 2.0 * (y * z - w * x) | |
zx = 2.0 * (x * z - w * y) | |
zy = 2.0 * (y * z + w * x) | |
zz = 1.0 - 2.0 * (x ** 2 + y ** 2) | |
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) | |
def reshape(self, new_shape): | |
field_names = utils.get_field_names(Rot3Array) | |
reshape_fn = lambda t: t.reshape(new_shape) | |
return Rot3Array( | |
**{ | |
name: reshape_fn(getattr(self, name)) | |
for name in field_names | |
} | |
) | |
def cat(cls, rots: List[Rot3Array], dim: int) -> Rot3Array: | |
field_names = utils.get_field_names(Rot3Array) | |
cat_fn = lambda l: torch.cat(l, dim=dim) | |
return cls( | |
**{ | |
name: cat_fn([getattr(r, name) for r in rots]) | |
for name in field_names | |
} | |
) | |