Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KAN conv instead of ResNet conv #14

Open
woodszp opened this issue Aug 23, 2024 · 0 comments
Open

KAN conv instead of ResNet conv #14

woodszp opened this issue Aug 23, 2024 · 0 comments

Comments

@woodszp
Copy link

woodszp commented Aug 23, 2024

Great work!
I want to use KAN conv instead of ResNet conv, how can i do it?

First, from kan_convolutional.KANConv import KAN_Convolutional_Layer from https://github.com/AntonioTepsich/Convolutional-KANs.
Second, change the ResNet conv
But there is a problem with the code, can you help me solve it?

import torch.nn as nn
import torch
from kan_convolutional.KANConv import KAN_Convolutional_Layer


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

def conv1x1(in_planes, out_planes, stride= 1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channel, out_channel, stride=(1,1), downsample=None, device: str = 'cuda:0'):
        super(BasicBlock, self).__init__()

        # use the kan convolutional
        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            stride=stride,
            device = device
        )

        self.conv2 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (3,3),
            device = device
        )

        # self.conv1 = conv3x3(in_channel, out_channel, stride)  
        self.bn1 = nn.BatchNorm2d(15)
        self.relu = nn.ReLU(inplace=True)
        # self.conv2 = conv3x3(out_channel, out_channel)  
        self.bn2 = nn.BatchNorm2d(15)
        self.downsample = downsample
        self.stride = stride  
 

    def forward(self, x):
        residual = x
        print(f'Input shape: {x.shape}')  
        out = self.conv1(x)  # 3x3conv,s=1

        print(f'After conv1 shape: {out.shape}')  
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)  
        out = self.bn2(out)
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual  
        out = self.relu(out)
 
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channel, out_channel, stride=1, downsample=None):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=1, stride=1, bias=False)  # squeeze channels
        self.bn1 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, bias=False, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channel)
        # -----------------------------------------
        self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
                               kernel_size=1, stride=1, bias=False)  # unsqueeze channels
        self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
        self.relu = nn.ReLU(inplace=True)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = self.bn3(out)
 
        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)
 
        return out


class ResNet(nn.Module):

    def __init__(self, block, blocks_num, num_classes=1000, include_top=True, device: str = 'cuda:0'):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64

        # self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
        #                        padding=3, bias=False)
        
        self.conv1 = KAN_Convolutional_Layer(
            n_convs = 5,
            kernel_size= (7,7),
            stride=(2,2),
            padding=(3,3),
            device = device
        )       

    
        self.bn1 = nn.BatchNorm2d(15)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)
            self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1, device: str = 'cuda:0'):
        downsample = None
        if stride != 1 or self.in_channel != channel * block.expansion:
            downsample = nn.Sequential(
                KAN_Convolutional_Layer(n_convs = 5, kernel_size= (1,1), stride=(stride,stride), device = device),       
                # nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                # nn.BatchNorm2d(KAN_Convolutional_Layer(n_convs = 5, kernel_size= (1,1), stride=(stride,stride), device = device).convs[0].conv.in_features)
                nn.BatchNorm2d(15)
                
                )

        layers = []
        layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):
            layers.append(block(self.in_channel, channel))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)

        return x


def resnet18(num_classes=1000, include_top=True):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant