""" How to apply our segmentation method
"""
import torch
import numpy as np
from typing import Optional, Literal
from pathlib import Path
from fetalbrain.tedsnet_multi.network.TEDS_Net import TEDS_Net
from fetalbrain.alignment.align import prepare_scan
from fetalbrain.utils import read_image
from fetalbrain.tedsnet_multi.hemisphere_detector import load_sidedetector_model, detect_side
from ..model_paths import TEDS_MULTI_MODEL_PATH, PRIOR_SHAPE_PATH
[docs]
def load_tedsmulti_model(model_path: Optional[Path] = None) -> TEDS_Net:
"""Load the trained multistructure segmentation model
Args:
model_path: path to the trained model weights
Returns:
model: segmentation model with trained weights loaded
Example:
>>> model = load_tedsmulti_model()
"""
if model_path is None:
model_path = TEDS_MULTI_MODEL_PATH
model = TEDS_Net()
if torch.cuda.is_available():
model_weights = torch.load(model_path)
else:
model_weights = torch.load(model_path, map_location=torch.device('cpu'))
model.load_state_dict(model_weights)
model.eval()
torch.set_grad_enabled(False)
return model
[docs]
def get_prior_shape_sa(sd: Literal[0, 1]) -> torch.Tensor:
"""Get the prior paired with each week and side
Args:
sd: which side to get the prior shape for, either 0 or 1
Returns:
prior_shape: tensor containing the prior shape
Example:
>>> prior_shape = get_prior_shape_sa(0)
"""
assert sd in [0, 1], "sd should be either 0 or 1"
# Load in shape prior
pshape, _ = read_image(PRIOR_SHAPE_PATH)
# correct for permuted orientation
pshape_per = np.swapaxes(pshape, 0, 2).astype(int)
# set the invisible hemisphere to zero, except for the cavum (because it is around the midplane)
cavum = np.where(pshape_per == 2, 1, 0)
if sd == 0:
pshape_per[:, 80:160, :] = 0
elif sd == 1:
pshape_per[:, 0:80, :] = 0
pshape_per = np.where(cavum == 1, 2, pshape_per)
# One hot the labels
nclass = 10
one_hot = np.zeros((nclass, pshape_per.shape[0], pshape_per.shape[1], pshape_per.shape[2]))
for i in range(1, nclass + 1):
one_hot[i - 1, :, :, :][pshape_per == i] = 1
return torch.from_numpy(one_hot.astype(np.float32))
[docs]
def generate_multiclass_prediction(prediction: torch.Tensor) -> np.ndarray:
"""Convert the TEDS-multiclass output into a multiclass segmentation mask
Note: I don't think this works for batches atm
Args:
prediction: prediction from TEDS model of size [B, 10, H, W, D]
Returns:
combined_pred: multiclass segmentation mask of size [H, W, D]
"""
# due to tedsnets design we approach is as binary for each channel and threshold at 0.4
pred = (prediction > 0.4).int().squeeze().cpu().numpy()
# we then combine the predictions into a single multiclass image
combined_pred = np.zeros_like(pred[0])
# loop through the channels
for i, ch in enumerate(range(np.shape(pred)[0])):
combined_pred = np.where(pred[ch] == 1, i + 1, combined_pred)
return combined_pred
[docs]
def segment_tedsall(
aligned_scan: torch.Tensor, segm_model: TEDS_Net, side: Literal[0, 1] = 0
) -> tuple[np.ndarray, dict]:
"""_summary_
Args:
aligned_scan: _description_
segm_model: _description_
side: _description_. Defaults to "0".
Returns:
_description_
"""
aligned_scan_per = aligned_scan.permute(0, 1, 4, 3, 2)
# get the prior shape
prior = torch.unsqueeze(get_prior_shape_sa(side), 0).to(aligned_scan_per.device)
# forward pass
logits, _ = segm_model(aligned_scan_per, prior)
# convert to multiclass [B, H, W, D]
multiclass = generate_multiclass_prediction(logits.permute(0, 1, 4, 3, 2))
# define key maps of model output
key_maps = {
"Cortical Plate": 1,
"Cavum Septum": 2,
"Cerebellum": 3,
"Choriod Plex": 4,
"Ventricle": 5,
"DGM": 6,
"Thalamus": 7,
"Brainstem": 8,
"WM": 9,
"Frontal Horns": 10,
}
return multiclass, key_maps
[docs]
def segment_scan_tedsall(aligned_scan: torch.Tensor) -> tuple[np.ndarray, dict]:
"""Executes the whole TEDSall segmentation pipeline
Args:
aligned_scan: _description_
Returns:
_description_
"""
segm_model = load_tedsmulti_model().to(aligned_scan.device)
aligned_scan = prepare_scan(aligned_scan)
# side is now hardcode, this has to change into model prediction
side_model = load_sidedetector_model()
side, _ = detect_side(aligned_scan, side_model)
multiclass, keys = segment_tedsall(aligned_scan, segm_model, side=side)
return multiclass.squeeze(), keys