# adapted from https://github.com/juglab/EmbedSeg/tree/main/EmbedSeg/criterions
import torch
import torch.nn as nn
from mmv_im2im.utils.lovasz_losses import lovasz_hinge
[docs]class SpatialEmbLoss_3d(nn.Module):
def __init__(
self,
grid_z=32,
grid_y=1024,
grid_x=1024,
pixel_z=1,
pixel_y=1,
pixel_x=1,
n_sigma=3,
foreground_weight=10,
use_costmap=False,
):
super().__init__()
print(
f"Created spatial emb loss function with: n_sigma: {n_sigma},"
f"foreground_weight: {foreground_weight}"
)
print("*************************")
self.n_sigma = n_sigma
self.foreground_weight = foreground_weight
xm = (
torch.linspace(0, pixel_x, grid_x)
.view(1, 1, 1, -1)
.expand(1, grid_z, grid_y, grid_x)
)
ym = (
torch.linspace(0, pixel_y, grid_y)
.view(1, 1, -1, 1)
.expand(1, grid_z, grid_y, grid_x)
)
zm = (
torch.linspace(0, pixel_z, grid_z)
.view(1, -1, 1, 1)
.expand(1, grid_z, grid_y, grid_x)
)
xyzm = torch.cat((xm, ym, zm), 0)
self.register_buffer("xyzm", xyzm)
self.use_costmap = use_costmap
[docs] def forward(
self,
prediction,
instances,
labels,
center_images,
costmaps=None,
w_inst=1,
w_var=10,
w_seed=1,
):
# instances B 1 Z Y X
batch_size, depth, height, width = (
prediction.size(0),
prediction.size(2),
prediction.size(3),
prediction.size(4),
)
# weighted loss
if self.use_costmap:
# only need to adjust instances in this step, because for pixels
# with zero weight, this step will ignore the corresponding instances
instances_adjusted = instances * costmaps
else:
instances_adjusted = instances
xyzm_s = self.xyzm[:, 0:depth, 0:height, 0:width].contiguous() # 3 x d x h x w
loss = 0
for b in range(0, batch_size):
spatial_emb = torch.tanh(prediction[b, 0:3]) + xyzm_s # 3 x d x h x w
sigma = prediction[b, 3 : 3 + self.n_sigma] # n_sigma x d x h x w
seed_map = torch.sigmoid(
prediction[b, 3 + self.n_sigma : 3 + self.n_sigma + 1]
) # 1 x d x h x w
# loss accumulators
var_loss = 0
instance_loss = 0
seed_loss = 0
obj_count = 0
if self.use_costmap:
costmap = costmaps[b]
instance = instances[b] # without costmap adjustment
label = labels[b]
center_image = center_images[b]
# use adjusted instance to find all ids
instance_ids = instances_adjusted[b].unique()
instance_ids = instance_ids[instance_ids != 0]
# regress bg to zero
bg_mask = label == 0
if bg_mask.sum() > 0:
if self.use_costmap:
# adjust the cost here, because some of the background pixels might
# have zero weight
seed_loss += torch.sum(
costmap * torch.pow(seed_map[bg_mask] - 0, 2)
)
else:
seed_loss += torch.sum(torch.pow(seed_map[bg_mask] - 0, 2))
for id in instance_ids:
# use the original instance without costmap adjustment to fetch
# instance mask, since the costmap may partial cut some instances
# and alter the ground truth only use the costmap to adjust the
# loss values at the end
in_mask = instance.eq(id) # 1 x d x h x w
center_mask = in_mask & center_image
if center_mask.sum().eq(1):
center = xyzm_s[center_mask.expand_as(xyzm_s)].view(3, 1, 1, 1)
else:
xyz_in = xyzm_s[in_mask.expand_as(xyzm_s)].view(3, -1)
center = xyz_in.mean(1).view(3, 1, 1, 1) # 3 x 1 x 1 x 1
# calculate sigma
sigma_in = sigma[in_mask.expand_as(sigma)].view(
self.n_sigma, -1
) # 3 x N
s = sigma_in.mean(1).view(self.n_sigma, 1, 1, 1) # n_sigma x 1 x 1 x 1
# calculate var loss before exp
if self.use_costmap:
var_loss = var_loss + torch.mean(
costmap * torch.pow(sigma_in - s[..., 0, 0].detach(), 2)
)
else:
var_loss = var_loss + torch.mean(
torch.pow(sigma_in - s[..., 0, 0].detach(), 2)
)
s = torch.exp(s * 10)
dist = torch.exp(
-1
* torch.sum(torch.pow(spatial_emb - center, 2) * s, 0, keepdim=True)
)
# apply lovasz-hinge loss
# TODO: currently, if we assume the costmap is just to make some
# instances on/off. this loss is still good. Otherwise, there might be
# a little off.
instance_loss = instance_loss + lovasz_hinge(dist * 2 - 1, in_mask)
# seed loss
if self.use_costmap:
seed_loss += self.foreground_weight * torch.sum(
costmap
* torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2)
)
else:
seed_loss += self.foreground_weight * torch.sum(
torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2)
)
# calculate instance iou
# iou_instance = calculate_iou(dist > 0.5, in_mask)
obj_count += 1
if obj_count > 0:
instance_loss /= obj_count
var_loss /= obj_count
if self.use_costmap:
seed_loss = seed_loss / costmap.sum()
else:
seed_loss = seed_loss / (depth * height * width)
loss += w_inst * instance_loss + w_var * var_loss + w_seed * seed_loss
loss = loss / (b + 1)
return loss + prediction.sum() * 0
[docs]class SpatialEmbLoss_2d(nn.Module):
def __init__(
self,
grid_y=1024,
grid_x=1024,
pixel_y=1,
pixel_x=1,
n_sigma=2,
foreground_weight=10,
use_costmap=False,
):
super().__init__()
print(
f"Created spatial emb loss function with: n_sigma: {n_sigma},"
f"foreground_weight: {foreground_weight}"
)
print("*************************")
self.n_sigma = n_sigma
self.foreground_weight = foreground_weight
# coordinate map
xm = torch.linspace(0, pixel_x, grid_x).view(1, 1, -1).expand(1, grid_y, grid_x)
ym = torch.linspace(0, pixel_y, grid_y).view(1, -1, 1).expand(1, grid_y, grid_x)
xym = torch.cat((xm, ym), 0)
self.register_buffer("xym", xym)
self.use_costmap = use_costmap
[docs] def forward(
self,
prediction,
instances,
labels,
center_images,
costmaps=None,
w_inst=1,
w_var=10,
w_seed=1,
):
# instances B C Y X
batch_size, height, width = (
prediction.size(0),
prediction.size(2),
prediction.size(3),
)
xym_s = self.xym[:, 0:height, 0:width].contiguous() # 2 x h x w
# weighted loss
instances_adjusted = instances
if self.use_costmap:
# only need to adjust instances in this step, because for pixels with
# zero weight, this step will ignore the corresponding instances
instances_adjusted[costmaps == 0] = 0
loss = 0
for b in range(0, batch_size):
spatial_emb = torch.tanh(prediction[b, 0:2]) + xym_s # 2 x h x w #TODO
sigma = prediction[b, 2 : 2 + self.n_sigma] # n_sigma x h x w
seed_map = torch.sigmoid(
prediction[b, 2 + self.n_sigma : 2 + self.n_sigma + 1]
) # 1 x h x w
# loss accumulators
var_loss = 0
instance_loss = 0
seed_loss = 0
obj_count = 0
if self.use_costmap:
costmap = costmaps[b]
instance = instances[b]
label = labels[b]
center_image = center_images[b] > 0
# use adjusted instance to find all ids
instance_ids = instances_adjusted[b].unique()
instance_ids = instance_ids[instance_ids != 0]
# regress bg to zero
bg_mask = label == 0
if bg_mask.sum() > 0:
if self.use_costmap:
# adjust the cost here, because some of the background pixels might
# have zero weight
seed_loss += torch.sum(
costmap * torch.pow(seed_map[bg_mask] - 0, 2)
)
else:
seed_loss += torch.sum(torch.pow(seed_map[bg_mask] - 0, 2))
for id in instance_ids:
# use the original instance without costmap adjustment to fetch
# instance mask, since the costmap may partial cut some instances
# and alter the ground truth only use the costmap to adjust the loss
# values at the end
in_mask = instance.eq(id) # 1 x h x w
center_mask = in_mask & center_image
if center_mask.sum().eq(1):
center = xym_s[center_mask.expand_as(xym_s)].view(2, 1, 1)
else:
xy_in = xym_s[in_mask.expand_as(xym_s)].view(
2, -1
) # TODO --> should this edge case change!
center = xy_in.mean(1).view(2, 1, 1) # 2 x 1 x 1
# calculate sigma
sigma_in = sigma[in_mask.expand_as(sigma)].view(self.n_sigma, -1)
s = sigma_in.mean(1).view(self.n_sigma, 1, 1) # n_sigma x 1 x 1
# calculate var loss before exp
if self.use_costmap:
var_loss = var_loss + torch.mean(
costmap * torch.pow(sigma_in - s[..., 0].detach(), 2)
)
else:
var_loss = var_loss + torch.mean(
torch.pow(sigma_in - s[..., 0].detach(), 2)
)
s = torch.exp(s * 10) # TODO
dist = torch.exp(
-1
* torch.sum(torch.pow(spatial_emb - center, 2) * s, 0, keepdim=True)
)
# apply lovasz-hinge loss
# TODO: currently, if we assume the costmap is just to make some
# instances on/off. this loss is still good. Otherwise, there might be
# a little off.
instance_loss = instance_loss + lovasz_hinge(dist * 2 - 1, in_mask)
# seed loss
if self.use_costmap:
seed_loss += self.foreground_weight * torch.sum(
costmap
* torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2)
)
else:
seed_loss += self.foreground_weight * torch.sum(
torch.pow(seed_map[in_mask] - dist[in_mask].detach(), 2)
)
# calculate instance iou
# iou_instance = calculate_iou(dist > 0.5, in_mask)
obj_count += 1
if obj_count > 0:
instance_loss /= obj_count
var_loss /= obj_count
if self.use_costmap:
seed_loss = seed_loss / costmap.sum()
else:
seed_loss = seed_loss / (height * width)
loss += w_inst * instance_loss + w_var * var_loss + w_seed * seed_loss
# The sum of the losses of all the instances in the batch.
loss = loss / (b + 1)
return loss + prediction.sum() * 0
[docs]def calculate_iou(pred, label):
intersection = ((label == 1) & (pred == 1)).sum()
union = ((label == 1) | (pred == 1)).sum()
if not union:
return 0
else:
iou = intersection.item() / union.item()
return iou