We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Great work! I want to use KAN conv instead of ResNet conv, how can i do it?
First, from kan_convolutional.KANConv import KAN_Convolutional_Layer from https://github.com/AntonioTepsich/Convolutional-KANs. Second, change the ResNet conv But there is a problem with the code, can you help me solve it?
from kan_convolutional.KANConv import KAN_Convolutional_Layer
import torch.nn as nn import torch from kan_convolutional.KANConv import KAN_Convolutional_Layer def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) def conv1x1(in_planes, out_planes, stride= 1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_channel, out_channel, stride=(1,1), downsample=None, device: str = 'cuda:0'): super(BasicBlock, self).__init__() # use the kan convolutional self.conv1 = KAN_Convolutional_Layer( n_convs = 5, kernel_size= (3,3), stride=stride, device = device ) self.conv2 = KAN_Convolutional_Layer( n_convs = 5, kernel_size= (3,3), device = device ) # self.conv1 = conv3x3(in_channel, out_channel, stride) self.bn1 = nn.BatchNorm2d(15) self.relu = nn.ReLU(inplace=True) # self.conv2 = conv3x3(out_channel, out_channel) self.bn2 = nn.BatchNorm2d(15) self.downsample = downsample self.stride = stride def forward(self, x): residual = x print(f'Input shape: {x.shape}') out = self.conv1(x) # 3x3conv,s=1 print(f'After conv1 shape: {out.shape}') out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, in_channel, out_channel, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, bias=False) # squeeze channels self.bn1 = nn.BatchNorm2d(out_channel) # ----------------------------------------- self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=stride, bias=False, padding=1) self.bn2 = nn.BatchNorm2d(out_channel) # ----------------------------------------- self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion, kernel_size=1, stride=1, bias=False) # unsqueeze channels self.bn3 = nn.BatchNorm2d(out_channel*self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ResNet(nn.Module): def __init__(self, block, blocks_num, num_classes=1000, include_top=True, device: str = 'cuda:0'): super(ResNet, self).__init__() self.include_top = include_top self.in_channel = 64 # self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, # padding=3, bias=False) self.conv1 = KAN_Convolutional_Layer( n_convs = 5, kernel_size= (7,7), stride=(2,2), padding=(3,3), device = device ) self.bn1 = nn.BatchNorm2d(15) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, blocks_num[0]) self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) if self.include_top: self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') def _make_layer(self, block, channel, block_num, stride=1, device: str = 'cuda:0'): downsample = None if stride != 1 or self.in_channel != channel * block.expansion: downsample = nn.Sequential( KAN_Convolutional_Layer(n_convs = 5, kernel_size= (1,1), stride=(stride,stride), device = device), # nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False), # nn.BatchNorm2d(KAN_Convolutional_Layer(n_convs = 5, kernel_size= (1,1), stride=(stride,stride), device = device).convs[0].conv.in_features) nn.BatchNorm2d(15) ) layers = [] layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride)) self.in_channel = channel * block.expansion for _ in range(1, block_num): layers.append(block(self.in_channel, channel)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) if self.include_top: x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x def resnet18(num_classes=1000, include_top=True): return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Great work!
I want to use KAN conv instead of ResNet conv, how can i do it?
First,
from kan_convolutional.KANConv import KAN_Convolutional_Layer
from https://github.com/AntonioTepsich/Convolutional-KANs.Second, change the ResNet conv
But there is a problem with the code, can you help me solve it?
The text was updated successfully, but these errors were encountered: