Segmentation Loss Pytorch Version
Published:
segmentation loss in pytorch
weighted bce
def criterion_pixel(logit_pixel, truth_pixel):
logit = logit_pixel.view(-1)
truth = truth_pixel.view(-1)
assert(logit.shape==truth.shape)
loss = F.binary_cross_entropy_with_logits(logit, truth, reduction='none')
if 0:
loss = loss.mean()
if 1:
pos = (truth>0.5).float()
neg = (truth<0.5).float()
pos_weight = pos.sum().item() + 1e-12
neg_weight = neg.sum().item() + 1e-12
loss = (0.25*pos*loss/pos_weight + 0.75*neg*loss/neg_weight).sum()
return loss
weighted soft dice
def criterion_pixel(logit_pixel, truth_pixel):
batch_size = len(logit_pixel)
logit = logit_pixel.view(batch_size,-1)
truth = truth_pixel.view(batch_size,-1)
assert(logit.shape==truth.shape)
loss = soft_dice_criterion(logit, truth)
loss = loss.mean()
return loss
def soft_dice_criterion(logit, truth, weight=[0.2,0.8]):
batch_size = len(logit)
probability = torch.sigmoid(logit)
p = probability.view(batch_size,-1)
t = truth.view(batch_size,-1)
w = truth.detach()
w = w*(weight[1]-weight[0])+weight[0]
p = w*(p*2-1) #convert to [0,1] --> [-1, 1]
t = w*(t*2-1)
intersection = (p * t).sum(-1)
union = (p * p).sum(-1) + (t * t).sum(-1)
dice = 1 - 2*intersection/union
loss = dice
return loss
weighted lovasz
(adjust margin instead)
def lovasz_loss(logit, truth, margin=[1,5]):
def compute_lovasz_gradient(truth): #sorted
truth_sum = truth.sum()
intersection = truth_sum - truth.cumsum(0)
union = truth_sum + (1 - truth).cumsum(0)
jaccard = 1. - intersection / union
T = len(truth)
jaccard[1:T] = jaccard[1:T] - jaccard[0:T-1]
gradient = jaccard
return gradient
def lovasz_hinge_one(logit , truth):
m = truth.detach()
m = m*(margin[1]-margin[0])+margin[0]
truth = truth.float()
sign = 2. * truth - 1.
hinge = (m - logit * sign)
hinge, permutation = torch.sort(hinge, dim=0, descending=True)
hinge = F.relu(hinge)
truth = truth[permutation.data]
gradient = compute_lovasz_gradient(truth)
loss = torch.dot(hinge, gradient)
return loss
#----
lovasz_one = lovasz_hinge_one
batch_size = len(truth)
loss = torch.zeros(batch_size).cuda()
for b in range(batch_size):
l, t = logit[b].view(-1), truth[b].view(-1)
loss[b] = lovasz_one(l, t)
return loss
def criterion_pixel(logit_pixel, truth_pixel):
batch_size = len(logit_pixel)
logit = logit_pixel.view(batch_size,-1)
truth = truth_pixel.view(batch_size,-1)
assert(logit.shape==truth.shape)
loss = lovasz_loss(logit, truth)
loss = loss.mean()
return loss