Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: format code with Black #13

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ebtorch/nn/architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,109 +155,107 @@
# Fully-Connected Block, New version
# Joint work with Davide Roznowicz (https://github.com/DavideRoznowicz)
class FCBlock(nn.Module):
def __init__(
self,
in_sizes: Union[List[int], tuple],
out_size: int,
bias: Optional[Union[List[bool], tuple, bool]] = None,
activation_fx: Optional[Union[List, nn.ModuleList, nn.Module]] = None,
dropout: Optional[Union[List[Union[float, bool]], float, bool, tuple]] = None,
batchnorm: Optional[Union[List[bool], bool, tuple]] = None,
) -> None:
super().__init__()

self.activation_fx = nn.ModuleList()

error_uneven_size: str = (
"The length of lists of arguments must be the same across them."
)
error_illegal_dropout: str = (
"The 'dropout' argument must be either False, a float, or an iterable of floats and/or False."
)
error_illegal_dropout: str = "The 'dropout' argument must be either False, a float, or an iterable of floats and/or False."

# Default cases
if bias is None:
bias = [True] * len(in_sizes)
if dropout is None:
dropout = [False] * len(in_sizes)
if batchnorm is None:
batchnorm = [True] * (len(in_sizes) - 1) + [False]
if activation_fx is None:
for _ in range(len(in_sizes) - 1):
self.activation_fx.append(nn.ReLU())
self.activation_fx.append(nn.Identity())

# Ergonomics
if isinstance(bias, bool):
bias = [bias] * len(in_sizes)
if isinstance(dropout, bool):
if not dropout:
dropout = [False] * len(in_sizes)
else:
raise ValueError(error_illegal_dropout)
elif isinstance(dropout, float) or (
isinstance(dropout, int) and (dropout in (0, 1))
):
dropout = [dropout] * len(in_sizes)
elif not isinstance(dropout, list):
raise ValueError(error_illegal_dropout)

if isinstance(batchnorm, bool):
batchnorm = [batchnorm] * len(in_sizes)

if isinstance(activation_fx, list):
self.activation_fx = nn.ModuleList(copy.deepcopy(activation_fx))
elif isinstance(activation_fx, nn.Module) and not isinstance(
activation_fx, nn.ModuleList
):
for _ in enumerate(in_sizes):
self.activation_fx.append(copy.deepcopy(activation_fx))
elif isinstance(activation_fx, nn.ModuleList):
self.activation_fx = copy.deepcopy(activation_fx)

# Sanitize
if (
not len(in_sizes)
== len(bias)
== len(self.activation_fx)
== len(dropout)
== len(batchnorm)
):
raise ValueError(error_uneven_size)

# Start with an empty module list
self.module_battery = nn.ModuleList(modules=None)

for layer_idx in range(len(in_sizes) - 1):
self.module_battery.append(
nn.Linear(
in_features=in_sizes[layer_idx],
out_features=in_sizes[layer_idx + 1],
bias=bias[layer_idx],
)
)
self.module_battery.append(copy.deepcopy(self.activation_fx[layer_idx]))

if batchnorm[layer_idx]:
self.module_battery.append(
nn.BatchNorm1d(num_features=in_sizes[layer_idx + 1])
)

if isinstance(dropout[layer_idx], bool) and dropout[layer_idx]:
raise ValueError(error_illegal_dropout)
if not isinstance(dropout[layer_idx], bool):
self.module_battery.append(nn.Dropout(p=dropout[layer_idx]))

self.module_battery.append(
nn.Linear(in_features=in_sizes[-1], out_features=out_size, bias=bias[-1])
)
self.module_battery.append(copy.deepcopy(self.activation_fx[-1]))
if batchnorm[-1]:
self.module_battery.append(nn.BatchNorm1d(num_features=out_size))
if isinstance(dropout[-1], bool) and dropout[-1]:
raise ValueError(error_illegal_dropout)
if not isinstance(dropout[-1], bool):
self.module_battery.append(nn.Dropout(p=dropout[-1]))

Check notice on line 258 in ebtorch/nn/architectures.py

View check run for this annotation

codefactor.io / CodeFactor

ebtorch/nn/architectures.py#L158-L258

Complex Method

def reset_parameters(self) -> None:
for module in self.modules():
Expand Down Expand Up @@ -813,12 +811,14 @@
self.extract_z: bool = extract_z
self.extract_mv: bool = extract_mv

def forward(self, x: Tensor) -> Union[
def forward(
self, x: Tensor
) -> Union[
Tensor,
Tuple[Tensor, Tensor],
Tuple[Tensor, Tensor, Tensor],
Tuple[Tensor, Tensor, Tensor, Tensor],
]:

Check notice

Code scanning / CodeQL

Returning tuples with varying lengths Note

BasicVAE.forward returns
tuple of size 2
and
tuple of size 4
.
BasicVAE.forward returns
tuple of size 2
and
tuple of size 3
.
BasicVAE.forward returns
tuple of size 3
and
tuple of size 4
.
shared: Tensor = self.encoder(x)
mean: Tensor = self.mean_neck(shared)
logvar: Tensor = self.logvar_neck(shared)
Expand Down
8 changes: 4 additions & 4 deletions ebtorch/optim/lookahead.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ def step(self, closure=None):
param_state["cached_params"].copy_(p.data)
if self.pullback_momentum == "pullback":
internal_momentum = self.optimizer.state[p]["momentum_buffer"]
self.optimizer.state[p]["momentum_buffer"] = (
internal_momentum.mul_(self.la_alpha).add_(
1.0 - self.la_alpha, param_state["cached_mom"]
)
self.optimizer.state[p][
"momentum_buffer"
] = internal_momentum.mul_(self.la_alpha).add_(
1.0 - self.la_alpha, param_state["cached_mom"]
)
param_state["cached_mom"] = self.optimizer.state[p][
"momentum_buffer"
Expand Down
3 changes: 2 additions & 1 deletion ebtorch/optim/schedcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def _get_lr(self, t):
)
t_i = self.cycle_mul**i * self.t_initial
t_curr = (
t - (1 - self.cycle_mul**i) / (1 - self.cycle_mul) * self.t_initial
t
- (1 - self.cycle_mul**i) / (1 - self.cycle_mul) * self.t_initial
)
else:
i = t // self.t_initial
Expand Down
Loading