pytorch | 深度学习分割网络U-net的pytorch模型实现

    xiaoxiao2021-03-25  114

    这个是pytorch出来没多久的时候写的了,现在看是非常傻逼的方法,羞耻感十足。 推荐学习项目【pix2pix】的代码,优雅! –作者 2018.1.30

    U-Net 的实现现在github上非常多了吧!用dense-net大概也随随便便吊打了吧!不要用我这个啦~批判性参考一下pytorch咋用还差不多~! –作者 2018.4.09

    大概不支持pytorch 0.4以及以上版本


    pytorch是一个很好用的工具,作为一个python的深度学习包,其接口调用起来很方便,具备自动求导功能,适合快速实现构思,且代码可读性强,比如前阵子的WGAN1 好了回到Unet。 原文 arXiv:1505.04597 [cs.CV] 主页 U-Net: Convolutional Networks for Biomedical Image Segmentation 该文章实现了生物图像分割的一个网络,2015年的模型,好像是该领域的冠军。模型长得像个巨大的U,故取名Unet,之前很火的动漫线稿自动上色2就是用的这个模型。当然,该模型也许比不上现在的各种生成式模型了,不过拿来在pytorch里练练手,当做boundary提取,还是可以的。注意这个网络的输出size与输入size不一致,所以应用起来需要额外的处理。 模型长这个鬼样:


    参考pytorch的tutorial代码,实现如下:

    #unet.py: from __future__ import division import torch.nn as nn import torch.nn.functional as F import torch from numpy.linalg import svd from numpy.random import normal from math import sqrt class UNet(nn.Module): def __init__(self,colordim =1): super(UNet, self).__init__() self.conv1_1 = nn.Conv2d(colordim, 64, 3) # input of (n,n,1), output of (n-2,n-2,64) self.conv1_2 = nn.Conv2d(64, 64, 3) self.bn1 = nn.BatchNorm2d(64) self.conv2_1 = nn.Conv2d(64, 128, 3) self.conv2_2 = nn.Conv2d(128, 128, 3) self.bn2 = nn.BatchNorm2d(128) self.conv3_1 = nn.Conv2d(128, 256, 3) self.conv3_2 = nn.Conv2d(256, 256, 3) self.bn3 = nn.BatchNorm2d(256) self.conv4_1 = nn.Conv2d(256, 512, 3) self.conv4_2 = nn.Conv2d(512, 512, 3) self.bn4 = nn.BatchNorm2d(512) self.conv5_1 = nn.Conv2d(512, 1024, 3) self.conv5_2 = nn.Conv2d(1024, 1024, 3) self.upconv5 = nn.Conv2d(1024, 512, 1) self.bn5 = nn.BatchNorm2d(512) self.bn5_out = nn.BatchNorm2d(1024) self.conv6_1 = nn.Conv2d(1024, 512, 3) self.conv6_2 = nn.Conv2d(512, 512, 3) self.upconv6 = nn.Conv2d(512, 256, 1) self.bn6 = nn.BatchNorm2d(256) self.bn6_out = nn.BatchNorm2d(512) self.conv7_1 = nn.Conv2d(512, 256, 3) self.conv7_2 = nn.Conv2d(256, 256, 3) self.upconv7 = nn.Conv2d(256, 128, 1) self.bn7 = nn.BatchNorm2d(128) self.bn7_out = nn.BatchNorm2d(256) self.conv8_1 = nn.Conv2d(256, 128, 3) self.conv8_2 = nn.Conv2d(128, 128, 3) self.upconv8 = nn.Conv2d(128, 64, 1) self.bn8 = nn.BatchNorm2d(64) self.bn8_out = nn.BatchNorm2d(128) self.conv9_1 = nn.Conv2d(128, 64, 3) self.conv9_2 = nn.Conv2d(64, 64, 3) self.conv9_3 = nn.Conv2d(64, colordim, 1) self.bn9 = nn.BatchNorm2d(colordim) self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False) self.upsample = nn.UpsamplingBilinear2d(scale_factor=2) self._initialize_weights() def forward(self, x1): x1 = F.relu(self.bn1(self.conv1_2(F.relu(self.conv1_1(x1))))) # print('x1 size: %d'%(x1.size(2))) x2 = F.relu(self.bn2(self.conv2_2(F.relu(self.conv2_1(self.maxpool(x1)))))) # print('x2 size: %d'%(x2.size(2))) x3 = F.relu(self.bn3(self.conv3_2(F.relu(self.conv3_1(self.maxpool(x2)))))) # print('x3 size: %d'%(x3.size(2))) x4 = F.relu(self.bn4(self.conv4_2(F.relu(self.conv4_1(self.maxpool(x3)))))) # print('x4 size: %d'%(x4.size(2))) xup = F.relu(self.conv5_2(F.relu(self.conv5_1(self.maxpool(x4))))) # x5 # print('x5 size: %d'%(xup.size(2))) xup = self.bn5(self.upconv5(self.upsample(xup))) # x6in cropidx = (x4.size(2) - xup.size(2)) // 2 x4 = x4[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x4crop.size(2),xup.size(2))) xup = self.bn5_out(torch.cat((x4, xup), 1)) # x6 cat x4 xup = F.relu(self.conv6_2(F.relu(self.conv6_1(xup)))) # x6out xup = self.bn6(self.upconv6(self.upsample(xup))) # x7in cropidx = (x3.size(2) - xup.size(2)) // 2 x3 = x3[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x3crop.size(2),xup.size(2))) xup = self.bn6_out(torch.cat((x3, xup), 1) ) # x7 cat x3 xup = F.relu(self.conv7_2(F.relu(self.conv7_1(xup)))) # x7out xup = self.bn7(self.upconv7(self.upsample(xup)) ) # x8in cropidx = (x2.size(2) - xup.size(2)) // 2 x2 = x2[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x2crop.size(2),xup.size(2))) xup = self.bn7_out(torch.cat((x2, xup), 1)) # x8 cat x2 xup = F.relu(self.conv8_2(F.relu(self.conv8_1(xup)))) # x8out xup = self.bn8(self.upconv8(self.upsample(xup)) ) # x9in cropidx = (x1.size(2) - xup.size(2)) // 2 x1 = x1[:, :, cropidx:cropidx + xup.size(2), cropidx:cropidx + xup.size(2)] # print('crop1 size: %d, x9 size: %d'%(x1crop.size(2),xup.size(2))) xup = self.bn8_out(torch.cat((x1, xup), 1)) # x9 cat x1 xup = F.relu(self.conv9_3(F.relu(self.conv9_2(F.relu(self.conv9_1(xup)))))) # x9out return F.softsign(self.bn9(xup)) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() unet = UNet().cuda()

    训练集。。因为没找到原先的库,就先用着BSDS500了。。。这里的BSD500是我上一篇博文处理过的那样的 但是由于训练集很少,可以做随机中心裁剪和随机水平翻转的数据增广, 注意在torchvision.transforms这个包里,不支持对多幅输入图像做相同的裁剪操作,所以把这个增广的步骤放到train.py了

    #BSDDataLoader.py #这里主要是想说明pytorch的训练集load操作,简直傻瓜式操作!妈妈再也不用担心我的预处理了! from os.path import exists, join from torchvision.transforms import Compose, CenterCrop, ToTensor, Scale import torch.utils.data as data from os import listdir from PIL import Image def bsd500(dest="/dir/to/dataset"):#自行修改路径!! if not exists(dest): print("dataset not exist ") return dest def input_transform(crop_size): return Compose([ CenterCrop(crop_size), ToTensor() ]) def get_training_set(size, target_mode='seg', colordim=1): root_dir = bsd500() train_dir = join(root_dir, "train") return DatasetFromFolder(train_dir,target_mode,colordim, input_transform=input_transform(size), target_transform=input_transform(size)) def get_test_set(size, target_mode='seg', colordim=1): root_dir = bsd500() test_dir = join(root_dir, "test") return DatasetFromFolder(test_dir,target_mode,colordim, input_transform=input_transform(size), target_transform=input_transform(size)) def is_image_file(filename): return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) def load_img(filepath,colordim): if colordim==1: img = Image.open(filepath).convert('L') else: img = Image.open(filepath).convert('RGB') #y, _, _ = img.split() return img class DatasetFromFolder(data.Dataset): def __init__(self, image_dir, target_mode, colordim, input_transform=None, target_transform=None): super(DatasetFromFolder, self).__init__() self.image_filenames = [x for x in listdir( join(image_dir,'data') ) if is_image_file(x)] self.input_transform = input_transform self.target_transform = target_transform self.image_dir = image_dir self.target_mode = target_mode self.colordim = colordim def __getitem__(self, index): input = load_img(join(self.image_dir,'data',self.image_filenames[index]),self.colordim) if self.target_mode=='seg': target = load_img(join(self.image_dir,'seg',self.image_filenames[index]),1) else: target = load_img(join(self.image_dir,'bon',self.image_filenames[index]),1) if self.input_transform: input = self.input_transform(input) if self.target_transform: target = self.target_transform(target) return input, target def __len__(self): return len(self.image_filenames) #train.py ''' 因为原文中网络的input和output size不一样,不知道他是怎么搞的loss 简单起见,我就将groundtruth中心crop到和output一样大,然后求MSE loss了 结果还是收敛的,做过增广的数据用于训练,得到的测试集loss要大一点,因为训练时的尺度不一样,估计影响了泛化效果 ''' from __future__ import print_function from math import log10 import numpy as np import random import os import torch import torch.nn as nn import torch.optim as optim from torch.autograd import Variable from torch.utils.data import DataLoader from unet import UNet from BSDDataLoader import get_training_set,get_test_set import torchvision # Training settings class option: def __init__(self): self.cuda = True #use cuda? self.batchSize = 4 #training batch size self.testBatchSize = 4 #testing batch size self.nEpochs = 140 #umber of epochs to train for self.lr = 0.001 #Learning Rate. Default=0.01 self.threads = 4 #number of threads for data loader to use self.seed = 123 #random seed to use. Default=123 self.size = 428 self.remsize = 20 self.colordim = 1 self.target_mode = 'bon' self.pretrain_net = "/home/wcd/PytorchProject/Unet/unetdata/checkpoint/model_epoch_140.pth" def map01(tensor,eps=1e-5): #input/output:tensor max = np.max(tensor.numpy(), axis=(1,2,3), keepdims=True) min = np.min(tensor.numpy(), axis=(1,2,3), keepdims=True) if (max-min).any(): return torch.from_numpy( (tensor.numpy() - min) / (max-min + eps) ) else: return torch.from_numpy( (tensor.numpy() - min) / (max-min) ) def sizeIsValid(size): for i in range(4): size -= 4 if size%2: return 0 else: size /= 2 for i in range(4): size -= 4 size *= 2 return size-4 opt = option() target_size = sizeIsValid(opt.size) print("outputsize is: "+str(target_size)) if not target_size: raise Exception("input size invalid") target_gap = (opt.size - target_size)//2 cuda = opt.cuda if cuda and not torch.cuda.is_available(): raise Exception("No GPU found, please run without --cuda") torch.manual_seed(opt.seed) if cuda: torch.cuda.manual_seed(opt.seed) print('===> Loading datasets') train_set = get_training_set(opt.size + opt.remsize, target_mode=opt.target_mode, colordim=opt.colordim) test_set = get_test_set(opt.size, target_mode=opt.target_mode, colordim=opt.colordim) training_data_loader = DataLoader(dataset=train_set, num_workers=opt.threads, batch_size=opt.batchSize, shuffle=True) testing_data_loader = DataLoader(dataset=test_set, num_workers=opt.threads, batch_size=opt.testBatchSize, shuffle=False) print('===> Building unet') unet = UNet(opt.colordim) criterion = nn.MSELoss() if cuda: unet = unet.cuda() criterion = criterion.cuda() pretrained = True if pretrained: unet.load_state_dict(torch.load(opt.pretrain_net)) optimizer = optim.SGD(unet.parameters(), lr=opt.lr) print('===> Training unet') def train(epoch): epoch_loss = 0 for iteration, batch in enumerate(training_data_loader, 1): randH = random.randint(0, opt.remsize) randW = random.randint(0, opt.remsize) input = Variable(batch[0][:, :, randH:randH + opt.size, randW:randW + opt.size]) target = Variable(batch[1][:, :, randH + target_gap:randH + target_gap + target_size, randW + target_gap:randW + target_gap + target_size]) #target =target.squeeze(1) #print(target.data.size()) if cuda: input = input.cuda() target = target.cuda() input = unet(input) #print(input.data.size()) loss = criterion( input, target) epoch_loss += loss.data[0] loss.backward() optimizer.step() if iteration%10 is 0: print("===> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, len(training_data_loader), loss.data[0])) imgout = input.data/2 +1 torchvision.utils.save_image(imgout,"/home/wcd/PytorchProject/Unet/unetdata/checkpoint/epch_"+str(epoch)+'.jpg') print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, epoch_loss / len(training_data_loader))) def test(): totalloss = 0 for batch in testing_data_loader: input = Variable(batch[0],volatile=True) target = Variable(batch[1][:, :, target_gap:target_gap + target_size, target_gap:target_gap + target_size], volatile=True) #target =target.long().squeeze(1) if cuda: input = input.cuda() target = target.cuda() optimizer.zero_grad() prediction = unet(input) loss = criterion(prediction, target) totalloss += loss.data[0] print("===> Avg. test loss: {:.4f} dB".format(totalloss / len(testing_data_loader))) def checkpoint(epoch): model_out_path = "/home/wcd/PytorchProject/Unet/unetdata/checkpoint/model_epoch_{}.pth".format(epoch) torch.save(unet.state_dict(), model_out_path) print("Checkpoint saved to {}".format(model_out_path)) for epoch in range(141, 141+opt.nEpochs + 1): train(epoch) if epoch%10 is 0: checkpoint(epoch) test() checkpoint(epoch)

    如果想要看看网络的结构 还可以这样

    from graphviz import Digraph from torch.autograd import Variable from unet import UNet def make_dot(var): node_attr = dict(style='filled', shape='box', align='left', fontsize='10', ranksep='0.1', height='0.5') dot = Digraph(node_attr=node_attr, graph_attr=dict(size="30,14")) seen = set() def add_nodes(var): if var not in seen: if isinstance(var, Variable): value = '('+(', ').join(['%d'% v for v in var.size()])+')' dot.node(str(id(var)), str(value), fillcolor='lightblue') else: dot.node(str(id(var)), str(type(var).__name__)) seen.add(var) if hasattr(var, 'previous_functions'): for u in var.previous_functions: dot.edge(str(id(u[0])), str(id(var))) add_nodes(u[0]) add_nodes(var.creator) return dot unet = UNet(opt.colordim) x = Variable(torch.rand(1, 1, 572, 572)).cuda() h_x = unet(x) make_dot(h_x)

    然后就会看到这样的图: 真壮观,科科!

    然后看看结果吧~ ground truth data


    https://github.com/martinarjovsky/WassersteinGAN ↩https://github.com/pfnet/PaintsChainer ↩
    转载请注明原文地址: https://ju.6miu.com/read-7292.html

    最新回复(0)