CutMix Pytorch Version
Published:
pytorch implementation for cutmix.
Version 1
input = input.cuda()
target = target.cuda()
r = np.random.rand(1)
if args.beta > 0 and r < args.cutmix_prob:
# generate mixed sample
lam = np.random.beta(args.beta, args.beta)
rand_index = torch.randperm(input.size()[0]).cuda()
target_a = target
target_b = target[rand_index]
bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
# compute output
output = model(input)
loss = criterion(output, target_a) * lam + criterion(output, target_b) * (1. - lam)
else:
# compute output
output = model(input)
loss = criterion(output, target)
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
Another Version
def cutmix(batch, alpha):
data, targets = batch
indices = torch.randperm(data.size(0))
shuffled_data = data[indices]
shuffled_targets = targets[indices]
lam = np.random.beta(alpha, alpha)
image_h, image_w = data.shape[2:]
cx = np.random.uniform(0, image_w)
cy = np.random.uniform(0, image_h)
w = image_w * np.sqrt(1 - lam)
h = image_h * np.sqrt(1 - lam)
x0 = int(np.round(max(cx - w / 2, 0)))
x1 = int(np.round(min(cx + w / 2, image_w)))
y0 = int(np.round(max(cy - h / 2, 0)))
y1 = int(np.round(min(cy + h / 2, image_h)))
data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
targets = (targets, shuffled_targets, lam)
return data, targets
class CutMixCollator:
def __init__(self, alpha):
self.alpha = alpha
def __call__(self, batch):
batch = torch.utils.data.dataloader.default_collate(batch)
batch = cutmix(batch, self.alpha)
return batch
class CutMixCriterion:
def __init__(self, reduction):
self.criterion = nn.CrossEntropyLoss(reduction=reduction)
def __call__(self, preds, targets):
targets1, targets2, lam = targets
return lam * self.criterion(
preds, targets1) + (1 - lam) * self.criterion(preds, targets2)