Source code for mmv_im2im.models.nets.deeplabv3_2d

# ADAPTED FROM https://pytorch.org/vision/main/_modules/torchvision/models/segmentation/deeplabv3.html  # noqa E501
# and https://discuss.pytorch.org/t/how-to-modify-deeplabv3-and-fcn-models-for-grayscale-images/52688  # noqa E501

import torch
from mmv_im2im.utils.misc import parse_config_func_without_params


[docs]class Net(torch.nn.Module): def __init__( self, backbone, pretrained: bool = False, pretrained_backbone: bool = True, in_channels: int = 3, num_classes: int = 21, aux_loss: bool = None, ): super().__init__() params = { "progress": False, "num_classes": num_classes, "aux_loss": aux_loss, "pretrained": pretrained, "pretrained_backbone": pretrained_backbone, } info = {"module_name": "torchvision.models.segmentation", "params": params} if backbone == "deeplabv3_resnet50": info["func_name"] = "deeplabv3_resnet50" my_func = parse_config_func_without_params(info) self.net = my_func(**info["params"]) elif backbone == "deeplabv3_resnet101": info["func_name"] = "deeplabv3_resnet101" my_func = parse_config_func_without_params(info) self.net = my_func(**info["params"]) elif backbone == "deeplabv3_mobilenet_v3_large": info["func_name"] = "deeplabv3_mobilenet_v3_large" my_func = parse_config_func_without_params(info) self.net = my_func(**info["params"]) if in_channels != 3: self.net.backbone.conv1 = torch.nn.Conv2d( in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False )
[docs] def forward(self, x): y_hat = self.net(x) return y_hat["out"]