From 3b3092cc4c1866d1e10c74853827bcf4e10aafc2 Mon Sep 17 00:00:00 2001 From: lx Date: Fri, 24 Jan 2025 16:17:46 +0800 Subject: [PATCH] 'update' --- basicts/metrics/corr.py | 14 ++++++++------ basicts/metrics/r_square.py | 8 ++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/basicts/metrics/corr.py b/basicts/metrics/corr.py index 4e10ee4d..cac99cc5 100644 --- a/basicts/metrics/corr.py +++ b/basicts/metrics/corr.py @@ -1,7 +1,6 @@ import numpy as np import torch - def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor: """ Calculate the Masked Pearson Correlation Coefficient between the predicted and target values, @@ -22,6 +21,10 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float """ + if len(prediction.shape) == 4: # (Bs, L, N, 1) else (Bs, N, 1) + prediction = torch.mean(prediction, dim=1) + target = torch.mean(target, dim=1) + if np.isnan(null_val): mask = ~torch.isnan(target) else: @@ -32,18 +35,17 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float mask /= torch.mean(mask) # Normalize mask to avoid bias in the loss due to the number of valid entries mask = torch.nan_to_num(mask) # Replace any NaNs in the mask with zero - prediction_mean = torch.mean(prediction, dim=1, keepdim=True) - target_mean = torch.mean(target, dim=1, keepdim=True) + prediction_mean = torch.mean(prediction, dim=0, keepdim=True) + target_mean = torch.mean(target, dim=0, keepdim=True) # 计算偏差 (X - mean_X) 和 (Y - mean_Y) prediction_dev = prediction - prediction_mean target_dev = target - target_mean # 计算皮尔逊相关系数 - numerator = torch.sum(prediction_dev * target_dev, dim=1, keepdim=True) # 分子 - denominator = torch.sqrt(torch.sum(prediction_dev ** 2, dim=1, keepdim=True) * torch.sum(target_dev ** 2, dim=1, keepdim=True)) # 分母 + numerator = torch.sum(prediction_dev * target_dev, dim=0, keepdim=True) # 分子 + denominator = torch.sqrt(torch.sum(prediction_dev ** 2, dim=0, keepdim=True) * torch.sum(target_dev ** 2, dim=0, keepdim=True)) # 分母 loss = numerator / denominator - loss = loss * mask # Apply the mask to the loss loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero diff --git a/basicts/metrics/r_square.py b/basicts/metrics/r_square.py index 22a4fcfa..a8e9ae9e 100644 --- a/basicts/metrics/r_square.py +++ b/basicts/metrics/r_square.py @@ -22,6 +22,10 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float = """ + if len(prediction.shape) == 4: # (Bs, L, N, 1) else (Bs, N, 1) + prediction = torch.mean(prediction, dim=1) + target = torch.mean(target, dim=1) + eps = 5e-5 if np.isnan(null_val): mask = ~torch.isnan(target) @@ -34,8 +38,8 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float = prediction = torch.nan_to_num(prediction) target = torch.nan_to_num(target) - ss_res = torch.sum(torch.pow((target - prediction), 2), dim=1) # 残差平方和 - ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=1, keepdim=True), 2), dim=1) # 总平方和 + ss_res = torch.sum(torch.pow((target - prediction), 2), dim=0) # 残差平方和 + ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=0, keepdim=True), 2), dim=0) # 总平方和 # 计算 R^2 loss = 1 - (ss_res / (ss_tot + eps))