I am working on an algorithm for detecting keys and keyboard body from an image. The model of the keyboard is known, so a blender env. has been created to generate images with random lighting, angle, textures and objects on screen for training. I have chosen an approach utilising a UNet with the following structure:
class RELUConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() layers = [ nn.Conv2d(in_ch, out_ch,3,1,1), nn.BatchNorm2d(out_ch), nn.ReLU() ] self.model = nn.Sequential(*layers) def forward(self,x): return self.model(x) class DownBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() layers = [ RELUConvBlock(in_ch, out_ch), RELUConvBlock(out_ch, out_ch) ] self.model = nn.Sequential(*layers) def forward(self,x): return self.model(x) class UpBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() layers = [ RELUConvBlock(in_ch, out_ch), RELUConvBlock(out_ch, out_ch) ] self.model = nn.Sequential(*layers) def forward(self,x): return self.model(x) class UNet(nn.Module): def __init__(self, out_ch = 3, down_ch = [64,128,256,512]): super().__init__() self.pool = nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)) self.down0 = DownBlock(3, down_ch[0]) self.down1 = DownBlock(down_ch[0], down_ch[1]) self.down2 = DownBlock(down_ch[1], down_ch[2]) self.down3 = DownBlock(down_ch[2], down_ch[3]) self.bottleneck = DownBlock(down_ch[3], 2*down_ch[3]) self.up3 = UpBlock(2*down_ch[-1], down_ch[-1]) self.up2 = UpBlock(down_ch[-1], down_ch[-2]) self.up1 = UpBlock(down_ch[-2], down_ch[-3]) self.up0 = UpBlock(down_ch[-3], down_ch[-4]) self.connect_b_up3 = nn.ConvTranspose2d(2*down_ch[-1], down_ch[-1],kernel_size=2,stride=2) self.connect_up3_up2 = nn.ConvTranspose2d(down_ch[-1], down_ch[-2],kernel_size=2,stride=2) self.connect_up2_up1 = nn.ConvTranspose2d(down_ch[-2], down_ch[-3],kernel_size=2,stride=2) self.connect_up1_up0 = nn.ConvTranspose2d(down_ch[-3], down_ch[-4],kernel_size=2,stride=2) self.final_conv = nn.Conv2d(down_ch[0], out_ch, kernel_size=1) def _crop_to_match(self, tensor, target): _, _, h, w = target.shape return tensor[:, :, :h, :w] def forward(self, x): skip_connections = [] x = self.down0(x) skip_connections.append(x) x = self.pool(x) x = self.down1(x) skip_connections.append(x) x = self.pool(x) x = self.down2(x) skip_connections.append(x) x = self.pool(x) x = self.down3(x) skip_connections.append(x) x = self.pool(x) x = self.bottleneck(x) x = self.connect_b_up3(x) skip_connection = self._crop_to_match(skip_connections[3], x) x = torch.cat((skip_connection, x), dim=1) x = self.up3(x) x = self.connect_up3_up2(x) skip_connection = self._crop_to_match(skip_connections[2], x) x = torch.cat((skip_connection, x), dim=1) x = self.up2(x) x = self.connect_up2_up1(x) skip_connection = self._crop_to_match(skip_connections[1], x) x = torch.cat((skip_connection, x), dim=1) x = self.up1(x) x = self.connect_up1_up0(x) skip_connection = self._crop_to_match(skip_connections[0], x) x = torch.cat((skip_connection, x), dim=1) x = self.up0(x) x = self.final_conv(x) return x #nn.functional.sigmoid(x) The network is trained on 500 images generated from the blender env.
LEARNING_RATE = 1e-4 num_epochs = 10 loss_fn = nn.CrossEntropyLoss() optimizer = Adam(model.parameters(), lr=LEARNING_RATE) scaler = torch.amp.GradScaler(device) model.train() for epoch in range(num_epochs): loop = tqdm(enumerate(train_loader), total = len(train_loader)) for batch_idx, (data, targets) in loop: data = data.to(device) targets = targets.to(device) with torch.amp.autocast(device_type=str(device)): predictions = model(data) loss = loss_fn(predictions, targets) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() loop.set_postfix(loss = loss.item()) While the results on training and testing data are very promising (results_1), when tested on real photos the network performs poorly (results_2). What can be done to fix that? Is there any way other than creating a better blender env? Maybe a different type of nn or a different solution entirely? This is my first time using AI to solve real world problem so i probably made a lot of bad choices.
I have tried to tweak the parameters, make the blender images darker to reflect reality better and using openCV to manipulate with the results but none of that worked well