Source code for fetalbrain.alignment.kelluwen_transforms

"""
This module contains helper functions to apply alignment parameters resulting from fBAN to an image. It is likely not
necessary to access these functions directly when doing standard image alignment operations, as most are shadowed
by functions in the main align.py module.

These functions have been taken from the
`Kelluwen Github <https://github.com/FelipeMoser/kelluwen/blob/main/kelluwen/functions/transforms.py>`_.

The functions were copied over to minimise dependendencies, and their style adjusted to match the rest of the codebase.

Module functions
----------------

"""

from typing import Union, overload, Literal, Optional, TypedDict
import torch
from typeguard import typechecked


ORDER_TYPES = Literal["trs", "tsr", "rts", "rst", "str", "srt"]
ROT_TYPES = Literal["euler_xyz", "euler_xzy", "euler_yxz", "euler_yzx", "euler_zxy", "euler_zyx", "quaternions"]
RETURN_TYPES = Literal["positional", "named"]
ORIGIN_TYPES = Literal["centre", "origin"]


[docs] @typechecked def deconstruct_affine( transform_affine: torch.Tensor, transform_order: ORDER_TYPES = "srt", type_rotation: ROT_TYPES = "euler_xyz", type_output: RETURN_TYPES = "positional", ) -> Union[tuple, dict[str, torch.Tensor]]: """Deconstructs the affine transform into its conforming translation, rotation, and scaling parameters. Args: transform_affine: Affine transform being deconstructed. Must be of shape (B, C, H, W, D) or (B, H, W, D). transform_order : Order of multiplication of translation, rotation, and scaling transforms, defaults to 'srt' type_rotation: Type of rotation parameters: quaternions or Euler angles. For Euler angles, the order of the multiplication of the rotations around x, y, and z is represented in the name (euler_xyz, euler_yzx, etc.), defaults to 'euler_xyz' type_output : Determines how the outputs are returned. If set to positional, it returns positional outputs. If set to named, it returns a dictionary with named outputs, defaults to 'positional' Returns: parameter_translation: tensor of size (B, C, 3) (channel dimension is optional, based on whether channel dimension is present in input affine) parameter_rotation : tensor of size (B, C, 3) (if type_rotation is euler) or (B, C, 4) (if type_rotation is 'quaternions') parameter_scaling : tensor of size [B, C, 3] Example: >>> transform_affine = torch.eye(4,4, dtype=torch.float32).unsqueeze(0) >>> transl, rot, scale = deconstruct_affine(transform_affine, transform_order='srt', type_rotation='euler_xyz',\ type_output='positional') """ # Validate arguments if transform_affine.dim() not in (3, 4): raise ValueError(f"expected a 3D or 4D transform_affine, got {transform_affine.dim()!r}D instead") if transform_affine.shape[-2:] not in ((3, 3), (4, 4)): raise ValueError(f"unexpected shape of transform_affine {transform_affine.shape!r}") # Update variables if required if transform_affine.dim() == 4: channel_dimension = True elif transform_affine.dim() == 3: channel_dimension = False transform_affine = transform_affine[:, None, ...] # Extract scaling parameters if transform_order in ("srt, str, tsr"): parameter_scaling = transform_affine[..., :-1, :-1].norm(dim=3) else: parameter_scaling = transform_affine[..., :-1, :-1].norm(dim=2) # Extract scaling transform transform_scaling = generate_scaling(parameter_scaling, type_output="positional") assert isinstance(transform_scaling, torch.Tensor) # Extract rotation transform if transform_order in ("srt, str, tsr"): transform_rotation = transform_scaling.inverse() @ transform_affine else: transform_rotation = transform_affine @ transform_scaling.inverse() transform_rotation[..., :-1, -1] = 0 # Extract translation transform if transform_order in ("trs", "tsr"): transform_translation = torch.eye(4).tile((*transform_affine.shape[:2], 1, 1)) transform_translation[..., :-1, -1] = transform_affine[..., :-1, -1] elif transform_order == "str": transform_translation = torch.eye(4).tile((*transform_affine.shape[:2], 1, 1)) transform_translation[..., :-1, -1] = (transform_scaling.inverse() @ transform_affine)[..., :-1, -1] elif transform_order == "rts": transform_translation = transform_rotation.inverse() @ transform_affine @ transform_scaling.inverse() elif transform_order == "rst": transform_translation = transform_scaling.inverse() @ transform_rotation.inverse() @ transform_affine elif transform_order == "srt": transform_translation = transform_rotation.inverse() @ transform_scaling.inverse() @ transform_affine # Extract translation parameters parameter_translation = transform_translation[..., :-1, -1] # Extract rotation parameters if transform_rotation.shape[2] == 2: # 2D rotation parameter_rotation = torch.asin(transform_rotation[..., 1, 0]) else: # 3D rotation if type_rotation == "quaternions": # Extract quaternions from rotation transform # This section has been adapted to pytorch from nibabel's implementation: # htorchps://nipy.org/nibabel/reference/nibabel.quaternions.html transform_rotation = transform_rotation[..., :-1, :-1].flatten(start_dim=-2) Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz = [ transform_rotation[..., i] for i in range(transform_rotation.shape[-1]) ] K = torch.eye(4).tile(*transform_rotation.shape[:-1], 1, 1) K[..., 0, 0] = Qxx - Qyy - Qzz K[..., 1, 0] = Qyx + Qxy K[..., 1, 1] = Qyy - Qxx - Qzz K[..., 2, 0] = Qzx + Qxz K[..., 2, 1] = Qzy + Qyz K[..., 2, 2] = Qzz - Qxx - Qyy K[..., 3, 0] = Qyz - Qzy K[..., 3, 1] = Qzx - Qxz K[..., 3, 2] = Qxy - Qyx K[..., 3, 3] = Qxx + Qyy + Qzz K /= 3 vals, vecs = torch.linalg.eigh(K) q = vecs[..., [3, 0, 1, 2], :] parameter_rotation = torch.zeros([*transform_rotation.shape[:2], 4], device=transform_affine.device) idx = torch.argmax(vals, dim=-1) for i in range(q.shape[0]): for j in range(q.shape[1]): parameter_rotation[i, j] = q[i, j, :, idx[i, j]] if parameter_rotation[i, j, 0] < 0: parameter_rotation[i, j] *= -1 else: # Get indices for the necessary transform components EulerDict = TypedDict( "EulerDict", {"s": int, "alpha": list[list[int]], "beta": list[int], "gamma": list[list[int]]} ) euler_idx: dict[str, EulerDict] = dict( euler_xyz=dict(s=-1, alpha=[[1, 2], [2, 2]], beta=[0, 2], gamma=[[0, 1], [0, 0]]), euler_xzy=dict(s=1, alpha=[[2, 1], [1, 1]], beta=[0, 1], gamma=[[0, 2], [0, 0]]), euler_yxz=dict(s=1, alpha=[[0, 2], [2, 2]], beta=[1, 2], gamma=[[1, 0], [1, 1]]), euler_yzx=dict(s=-1, alpha=[[2, 0], [0, 0]], beta=[1, 0], gamma=[[1, 2], [1, 1]]), euler_zxy=dict(s=-1, alpha=[[0, 1], [1, 1]], beta=[2, 1], gamma=[[2, 0], [2, 2]]), euler_zyx=dict(s=1, alpha=[[1, 0], [0, 0]], beta=[2, 0], gamma=[[2, 1], [2, 2]]), ) idx_s = euler_idx[type_rotation]["s"] idx_alpha = euler_idx[type_rotation]["alpha"] idx_beta = euler_idx[type_rotation]["beta"] idx_gamma = euler_idx[type_rotation]["gamma"] # Get Euler angles alpha = torch.atan2( idx_s * transform_rotation[..., idx_alpha[0][0], idx_alpha[0][1]], transform_rotation[..., idx_alpha[1][0], idx_alpha[1][1]], ) beta = -idx_s * torch.asin(transform_rotation[..., idx_beta[0], idx_beta[1]]) gamma = torch.atan2( idx_s * transform_rotation[..., idx_gamma[0][0], idx_gamma[0][1]], transform_rotation[..., idx_gamma[1][0], idx_gamma[1][1]], ) parameter_rotation = torch.stack([alpha, beta, gamma], dim=2) # Remove channels if required if channel_dimension is False: parameter_scaling = parameter_scaling[:, 0, :] parameter_translation = parameter_translation[:, 0, :] parameter_rotation = parameter_rotation[:, 0, :] # Return results if type_output == "positional": return parameter_translation, parameter_rotation, parameter_scaling else: return { "parameter_translation": parameter_translation, "parameter_rotation": parameter_rotation, "parameter_scaling": parameter_scaling, }
[docs] @typechecked def apply_affine( image: torch.Tensor, transform_affine: torch.Tensor, shape_output: Optional[Union[torch.Size, list[int]]] = None, type_resampling: Literal["bilinear", "nearest"] = "bilinear", type_origin: ORIGIN_TYPES = "centre", type_output: RETURN_TYPES = "positional", ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """Applies affine transform to tensor. Args: image: image being transformed. Must be of shape (B, C, `*`). transform_affine: affine transform being applied. Must be of shape (B, C, `*`), (B, 1, `*`), or (B, `*`). shape_output: Output shape of transformed image. Must have the same batch and channel as image. If None, the output_shape=image.shape. Defaults to None. type_resampling: interpolation algorithm used when sampling image. Available: bilinear, nearest type_origin: point around which the transform is applied, defaults to centre type_output: Determines how the outputs are returned. If set to positional, it returns positional outputs. If set to named, it returns a dictionary with named outputs. Returns: image_transformed Example: >>> image = torch.rand((1, 1, 160, 160, 160)) >>> identity_affine = torch.eye(4, dtype=torch.float32).unsqueeze(0).unsqueeze(0) >>> image_transformed = apply_affine(image, identity_affine) """ # Validate arguments if image.dim() not in (4, 5): raise ValueError(f"expected a 4D or 5D image, got {image.dim()!r}D instead") if transform_affine.dim() not in (3, 4): raise ValueError(f"expected a 3D or 4D transform_affine, got {transform_affine.dim()!r}D instead") if transform_affine.shape[0] != image.shape[0]: raise ValueError("transform_affine.shape doesn't match image.shape") if transform_affine.dim() - 2 == image.dim() - 3: if transform_affine.shape[0] != image.shape[0]: raise ValueError("transform_affine.shape doesn't match image.shape") if transform_affine.shape[-2:] != (*[image.dim() - 1] * 2,): raise ValueError("transform_affine.shape doesn't match image.shape") if shape_output is not None: if image.dim() != len(shape_output) or image.shape[:2] != shape_output[:2]: raise ValueError("shape_output doesn't match image.shape") if shape_output is None: shape_output = image.shape if transform_affine.dim() == 3: transform_affine = transform_affine[:, None, :, :] if image.type() not in (torch.float, torch.double): image = image.float() if transform_affine.type() != (torch.float, torch.double): transform_affine = transform_affine.float() # Translate origin if required if type_origin == "centre": transform_origin = torch.eye(image.dim() - 1).to(transform_affine.device) transform_origin = transform_origin.tile(*transform_affine.shape[:2], 1, 1) transform_origin[..., :-1, -1] = -(torch.tensor(image.shape[2 : image.dim()]) - 1) / 2 transform_affine = transform_origin.inverse() @ transform_affine @ transform_origin # Generate transformed coordinates transform_affine_n = transform_affine.inverse()[..., :-1, :] # changed to xy should be ij coordinates_seq = torch.meshgrid(*(torch.arange(s) for s in shape_output[2:]), indexing="ij") coordinates = torch.stack((*coordinates_seq, torch.ones(*shape_output[2:]))).to(image.device) coordinates = transform_affine_n @ (coordinates.reshape((1, 1, image.dim() - 1, -1))) # Prepare indices for readability batch = torch.arange(shape_output[0])[:, None, None] channel = torch.arange(shape_output[1])[None, :, None] x = coordinates[..., 0, :] y = coordinates[..., 1, :] if image.dim() == 5: z = coordinates[:, :, 2, :] # Find transformed coordinates that lie outside image mask = ~(torch.any(coordinates < 0, dim=2) | (x > image.shape[2] - 1) | (y > image.shape[3] - 1)) if image.dim() == 5: mask = mask & ~(z > image.shape[4] - 1) # Clip coordinates outside image coordinates *= mask[:, :, None, :] # Resample if type_resampling == "nearest": # Prepare indices and weights for readability c0 = lambda x: (x.ceil() - 1).long() # noqa: E731 c1 = lambda x: x.ceil().long() # noqa: E731 w0 = lambda x: x.ceil() - x.round() # noqa: E731 w1 = lambda x: x.round() - (x.ceil() - 1) # noqa: E731 elif type_resampling == "bilinear": # Prepare indices and weights for readability c0 = lambda x: (x.ceil() - 1).long() # noqa: E731 c1 = lambda x: x.ceil().long() # noqa: E731 w0 = lambda x: x.ceil() - x # noqa: E731 w1 = lambda x: x - (x.ceil() - 1) # noqa: E731 # Sample transformed image if image.dim() == 4: image_transformed = ( image[batch, channel, c0(x), c0(y)] * (w0(x) * w0(y)) + image[batch, channel, c1(x), c0(y)] * (w1(x) * w0(y)) + image[batch, channel, c0(x), c1(y)] * (w0(x) * w1(y)) + image[batch, channel, c1(x), c1(y)] * (w1(x) * w1(y)) ) else: image_transformed = ( image[batch, channel, c0(x), c0(y), c0(z)] * (w0(x) * w0(y) * w0(z)) + image[batch, channel, c1(x), c0(y), c0(z)] * (w1(x) * w0(y) * w0(z)) + image[batch, channel, c0(x), c1(y), c0(z)] * (w0(x) * w1(y) * w0(z)) + image[batch, channel, c1(x), c1(y), c0(z)] * (w1(x) * w1(y) * w0(z)) + image[batch, channel, c0(x), c0(y), c1(z)] * (w0(x) * w0(y) * w1(z)) + image[batch, channel, c1(x), c0(y), c1(z)] * (w1(x) * w0(y) * w1(z)) + image[batch, channel, c0(x), c1(y), c1(z)] * (w0(x) * w1(y) * w1(z)) + image[batch, channel, c1(x), c1(y), c1(z)] * (w1(x) * w1(y) * w1(z)) ) # Mask transformed image image_transformed *= mask # Reshape transformed image image_transformed = image_transformed.reshape(shape_output) # Return results if type_output == "positional": return image_transformed else: return {"image": image_transformed}
@overload def generate_affine( parameter_translation: torch.Tensor, parameter_rotation: torch.Tensor, parameter_scaling: torch.Tensor, type_output: Literal["positional"] = "positional", type_rotation: ROT_TYPES = "euler_xyz", transform_order: ORDER_TYPES = "trs", ) -> torch.Tensor: ... @overload def generate_affine( parameter_translation: torch.Tensor, parameter_rotation: torch.Tensor, parameter_scaling: torch.Tensor, type_output: Literal["named"], type_rotation: ROT_TYPES = "euler_xyz", transform_order: ORDER_TYPES = "trs", ) -> dict[str, torch.Tensor]: ...
[docs] @typechecked def generate_affine( parameter_translation: torch.Tensor, parameter_rotation: torch.Tensor, parameter_scaling: torch.Tensor, type_output: RETURN_TYPES = "positional", type_rotation: ROT_TYPES = "euler_xyz", transform_order: ORDER_TYPES = "trs", ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """Generates an affine transform from translation, rotation, and scaling parameters. Args: parameter_translation: Translation parameters in pixels between -80 and 80. Must be of shape (B, C, parameters) or (B, parameters), with parameters=2 or 3 for 2D and 3D images, respectively. parameter_rotation: Rotation parameters in radians. Must be of shape (B, C, parameters) or (B, parameters), with parameters=1, 3 or 4, for 2D, 3D Euler angles, and 3D quaternions, respectively. parameter_scaling : Scaling parameters. Must be of shape (B, C, parameters) or (B, parameters), with parameters=2 or 3 for 2D and 3D images, respectively. type_rotation: Type of rotation parameters: quaternions or Euler angles. For Euler angles, the order of the multiplication of the rotations around x, y, and z is represented in the name (euler_xyz, euler_yzx, etc.), defaults to "euler_xyz" transform_order : Order of multiplication of translation, rotation, and scaling transforms, defaults to "trs" type_output: Determines how the outputs are returned. If set to "positional", it returns positional outputs. If set to "named", it returns a dictionary with named outputs. Defaults to 'positional'. Returns: transform_affine : torch.Tensor of shape (B, C, 4, 4) or dictionary: {"transform_affine": transform_affine} Example: >>> parameter_translation = torch.rand((1, 3)) >>> parameter_rotation = torch.rand((1, 4)) >>> parameter_scaling = torch.rand((1, 3)) >>> transform_affine = generate_affine(parameter_translation, parameter_rotation, parameter_scaling,\ type_rotation='quaternions') """ # Validate arguments if ( parameter_translation.shape[:-1] != parameter_rotation.shape[:-1] or parameter_translation.shape[:-1] != parameter_scaling.shape[:-1] ): raise ValueError("mismatched shape of parameters") if ( parameter_translation.device != parameter_rotation.device or parameter_translation.device != parameter_scaling.device ): raise ValueError("mismatched devices of parameters") # Generate required transforms transform_translation = generate_translation(parameter_translation) transform_rotation = generate_rotation(parameter_rotation, type_rotation) transform_scaling = generate_scaling(parameter_scaling) assert isinstance(transform_translation, torch.Tensor) assert isinstance(transform_rotation, torch.Tensor) assert isinstance(transform_scaling, torch.Tensor) # Sort order of operations key = (transform_order.index(x) for x in ("t", "r", "s")) operations = (transform_translation, transform_rotation, transform_scaling) operations_sorted = [x for _, x in sorted(zip(key, operations))] # Generate affine transform transform_affine = operations_sorted[0] @ operations_sorted[1] @ operations_sorted[2] # Return results if type_output == "positional": return transform_affine else: return {"transform_affine": transform_affine}
[docs] def generate_translation( parameter_translation: torch.Tensor, type_output: RETURN_TYPES = "positional", ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """Generates a translation transform from translation parameters. Args: parameter_translation: Translation parameters. Must be of shape (B, C, parameters) or (B, parameters), with parameters=2 or 3 for 2D and 3D images, respectively. type_output: Determines how the outputs are returned. If set to "positional", it returns positional outputs. If set to "named", it returns a dictionary with named outputs, defaults to "positional" Raises: ValueError: An error occured because dimension of batched translation parameters is >3 ValueError: An error occured because length of translation parameters is not 2 or 3 ValueError: An error occured because type_output is not "positional" or "named" Returns: transform_translation: torch.Tensor of shape (B, C, 4, 4) or dictionary: {"transform_affine": transform_affine} Example: >>> parameter_translation = torch.rand((1, 3)) >>> transform_translation = generate_translation(parameter_translation) """ # Validate arguments if parameter_translation.dim() not in (1, 2, 3): raise ValueError( f"expected a 1D, 2D, or 3D parameter_translation, got {parameter_translation.dim()!r}D instead" ) if parameter_translation.shape[-1] not in (2, 3): raise ValueError("unexpected shape of parameter_translation") # Update variables if required device = parameter_translation.device # Generate scaling transform transform_tiling = (*parameter_translation.shape[:-1], 1, 1) transform_translation = torch.eye(parameter_translation.shape[-1] + 1, device=device) transform_translation = transform_translation.tile(transform_tiling) # Populate translation transform transform_translation[..., :-1, -1] = parameter_translation # Return results if type_output == "positional": return transform_translation else: return {"transform_translation": transform_translation}
[docs] def generate_scaling( parameter_scaling: torch.Tensor, type_output: RETURN_TYPES = "positional", ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """Generates a scaling transform from scaling parameters. Args: parameter_scaling: Scaling parameters. Must be of shape (B, C, parameters) or (B, parameters), with parameters=2 or 3 for 2D and 3D images, respectively. type_output: Determines how the outputs are returned. If set to "positional", it returns positional outputs. If set to "named", it returns a dictionary with named outputs. Defaults to positional. Returns: transform_scaling: tensor of shape (B, C, 4, 4) Example: >>> parameter_scaling = torch.rand((1, 3)) >>> transform_scaling = generate_scaling(parameter_scaling) """ # Validate arguments if parameter_scaling.dim() not in (1, 2, 3): raise ValueError(f"expected a 1D, 2D, or 3D parameter_scaling, got {parameter_scaling.dim()!r}D instead") if parameter_scaling.shape[-1] not in (2, 3): raise ValueError("unexpected shape of parameter_scaling") device = parameter_scaling.device # Generate scaling transform transform_tiling = (*parameter_scaling.shape[:-1], 1, 1) transform_scaling = torch.eye(parameter_scaling.shape[-1] + 1, device=device) transform_scaling = transform_scaling.tile(transform_tiling) # Populate scaling transform for i in range(parameter_scaling.shape[-1]): transform_scaling[..., i, i] = parameter_scaling[..., i] # Return results if type_output == "positional": return transform_scaling else: return {"transform_scaling": transform_scaling}
[docs] def generate_rotation( parameter_rotation: torch.Tensor, type_rotation: ROT_TYPES = "euler_xyz", type_output: RETURN_TYPES = "positional", ) -> Union[torch.Tensor, dict[str, torch.Tensor]]: """Generates a rotation transform from rotation parameters. Args: parameter_rotation : Rotation parameters. Must be of shape (B, C, parameters) or (B, parameters), with parameters=1, 3 or 4, for 2D, 3D Euler angles, and 3D quaternions, respectively. type_rotation: Type of rotation parameters: quaternions or Euler angles. For Euler angles, the order of the multiplication of the rotations around x, y, and z is represented in the name (euler_xyz, euler_yzx, etc.) This variable with be ignored for 2D rotations. defaults to euler_xyz type_output : Determines how the outputs are returned. If set to "positional", it returns positional outputs. If set to "named", it returns a dictionary with named outputs. Defaults to positional Returns: transform_rotation: tensor of shape (B, C, 4, 4) Example: # Rotation transform for quaternions >>> parameter_rotation = torch.rand((1, 4)) >>> transform_rotation = generate_rotation(parameter_rotation, type_rotation="quaternions") # Rotation transform for Euler angles >>> parameter_rotation = torch.rand((1, 3)) >>> transform_rotation = generate_rotation(parameter_rotation, type_rotation="euler_xyz") """ # Validate arguments if parameter_rotation.dim() not in (1, 2, 3): raise ValueError(f"expected a 1D, 2D, or 3D parameter_rotation, got {parameter_rotation.dim()!r}D instead") if parameter_rotation.shape[-1] not in (1, 3, 4): raise ValueError("unexpected shape of parameter_scaling") if parameter_rotation.shape[-1] != 1: if type_rotation[:5] == "euler" and parameter_rotation.shape[-1] != 3: raise ValueError("mismatch between type_rotation and shape of parameter_rotation") if type_rotation == "quaternions" and parameter_rotation.shape[-1] != 4: raise ValueError("mismatch between type_rotation and shape of parameter_rotation") device = parameter_rotation.device # Generate identity transform transform_tiling = (*parameter_rotation.shape[:-1], 1, 1) transform_identity = torch.eye(4 - (parameter_rotation.shape[-1] == 1), device=device) transform_identity = transform_identity.tile(transform_tiling) # Generate and populate 2D rotation transform if parameter_rotation.shape[-1] == 1: transform_rotation = transform_identity transform_rotation[..., 0, 0] = torch.cos(parameter_rotation[..., 0]) transform_rotation[..., 0, 1] = -torch.sin(parameter_rotation[..., 0]) transform_rotation[..., 1, 0] = torch.sin(parameter_rotation[..., 0]) transform_rotation[..., 1, 1] = torch.cos(parameter_rotation[..., 0]) # Generate and populate 3D Euler rotation transform elif parameter_rotation.shape[-1] == 3: # Define rotation transform indices for rotation around x, y, and z index = { "x": torch.tensor([[1, 1, 2, 2], [1, 2, 1, 2], [1, -1, 1, 1]]), "y": torch.tensor([[0, 0, 2, 2], [0, 2, 0, 2], [1, 1, -1, 1]]), "z": torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1], [1, -1, 1, 1]]), } # Populate rotation transforms transform_rotation = transform_identity.clone() for i in range(3): i0 = index[type_rotation[6 + i]][0] i1 = index[type_rotation[6 + i]][1] q0 = index[type_rotation[6 + i]][2].to(device) angle = parameter_rotation[..., i] transform_temp = transform_identity.clone() transform_temp[..., i0, i1] = q0 * torch.stack( [torch.cos(angle), torch.sin(angle), torch.sin(angle), torch.cos(angle)], dim=-1 ) transform_rotation = transform_rotation.matmul(transform_temp) # Generate and populate 3D Quaternion rotation transform else: # Check if quaternions are normalised if torch.any(parameter_rotation.norm(dim=-1) < 1e-5): raise ValueError( f"parameter_rotation of type Quaternion must be normalised. Got parameter_rotation.norm(dim=-1)\ ={parameter_rotation.norm(dim=-1)} instead", ) # Separate quaternion components for readability if parameter_rotation.dim() == 3: q0, q1, q2, q3 = parameter_rotation.permute(dims=(2, 0, 1)) else: q0, q1, q2, q3 = parameter_rotation.permute(dims=(1, 0)) # Generate rotation transform transform_rotation = transform_identity transform_rotation[..., 0, 0] = 1 - 2 * (q2**2 + q3**2) transform_rotation[..., 0, 1] = 2 * (q1 * q2 - q3 * q0) transform_rotation[..., 0, 2] = 2 * (q1 * q3 + q2 * q0) transform_rotation[..., 1, 0] = 2 * (q1 * q2 + q3 * q0) transform_rotation[..., 1, 1] = 1 - 2 * (q1**2 + q3**2) transform_rotation[..., 1, 2] = 2 * (q2 * q3 - q1 * q0) transform_rotation[..., 2, 0] = 2 * (q1 * q3 - q2 * q0) transform_rotation[..., 2, 1] = 2 * (q2 * q3 + q1 * q0) transform_rotation[..., 2, 2] = 1 - 2 * (q1**2 + q2**2) # Return results if type_output == "positional": return transform_rotation else: return {"transform_rotation": transform_rotation}