Source code for fetalbrain.tedsnet_multi.hemisphere_detector

import torch
import torchvision
from pathlib import Path
from typing import Optional, Literal
import numpy as np
from ..model_paths import SIDE_DETECTOR_MODEL_PATH


[docs] def load_sidedetector_model(model_path: Optional[Path] = None) -> torch.nn.Module: """Load the trained side detection model Args: model_path: path to the trained model weights. Defaults to None. Returns: model: ResNet model with trained weights loaded """ if model_path is None: model_path = SIDE_DETECTOR_MODEL_PATH model = torch.hub.load("pytorch/vision", "resnet18", weights=None, num_classes=2) model_weights = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(model_weights) model.eval() torch.set_grad_enabled(False) return model
[docs] def detect_side( aligned_scan: torch.Tensor, model: torch.nn.Module, from_atlas: bool = True ) -> tuple[Literal[0, 1], float]: """_summary_ Takes as input a scan aligned (no scaling) to the atlas or bean coordinate system. Args: aligned_scan: [1,1, H, W, D] model: _description_ Returns: pred: 0 for left, 1 for right probs: probability of the prediction """ if torch.max(aligned_scan) <= 1: aligned_scan *= 255 # this is the orientation it was trained on by Maddy if not from_atlas: midslice = aligned_scan[:, 0, :, :, 79:82].permute(0, 3, 2, 1) # this is equivalent to this in the atlas orientation else: midslice = aligned_scan[:, 0, :, :, 85:88].permute(0, 3, 1, 2) midslice = torch.flip(midslice, dims=[3]) outputs = torch.sigmoid(model(midslice)).detach().cpu().numpy() pred = np.argmax(outputs) return pred, outputs[:, pred] # type: ignore