Segmentation network module

This module contains the 3D Unet architecture for multiclass subcortical segmentation, as described in (Hesse et al. NeuroImage, 2022).

class fetalbrain.structural_segmentation.segmentation_model.UNet(n_channels: int, n_classes: int = 5, min_featuremaps: int = 64, depth: int = 5, transposed_conv: bool = False)[source]

Bases: Module

3D Unet architecture for multiclass subcortical segmentation, returns the logits of the prediction before any last layer activation. In practice, a soft-max activation should be applied to the channel dimension of the logits to obtain a multi-class prediction.

Parameters:
  • n_channels – number of channel in the input image

  • n_classes – number of output classes in the prediction. Defaults to 5.

  • min_featuremaps – number of feature maps in the first encoder block. Defaults to 64.

  • depth – depth of the unet architecture. Defaults to 5.

  • transposed_conv – whether to use transposed convolutions for upsampling. Defaults to False.

Example

>>> input_im = torch.rand((1, 1, 160, 160, 160)) * 255
>>> model = UNet(1, 5, min_featuremaps=16, depth=5)
>>> output = model(input_im)
>>> assert output.shape == (1, 5, 160, 160, 160)
forward(x: Tensor) Tensor[source]

Makes a prediction for a batch of images

Parameters:

x – tensor with input image of size [B, C_in, H, W, D] with pixel values between 0 and 255

Returns:

logits – tensor with output prediction of size [B, C_out, H, W, D]