Skip to content

Commit

Permalink
'update'
Browse files Browse the repository at this point in the history
  • Loading branch information
superarthurlx committed Jan 24, 2025
1 parent 4d0d52c commit 3b3092c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
14 changes: 8 additions & 6 deletions basicts/metrics/corr.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import torch


def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float = np.nan) -> torch.Tensor:
"""
Calculate the Masked Pearson Correlation Coefficient between the predicted and target values,
Expand All @@ -22,6 +21,10 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float
"""

if len(prediction.shape) == 4: # (Bs, L, N, 1) else (Bs, N, 1)
prediction = torch.mean(prediction, dim=1)
target = torch.mean(target, dim=1)

if np.isnan(null_val):
mask = ~torch.isnan(target)
else:
Expand All @@ -32,18 +35,17 @@ def masked_corr(prediction: torch.Tensor, target: torch.Tensor, null_val: float
mask /= torch.mean(mask) # Normalize mask to avoid bias in the loss due to the number of valid entries
mask = torch.nan_to_num(mask) # Replace any NaNs in the mask with zero

prediction_mean = torch.mean(prediction, dim=1, keepdim=True)
target_mean = torch.mean(target, dim=1, keepdim=True)
prediction_mean = torch.mean(prediction, dim=0, keepdim=True)
target_mean = torch.mean(target, dim=0, keepdim=True)

# 计算偏差 (X - mean_X) 和 (Y - mean_Y)
prediction_dev = prediction - prediction_mean
target_dev = target - target_mean

# 计算皮尔逊相关系数
numerator = torch.sum(prediction_dev * target_dev, dim=1, keepdim=True) # 分子
denominator = torch.sqrt(torch.sum(prediction_dev ** 2, dim=1, keepdim=True) * torch.sum(target_dev ** 2, dim=1, keepdim=True)) # 分母
numerator = torch.sum(prediction_dev * target_dev, dim=0, keepdim=True) # 分子
denominator = torch.sqrt(torch.sum(prediction_dev ** 2, dim=0, keepdim=True) * torch.sum(target_dev ** 2, dim=0, keepdim=True)) # 分母
loss = numerator / denominator

loss = loss * mask # Apply the mask to the loss
loss = torch.nan_to_num(loss) # Replace any NaNs in the loss with zero

Expand Down
8 changes: 6 additions & 2 deletions basicts/metrics/r_square.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float =
"""

if len(prediction.shape) == 4: # (Bs, L, N, 1) else (Bs, N, 1)
prediction = torch.mean(prediction, dim=1)
target = torch.mean(target, dim=1)

eps = 5e-5
if np.isnan(null_val):
mask = ~torch.isnan(target)
Expand All @@ -34,8 +38,8 @@ def masked_r2(prediction: torch.Tensor, target: torch.Tensor, null_val: float =
prediction = torch.nan_to_num(prediction)
target = torch.nan_to_num(target)

ss_res = torch.sum(torch.pow((target - prediction), 2), dim=1) # 残差平方和
ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=1, keepdim=True), 2), dim=1) # 总平方和
ss_res = torch.sum(torch.pow((target - prediction), 2), dim=0) # 残差平方和
ss_tot = torch.sum(torch.pow(target - torch.mean(target, dim=0, keepdim=True), 2), dim=0) # 总平方和

# 计算 R^2
loss = 1 - (ss_res / (ss_tot + eps))
Expand Down

0 comments on commit 3b3092c

Please sign in to comment.