Source code for propel_loss.propel

import math
import itertools

import torch
import numpy as np


# Define EPS for numerical stability
eps = 7.0/3 - 4.0/3 - 1


[docs]def getloss(muG, sigmaG, muM, sigmaM): r"""PROPEL Loss function Implements the following equations: In forward pass: .. math:: L = -\log\underbrace{\left[ \frac{2}{I} \sum_{i=1}^{I} G(P_{gt}, P_{i}) \right]\rule[-12pt]{0pt}{5pt}}_{\mbox{$T1$}} + \log \underbrace{\left[H({P_{gt}}) + \frac{1}{I^2}\sum_{i=1}^{I} H({P_{i}}) + \frac{2}{I^2} \sum_{i < j}^{I} G({P_{i},P_{j}}) \right]\rule[-12pt]{0pt}{5pt}}_{\mbox{$T2$}} In backward pass: .. math:: \frac{\partial L}{\partial \mu_{x_{ni}}} = -\frac{1}{T1}\left[ \frac{\partial G(P_{gt}, P_{i})}{\partial \mu_{x_{ni}}} \right] + \frac{1}{T2} \left[ \frac{2}{I^2} \sum_{i < j}^{I} \frac{\partial G({P_{i},P_{j}})}{\partial \mu_{x_{ni}}} \right] .. math:: \frac{\partial L}{\partial \sigma_{x_{ni}}} = -\frac{1}{T1}\left[ \frac{\partial G(P_{gt}, P_{i})}{\partial \sigma_{x_{ni}}} \right] + \frac{1}{T2} \left[ \frac{1}{I^2} \frac{\partial H({P_{i}})}{\partial \sigma_{x_{ni}}} + \frac{2}{I^2} \sum_{i < j}^{I} \frac{\partial G({P_{i},P_{j}})}{\partial \sigma_{x_{ni}}} \right] Args: muG (torch.tensor): mean for groundtruth Gaussian distribution sigmaG (torch.tensor): standard deviation for groundtruth Gaussian distribution muM (torch.tensor): mean for model Mixture of Gaussian distribution (model output) sigmaM (torch.tensor): standard deviation for Mixture of Gaussian distribution (model output) Returns: torch.tensor: computed loss in forward pass, gradients w.r.t muM/sigmaM in backward pass """ # get the number of gaussians in the mixture model num_gaussians = muM.shape[1] # \frac{2}{I} \sum_{i=1}^{I} G(P_{gt}, P_{i}) T1 = ((2 / num_gaussians) * g_function(muG, sigmaG, muM, sigmaM) ).sum(dim=-1) # around num_gaussians # H({P_{gt}}) + \frac{1}{I^2}\sum_{i=1}^{I} H({P_{i}}) T2 = h_function(sigmaG) + ((1/num_gaussians**2) * h_function(sigmaM)).sum(dim=-1) # computing the last term of a^2 + b^2 + >>>2ab <<<<<< # \frac{2}{I^2} \sum_{i < j}^{I} G({P_{i},P_{j}}) i_index = torch.tensor( [i for i, j in itertools.combinations(range(num_gaussians), 2)]) j_index = torch.tensor( [j for i, j in itertools.combinations(range(num_gaussians), 2)]) i_index = i_index.to(muM.device) j_index = j_index.to(muM.device) T2_in = g_function(muM.index_select(1, i_index), sigmaM.index_select( 1, i_index), muM.index_select(1, j_index), sigmaM.index_select(1, j_index)).sum(dim=1) T2 = T2 + (2 / (num_gaussians**2)) * T2_in L = -torch.log10(T1) + torch.log10(T2) return L, T1, T2
[docs]def h_function(sigmaM): r"""H function implementation Implements the following equation: .. math:: H(P_i) = \frac{1}{(2\sqrt{\pi})^n \sqrt{\sigma_{x_{1i}}\cdots\sigma_{x_{ni}}}} Args: sigmaM (torch.tensor): standard deviation of our input Gaussian distribution Returns: torch.tensor: result of H(P_m) """ num_dims = sigmaM.shape[-1] dTerm = sigmaM.prod(dim=-1) dTerm_all = (2*math.sqrt(math.pi))**(num_dims) * torch.sqrt(dTerm) out = 1/(dTerm_all.clamp_min(eps)) return out
[docs]def g_function(muM1, sigmaM1, muM2, sigmaM2): r"""G Function implementation Implements the following equation: .. math:: G(P_i, P_j) = \frac{e^{\big[\frac{2\mu_{x_{1i}}\mu_{x_{1j}} - {\mu_{x_{1i}}}^2 - {\mu_{x_{1j}}}^2}{2(\sigma_{x_{1i}}+\sigma_{x_{1j}})} + \cdots + \frac{2\mu_{x_{ni}}\mu_{x_{nj}} - {\mu_{x_{ni}}}^2 - {\mu_{x_{nj}}}^2}{2(\sigma_{x_{ni}}+\sigma_{x_{nj}})}\big]}}{(\sqrt{2\pi})^n \sqrt{(\sigma_{x_{1i}} + \sigma_{x_{1j}}) \cdots (\sigma_{x_{ni}} + \sigma_{x_{nj}})}} Args: muM1 (torch.tensor): mean for first Gaussian distribution sigmaM1 (torch.tensor): standard deviation for first Gaussian distribution muM2 (torch.tensor): mean for second Gaussian distribution sigmaM2 (torch.tensor): standard deviation for second Gaussian distribution Returns: torch.tensor: result of G(P_1, P_2) """ num_dims = muM1.shape[-1] num_gaussians = muM2.shape[1] # expand mus and sigmas if not enough recieved if len(muM1.shape) < len(muM2.shape): muM1 = muM1.unsqueeze(dim=1) # calculate the denominator term sumSigma = sigmaM1 + sigmaM2 mulSigma = sumSigma.prod(dim=-1) # mul dimensions aTerm = (math.sqrt(2*math.pi)**(num_dims)) * torch.sqrt(mulSigma) A = 1/(aTerm.clamp_min(eps)) bTerm = 2*(sigmaM1 + sigmaM2) B = ((2*muM1*muM2 - muM1.pow(2) - muM2.pow(2)) / (bTerm.clamp_min(eps))).sum(dim=-1) out = A * torch.exp(B) return out
[docs]def unpack_prediction(pred, num_dims): r"""Helper function to unpack tensor coming from output of neural network It expects the pred tensor to have the following shape: [num_batch, num_gaussians, num_dimensions * 2] where: First [num_batch, num_gaussians, ::num_dimensions] correspond to mean Second [num_batch, num_gaussians, num_dimensions::] correspond to standard deviation Args: pred (torch.tensor): prediction output from a neural network with shape [num_batch, num_gaussians, num_dimensions] num_dims (int): number of dimensions to unpack data for, e.g. 3 for 3D problems Returns: tuple of torch.tensors: unpacked mean (g_mu) and standard deviation (g_sigma) """ # index (num_dimensions) we are splitting is for the last dimension, i.e. 2 g_mu, g_sigma = torch.split(pred, [num_dims, num_dims], dim=2) return g_mu, g_sigma
[docs]class PROPEL(torch.nn.modules.loss._Loss): r"""PRObabilistic Parametric rEgression Loss (PROPEL) for enabling neural networks to output parameters of a mixture of Gaussian distributions from [1]. [1] "PROPEL: Probabilistic Parametric Regression Loss for Convolutional Neural Networks", M. Asad et al. - 25th International Conference on Pattern Recognition (ICPR), 2020 Usage instructions: In order to use the loss function, the expected output shape from neural network is: [num_batches, num_gaussians, 2*num_dimensions] where, num_batches --> number of batches num_gaussians --> number of Gaussians in mixture of Gaussians num_dimensions --> number of dimensions in each sample - 2 accounts for mean/variance for each dimension num_gaussians can be set to a number corresponding to how complex you wish your mixture of Gaussian distribution, e.g. num_gaussians = 2 vs num_gaussians = 10 (where 10 can model much more complex distribution whereas 2 will apply regularisation affect by trying to model a gt distribution with 2 Gaussians in mixture) One example is of inferring 3D head orientation from 2D images. The output is 3D (num_dimensions = 3) and if we use two Gaussians with a num_batch=2, then output will be of size [b, 2, 2 * num_dimensions] = [b, 2, 6] - shape for output of the network For further usage examples, see optimisation tests inside tests/ folder within this project directory. """ def __init__(self, sigma=0.1, reduction='mean'): super(PROPEL, self).__init__(reduction=reduction) # defining sigma as parameter so *.to(device) function # can be used to change device for it from loss object self.sigma = torch.nn.Parameter( torch.Tensor([[[sigma]]]), requires_grad=False)
[docs] def forward(self, output, target): # target shape: # [batch, num_dims] # output shape (network output): # [batch, num_gaussians, 2 * num_dims] # # 2 is to account for mean and variance # first [batch, num_gaussians, num_dims] correspond to mean # second [batch, num_gaussians, num_dims] correspond to variance # get size of everything num_batch = output.shape[0] num_dims = int(output.shape[2]/2) # check if target labels are in line with what we got assert num_dims == target.shape[1], 'Num of dimensions dont match : output [%d] != target [%d]' % ( num_dims, target.shape[1]) # split output prediction into relevant sections # splits are of size num_dims (mean) and num_dims (variance) output_mu, output_sigma = unpack_prediction(output, num_dims) # apply differentiable loss - forward + backwards L, _, _ = getloss(target, self.sigma.pow( 2), output_mu, output_sigma.pow(2)) # apply selected reduction method if self.reduction == 'mean': return L.mean() elif self.reduction == 'sum': return L.sum() else: # none return L