Source code for cellmaps_vnn.ccc_loss

import torch


[docs] class CCCLoss(torch.nn.Module): """ A PyTorch module for calculating the Concordance Correlation Coefficient (CCC) Loss. The CCC Loss is a measure used in regression tasks to evaluate the agreement between two variables. """ def __init__(self, eps=1e-6): """ Initializes the CCCLoss module. :param eps: A small epsilon value for numerical stability. Default is 1e-6. :type eps: float """ super(CCCLoss, self).__init__() self.eps = eps
[docs] def forward(self, y_true, y_hat): """ Computes the CCC loss given true and predicted values. :param y_true: The true values. :type y_true: Tensor :param y_hat: The predicted values. :type y_hat: Tensor :return ccc: The calculated CCC loss. :rtype ccc: Tensor """ y_true_mean = torch.mean(y_true) y_hat_mean = torch.mean(y_hat) y_true_var = torch.var(y_true) y_hat_var = torch.var(y_hat) y_true_std = torch.std(y_true) y_hat_std = torch.std(y_hat) vx = y_true - torch.mean(y_true) vy = y_hat - torch.mean(y_hat) pcc = torch.sum(vx * vy) / ( torch.sqrt(torch.sum(vx ** 2) + self.eps) * torch.sqrt(torch.sum(vy ** 2) + self.eps)) ccc = (2 * pcc * y_true_std * y_hat_std) / (y_true_var + y_hat_var + (y_hat_mean - y_true_mean) ** 2) ccc = 1 - ccc return ccc