Source code for fetalbrain.alignment.fBAN_v1

"""
This module contains the architecture of the alignment model fBAN_v1 that aligns the input scan
to a reference coordinate system. The rotations in this model are represented by quaternions.
The model consists of predicting coarse and fine alignment parameters after each other.

More details about the architecture can be found in the NeuroImage paper
(`Moser et al. NeuroImage, 2022) <https://www.sciencedirect.com/science/article/pii/S1053811922004608>`_).

[Local] Network code copied from the OMNI cluster: **../../../shared/stru0039/fBAN/v1/model.py**

"""
import torch
import torch.nn as nn
import torch.nn.functional as ff
from .kelluwen_transforms import (
    generate_affine,
    apply_affine,
    deconstruct_affine,
)


[docs] class AlignmentModel(nn.Module): """Model class for the alignment model fBAN which predicts the alignment parameters for a given scan. Args: shape_data: the input shape of the image dimensions_hidden: the dimensions of the hidden layers in the network size_kernel: convolutional kernel size, defaults to 3 :Example: >>> model = AlignmentModel() """ def __init__( self, shape_data: list[int] = [1, 160, 160, 160], dimensions_hidden: list[int] = [16, 32, 64, 128, 256, 128, 64, 32], size_kernel: int = 3, ) -> None: super(AlignmentModel, self).__init__() # Rename for readability dh = dimensions_hidden sd = shape_data sk = size_kernel # Convolutional blocks X1 self.enc_X1_cb00 = self._conv_block(sd[0], dh[0], sk) self.enc_X1_cb01 = self._conv_block(dh[0], dh[1], sk) self.enc_X1_cb02 = self._conv_block(dh[1], dh[2], sk) self.enc_X1_cb03 = self._conv_block(dh[2], dh[3], sk) self.enc_X1_cb04 = self._conv_block(dh[3], dh[4], sk) # Dense layers X1 self.dense_X1_00 = nn.Linear(dh[4] * 10 * 10 * 10, 8) # Convolutional blocks X2 self.enc_X2_cb00 = self._conv_block(sd[0], dh[0], sk) self.enc_X2_cb01 = self._conv_block(dh[0], dh[1], sk) self.enc_X2_cb02 = self._conv_block(dh[1], dh[2], sk) self.enc_X2_cb03 = self._conv_block(dh[2], dh[3], sk) self.enc_X2_cb04 = self._conv_block(dh[3], dh[4], sk) # Dense layers X2 self.dense_X2_00 = nn.Linear(dh[4] * 10 * 10 * 10, 8) # Deterministic self.maxpool = nn.MaxPool3d(2, 2) def _conv_block(self, ch_in: int, ch_out: int, size_kernel: int) -> nn.Sequential: return nn.Sequential( nn.Conv3d(ch_in, ch_out, size_kernel, padding="same"), nn.BatchNorm3d(ch_out), nn.ReLU(), nn.Conv3d(ch_out, ch_out, size_kernel, padding="same"), nn.BatchNorm3d(ch_out), nn.ReLU(), )
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass of the alignment model Args: x: batch of data of size [B, 1, H, W, D] with pixel values between 0 and 255 Returns: translation: translation parameters of size [B, 3] between 0 and 1 rotation: rotation parameters of size [B, 4], represented as quarternions scaling: scaling parameters of size [B, 3] between 0 and 1 :Example: >>> model = AlignmentModel() >>> dummy_scan = torch.rand(1,1,160,160,160) >>> translation, rotation, scaling = model(dummy_scan) """ # Normalize input to [0,1] x = x / 255 # Model X1 hidden = self.enc_X1_cb00(x) hidden = self.enc_X1_cb01(self.maxpool(hidden)) hidden = self.enc_X1_cb02(self.maxpool(hidden)) hidden = self.enc_X1_cb03(self.maxpool(hidden)) hidden = self.enc_X1_cb04(self.maxpool(hidden)) hidden = self.dense_X1_00(torch.flatten(hidden, start_dim=1)) # Parameters X1 translation_X1 = hidden[:, :3] # 3 translation parameter rotation_X1 = ff.normalize(hidden[:, 3:7]) # 4 rotation parameters scaling_X1 = hidden[:, 7:8] # 1 scaling parameter # Transform X1 transform_X1 = generate_affine( parameter_translation=translation_X1 * 160, # Scan shape is (160,160,160) parameter_rotation=rotation_X1, parameter_scaling=scaling_X1.tile(1, 3), type_rotation="quaternions", transform_order="srt", ) # Scan aligned X1 x_aligned_X1 = apply_affine( image=x, transform_affine=transform_X1, type_resampling="bilinear", type_origin="centre", ) # Model X2 hidden = self.enc_X2_cb00(x_aligned_X1) hidden = self.enc_X2_cb01(self.maxpool(hidden)) hidden = self.enc_X2_cb02(self.maxpool(hidden)) hidden = self.enc_X2_cb03(self.maxpool(hidden)) hidden = self.enc_X2_cb04(self.maxpool(hidden)) hidden = self.dense_X2_00(torch.flatten(hidden, start_dim=1)) # Parameters X2 translation_X2 = hidden[:, :3] # 3 translation parameter rotation_X2 = ff.normalize(hidden[:, 3:7]) # 4 rotation parameters scaling_X2 = hidden[:, 7:8] # 1 scaling parameter # Transform X2 transform_X2 = generate_affine( parameter_translation=translation_X2 * 160, # Scan shape is (160,160,160) parameter_rotation=rotation_X2, parameter_scaling=scaling_X2.tile(1, 3), type_rotation="quaternions", transform_order="srt", ) # Combined transform transform = transform_X2 @ transform_X1 # Predicted parameters translation, rotation, scaling = deconstruct_affine( transform_affine=transform, transform_order="srt", type_rotation="quaternions", ) assert isinstance(translation, torch.Tensor) assert isinstance(rotation, torch.Tensor) assert isinstance(scaling, torch.Tensor) return translation / 160, rotation, scaling