Skip to content

Commit

Permalink
removing bias from linear model regularisation
Browse files Browse the repository at this point in the history
  • Loading branch information
stanton119 committed Jun 23, 2021
1 parent 354534a commit c83c763
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pl_bolts/models/regression/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[st

# L1 regularizer
if self.hparams.l1_strength > 0:
l1_reg = sum(param.abs().sum() for param in self.parameters())
l1_reg = self.linear.weight.abs().sum()
loss += self.hparams.l1_strength * l1_reg

# L2 regularizer
if self.hparams.l2_strength > 0:
l2_reg = sum(param.pow(2).sum() for param in self.parameters())
l2_reg = self.linear.weight.pow(2).sum()
loss += self.hparams.l2_strength * l2_reg

loss /= x.size(0)
Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/regression/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[st

# L1 regularizer
if self.hparams.l1_strength > 0:
l1_reg = sum(param.abs().sum() for param in self.parameters())
l1_reg = self.linear.weight.abs().sum()
loss += self.hparams.l1_strength * l1_reg

# L2 regularizer
if self.hparams.l2_strength > 0:
l2_reg = sum(param.pow(2).sum() for param in self.parameters())
l2_reg = self.linear.weight.pow(2).sum()
loss += self.hparams.l2_strength * l2_reg

loss /= x.size(0)
Expand Down

0 comments on commit c83c763

Please sign in to comment.