diff --git a/pl_bolts/models/vision/unet.py b/pl_bolts/models/vision/unet.py index eac3f407b4..9b9a5cdfa5 100644 --- a/pl_bolts/models/vision/unet.py +++ b/pl_bolts/models/vision/unet.py @@ -17,13 +17,15 @@ class UNet(nn.Module): Args: num_classes: Number of output classes required + input_channels: Number of channels in input images (default 3) num_layers: Number of layers in each side of U-net (default 5) features_start: Number of features in first layer (default 64) - bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. + bilinear: Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. """ def __init__( self, num_classes: int, + input_channels: int = 3, num_layers: int = 5, features_start: int = 64, bilinear: bool = False @@ -31,7 +33,7 @@ def __init__( super().__init__() self.num_layers = num_layers - layers = [DoubleConv(3, features_start)] + layers = [DoubleConv(input_channels, features_start)] feats = features_start for _ in range(num_layers - 1):