找回密码
 立即注册
首页 业界区 业界 【语义分割专栏】:U-net实战篇(附上完整可运行的代码py ...

【语义分割专栏】:U-net实战篇(附上完整可运行的代码pytorch)

崆蛾寺 3 天前
目录

  • 前言
  • U-net全流程代码

    • 模型搭建(model)
    • 数据处理(dataloader)
    • 评价指标(metric)
    • 训练流程(train)
    • 模型测试(test)

  • 效果图
  • 结语

前言

U-net原理篇讲解:【语义分割专栏】2:U-net原理篇(由浅入深) - carpell - 博客园
代码地址,下载可复现:fouen6/unet_semantic-segmentation
本篇文章收录于语义分割专栏,如果对语义分割领域感兴趣的,可以去看看专栏,会对经典的模型以及代码进行详细的讲解哦!其中会包含可复现的代码!(数据集文中提供了下载地址,下载不到可在评论区要取)
上篇文章已经带大家学习过了U-net的原理,相信大家对于原理应该有了比较深的了解。本文将会带大家去手动复现属于自己的一个语义分割模型。将会深入代码进行讲解,如果有讲错的地方欢迎大家批评指正!
其实所有的深度学习模型的搭建我认为可以总结成五部分:模型的构建,数据集的处理,评价指标的设定,训练流程,测试。其实感觉有点深度学习代码八股文的那种意思。本篇同样的也会按照这样的方式进行讲解,希望大家能够深入代码去进行了解学习。
请记住:只懂原理不懂代码,你就算有了很好的想法创新点,你也难以去实现,所以希望大家能够深入去了解,最好能够参考着本文自己复现一下。
1.png

U-net全流程代码

模型搭建(model)

我们先来看U-net模型代码,当然细节上跟原论文中的U-net不是完全一样,原来的U-net模型是适用于医学图像分割任务,所以其有部分设计也是为了医学图像分割设计的,我这里复现的U-net代码更适合普遍的语义分割任务,其输入输出的shape大小是相同的。
首先是我将所有的上采样下采样中的卷积部分集成到了一起,看模型结构能够看出,每个部分都是两次卷积,所以代码如下,就在设置不同stage的时候设置好输入输出通道即可。
  1. class Down_Up_Conv(nn.Module):
  2.     def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
  3.         super(Down_Up_Conv, self).__init__()
  4.         self.conv_block = nn.Sequential(
  5.             nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
  6.             nn.BatchNorm2d(out_channels),
  7.             nn.ReLU(),
  8.             nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
  9.             nn.BatchNorm2d(out_channels),
  10.             nn.ReLU()
  11.         )
  12.     def forward(self, x):
  13.         return self.conv_block(x)
复制代码
然后这是跳跃连接的代码,同时我们采取了crop操作。我们通过获取两个feature map的长宽,然后再对齐之后进行再通道维上的拼接,代码如下,还是比较好理解的。
  1. def crop_and_concat(upsampled, bypass):
  2.     """
  3.     将两个 feature map 在 H 和 W 上对齐后拼接(dim=1)
  4.     - upsampled: 解码器上采样后的特征图 (N, C1, H1, W1)
  5.     - bypass: 编码器传来的特征图 (N, C2, H2, W2)
  6.     """
  7.     h1, w1 = upsampled.shape[2], upsampled.shape[3]
  8.     h2, w2 = bypass.shape[2], bypass.shape[3]
  9.     # 计算差值
  10.     delta_h = h2 - h1
  11.     delta_w = w2 - w1
  12.     # 对 encoder 输出进行中心裁剪
  13.     bypass_cropped = bypass[:, :,
  14.                      delta_h // 2: delta_h // 2 + h1,
  15.                      delta_w // 2: delta_w // 2 + w1]
  16.     # 拼接通道维
  17.     return torch.cat([upsampled, bypass_cropped], dim=1)
复制代码
然后就是搭建我们的U-net模型了,这还是比较容易的,将encoder部分的五个阶段的下采样卷积定义好,注意通道数的变换,然后就是Decoder的上采样的过程,我们使用的是转置卷积,上采样后还有卷积过程,所以我们按照U-net的模型图搭建即可。注意,我这里是把maxpooling给摘出来了的,每个下采样卷积之后都会有一个maxpooling层,这个可别忘了,在forward里面有体现。定义好模型参数之后就是模型参数的初始化了,这个步骤可千万不能忘。
  1. class UNet(nn.Module):
  2.     def __init__(self, num_classes=2):
  3.         super(UNet, self).__init__()
  4.         self.stage_down1=Down_Up_Conv(3, 64)
  5.         self.stage_down2=Down_Up_Conv(64, 128)
  6.         self.stage_down3=Down_Up_Conv(128, 256)
  7.         self.stage_down4=Down_Up_Conv(256, 512)
  8.         self.stage_down5=Down_Up_Conv(512, 1024)
  9.         self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2,padding=1)
  10.         self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2,padding=1)
  11.         self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2,padding=1)
  12.         self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2,padding=1)
  13.         self.stage_up4=Down_Up_Conv(1024, 512)
  14.         self.stage_up3=Down_Up_Conv(512, 256)
  15.         self.stage_up2=Down_Up_Conv(256, 128)
  16.         self.stage_up1=Down_Up_Conv(128, 64)
  17.         self.stage_out=Down_Up_Conv(64, num_classes)
  18.         self.maxpool = nn.MaxPool2d(kernel_size=2)
  19.         self.initialize_weights()
  20.     def initialize_weights(self):
  21.         for m in self.modules():
  22.             if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
  23.                 nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  24.                 if m.bias is not None:
  25.                     nn.init.constant_(m.bias, 0)
  26.             elif isinstance(m, nn.BatchNorm2d):
  27.                 nn.init.constant_(m.weight, 1)
  28.                 nn.init.constant_(m.bias, 0)
  29.     def forward(self, x):
  30.         stage1 = self.stage_down1(x)
  31.         x = self.maxpool(stage1)
  32.         stage2 = self.stage_down2(x)
  33.         x = self.maxpool(stage2)
  34.         stage3 = self.stage_down3(x)
  35.         x = self.maxpool(stage3)
  36.         stage4 = self.stage_down4(x)
  37.         x = self.maxpool(stage4)
  38.         stage5 = self.stage_down5(x)
  39.         x = self.up4(stage5)
  40.         x = self.stage_up4(crop_and_concat(x, stage4))
  41.         x = self.up3(x)
  42.         x = self.stage_up3(crop_and_concat(x, stage3))
  43.         x = self.up2(x)
  44.         x = self.stage_up2(crop_and_concat(x, stage2))
  45.         x = self.up1(x)
  46.         x = self.stage_up1(crop_and_concat(x, stage1))
  47.         out = self.stage_out(x)
  48.         return out
复制代码
数据处理(dataloader)

数据集名称:CamVid
数据集下载地址:Object Recognition in Video Dataset
2.png

在这里进行下载,CamVid数据集有两种,一种是官方的就是上述的下载地址的,总共有32种类别,划分的会更加的细致。但是一般官网的太难打开了,所以我们可以通过Kaggle中的CamVid (Cambridge-Driving Labeled Video Database)进行下载。
还有一种就是11类别的(不包括背景),会将一些语义相近的内容进行合并,就划分的没有这么细致,任务难度也会比较低一些。(如果你在网上找不到的话,可以在评论区发言或是私聊我要取)
CamVid 数据集主要用于自动驾驶场景中的语义分割,包含驾驶场景中的道路、交通标志、车辆等类别的标注图像。该数据集旨在推动自动驾驶系统在道路场景中的表现。
数据特点

  • 图像数量:包括701帧视频序列图像,分为训练集、验证集和测试集。
  • 类别:包含32个类别(也有包含11个类别的),包括道路、建筑物、车辆、行人等。
  • 挑战:由于数据集主要来自城市交通场景,因此面临着动态变化的天气、光照、交通密度等挑战
这里我已经专门发了一篇博客对语义分割任务常用的数据集做了深入的介绍,已经具体讲解了其实现的处理代码。如果你对语义分割常用数据集有不了解的话,可以先去我的语义分割专栏中进行了解哦!!  我这里就直接附上代码了。
  1. import os
  2. from PIL import Image
  3. import albumentations as A
  4. from albumentations.pytorch.transforms import ToTensorV2
  5. from torch.utils.data import Dataset, DataLoader
  6. import numpy as np
  7. import torch
  8. # 11类
  9. Cam_CLASSES = [ "Unlabelled","Sky","Building","Pole",
  10.                 "Road","Sidewalk", "Tree","SignSymbol",
  11.                 "Fence","Car","Pedestrian","Bicyclist"]
  12. # 用于做可视化
  13. Cam_COLORMAP = [
  14.     [0, 0, 0],[128, 128, 128],[128, 0, 0],[192, 192, 128],
  15.     [128, 64, 128],[0, 0, 192],[128, 128, 0],[192, 128, 128],
  16.     [64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192]
  17. ]
  18. # 转换RGB mask为类别id的函数
  19. def mask_to_class(mask):
  20.     mask_class = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int64)
  21.     for idx, color in enumerate(Cam_COLORMAP):
  22.         color = np.array(color)
  23.         # 每个像素和当前颜色匹配
  24.         matches = np.all(mask == color, axis=-1)
  25.         mask_class[matches] = idx
  26.     return mask_class
  27. class CamVidDataset(Dataset):
  28.     def __init__(self, image_dir, label_dir):
  29.         self.image_dir = image_dir
  30.         self.label_dir = label_dir
  31.         self.transform = A.Compose([
  32.             A.Resize(224, 224),
  33.             A.HorizontalFlip(),
  34.             A.VerticalFlip(),
  35.             A.Normalize(),
  36.             ToTensorV2(),
  37.         ])
  38.         self.images = sorted(os.listdir(image_dir))
  39.         self.labels = sorted(os.listdir(label_dir))
  40.         assert len(self.images) == len(self.labels), "Images and labels count mismatch!"
  41.     def __len__(self):
  42.         return len(self.images)
  43.     def __getitem__(self, idx):
  44.         img_path = os.path.join(self.image_dir, self.images[idx])
  45.         label_path = os.path.join(self.label_dir, self.labels[idx])
  46.         image = np.array(Image.open(img_path).convert("RGB"))
  47.         label_rgb = np.array(Image.open(label_path).convert("RGB"))
  48.         # RGB转类别索引
  49.         mask = mask_to_class(label_rgb)
  50.         #mask = torch.from_numpy(np.array(mask)).long()
  51.         # Albumentations 需要 (H, W, 3) 和 (H, W)
  52.         transformed = self.transform(image=image, mask=mask)
  53.         return transformed['image'], transformed['mask'].long()
  54. def get_dataloader(data_path, batch_size=4, num_workers=4):
  55.     train_dir = os.path.join(data_path, 'train')
  56.     val_dir = os.path.join(data_path, 'val')
  57.     trainlabel_dir = os.path.join(data_path, 'train_labels')
  58.     vallabel_dir = os.path.join(data_path, 'val_labels')
  59.     train_dataset = CamVidDataset(train_dir, trainlabel_dir)
  60.     val_dataset = CamVidDataset(val_dir, vallabel_dir)
  61.     train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, num_workers=num_workers)
  62.     val_loader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, num_workers=num_workers)
  63.     return train_loader, val_loader
复制代码
评价指标(metric)

我们这里语义分割采用的评价指标为:PA(像素准确率),CPA(类别像素准确率),MPA(类别平均像素准确率),IoU(交并比),mIoU(平均交并比),FWIoU(频率加权交并比),mF1(平均F1分数)。
这里我已经专门发了一篇博客对这些平均指标做了深入的介绍,已经具体讲解了其实现的代码。如果你对这些评价指标有不了解的话,可以先去我的语义分割专栏中进行了解哦!!  我这里就直接附上代码了。
  1. import numpy as np
  2. __all__ = ['SegmentationMetric']
  3. class SegmentationMetric(object):
  4.     def __init__(self, numClass):
  5.         self.numClass = numClass
  6.         self.confusionMatrix = np.zeros((self.numClass,) * 2)
  7.     def genConfusionMatrix(self, imgPredict, imgLabel):
  8.         mask = (imgLabel >= 0) & (imgLabel < self.numClass)
  9.         label = self.numClass * imgLabel[mask] + imgPredict[mask]
  10.         count = np.bincount(label, minlength=self.numClass ** 2)
  11.         confusionMatrix = count.reshape(self.numClass, self.numClass)
  12.         return confusionMatrix
  13.     def addBatch(self, imgPredict, imgLabel):
  14.         assert imgPredict.shape == imgLabel.shape
  15.         self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
  16.         return self.confusionMatrix
  17.     def pixelAccuracy(self):
  18.         acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
  19.         return acc
  20.     def classPixelAccuracy(self):
  21.         denominator = self.confusionMatrix.sum(axis=1)
  22.         denominator = np.where(denominator == 0, 1e-12, denominator)
  23.         classAcc = np.diag(self.confusionMatrix) / denominator
  24.         return classAcc
  25.     def meanPixelAccuracy(self):
  26.         classAcc = self.classPixelAccuracy()
  27.         meanAcc = np.nanmean(classAcc)
  28.         return meanAcc
  29.     def IntersectionOverUnion(self):
  30.         intersection = np.diag(self.confusionMatrix)
  31.         union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
  32.             self.confusionMatrix)
  33.         union = np.where(union == 0, 1e-12, union)
  34.         IoU = intersection / union
  35.         return IoU
  36.     def meanIntersectionOverUnion(self):
  37.         mIoU = np.nanmean(self.IntersectionOverUnion())
  38.         return mIoU
  39.     def Frequency_Weighted_Intersection_over_Union(self):
  40.         denominator1 = np.sum(self.confusionMatrix)
  41.         denominator1 = np.where(denominator1 == 0, 1e-12, denominator1)
  42.         freq = np.sum(self.confusionMatrix, axis=1) / denominator1
  43.         denominator2 = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
  44.             self.confusionMatrix)
  45.         denominator2 = np.where(denominator2 == 0, 1e-12, denominator2)
  46.         iu = np.diag(self.confusionMatrix) / denominator2
  47.         FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
  48.         return FWIoU
  49.     def classF1Score(self):
  50.         tp = np.diag(self.confusionMatrix)
  51.         fp = self.confusionMatrix.sum(axis=0) - tp
  52.         fn = self.confusionMatrix.sum(axis=1) - tp
  53.         precision = tp / (tp + fp + 1e-12)
  54.         recall = tp / (tp + fn + 1e-12)
  55.         f1 = 2 * precision * recall / (precision + recall + 1e-12)
  56.         return f1
  57.     def meanF1Score(self):
  58.         f1 = self.classF1Score()
  59.         mean_f1 = np.nanmean(f1)
  60.         return mean_f1
  61.     def reset(self):
  62.         self.confusionMatrix = np.zeros((self.numClass, self.numClass))
  63.     def get_scores(self):
  64.         scores = {
  65.             'Pixel Accuracy': self.pixelAccuracy(),
  66.             'Class Pixel Accuracy': self.classPixelAccuracy(),
  67.             'Intersection over Union': self.IntersectionOverUnion(),
  68.             'Class F1 Score': self.classF1Score(),
  69.             'Frequency Weighted Intersection over Union': self.Frequency_Weighted_Intersection_over_Union(),
  70.             'Mean Pixel Accuracy': self.meanPixelAccuracy(),
  71.             'Mean Intersection over Union(mIoU)': self.meanIntersectionOverUnion(),
  72.             'Mean F1 Score': self.meanF1Score()
  73.         }
  74.         return scores
复制代码
训练流程(train)

到这里,所有的前期准备都已经就绪,我们就要开始训练我们的模型了。
  1. def parse_arguments():
  2.     parser = argparse.ArgumentParser()
  3.     parser.add_argument('--data_root', type=str, default='../../data/CamVid/CamVid(11)', help='Dataset root path')
  4.     parser.add_argument('--data_name', type=str, default='CamVid', help='Dataset class names')
  5.     parser.add_argument('--model', type=str, default='unet', help='Segmentation model')
  6.     parser.add_argument('--num_classes', type=int, default=12, help='Number of classes')
  7.     parser.add_argument('--epochs', type=int, default=50, help='Epochs')
  8.     parser.add_argument('--lr', type=float, default=0.005, help='Learning rate')
  9.     parser.add_argument('--momentum', type=float, default=0.9, help='Momentum')
  10.     parser.add_argument('--weight-decay', type=float, default=1e-4, help='Weight decay')
  11.     parser.add_argument('--batch_size', type=int, default=8, help='Batch size')
  12.     parser.add_argument('--checkpoint', type=str, default='./checkpoint', help='Checkpoint directory')
  13.     parser.add_argument('--resume', type=str, default=None, help='Resume checkpoint path')
  14.     return parser.parse_args()
复制代码
首先来看看我们的一些参数的设定,一般我们都是这样放在最前面,能够让人更加快速的了解其代码的一些核心参数设置。首先就是我们的数据集位置(data_root),然后就是我们的数据集名称(classes_name),这个暂时没什么用,因为我们目前只用了CamVid数据集,然后就是检测模型的选择(model),我们选择unet模型,数据集的类别数(num_classes),训练epoch数,这个你设置大一点也行,因为我们会在训练过程中保存最好结果的模型的。学习率(lr),动量(momentum),权重衰减(weight-decay),这些都属于模型超参数,大家可以尝试不同的数值,多试试,就会有个大致的了解的,批量大小(batch_size)根据自己电脑性能来设置,一般都是为2的倍数,保存权重的文件夹(checkpoint),是否继续训练(resume)。
[code]def train(args):    if not os.path.exists(args.checkpoint):        os.makedirs(args.checkpoint)    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    n_gpu = torch.cuda.device_count()    print(f"Device: {device}, GPUs available: {n_gpu}")    # Dataloader    train_loader, val_loader = get_dataloader(args.data_root, batch_size=args.batch_size)    train_dataset_size = len(train_loader.dataset)    val_dataset_size = len(val_loader.dataset)    print(f"Train samples: {train_dataset_size}, Val samples: {val_dataset_size}")    # Model    model = get_model(num_classes=args.num_classes)    model.to(device)    # Loss + Optimizer + Scheduler    criterion = nn.CrossEntropyLoss(ignore_index=0)    #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)    scaler = torch.cuda.amp.GradScaler()    # Resume    start_epoch = 0    best_miou = 0.0    if args.resume and os.path.isfile(args.resume):        print(f"Loading checkpoint '{args.resume}'")        checkpoint = torch.load(args.resume)        start_epoch = checkpoint['epoch']        best_miou = checkpoint['best_miou']        model.load_state_dict(checkpoint['model_state_dict'])        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])        print(f"Loaded checkpoint (epoch {start_epoch})")    # Training history    history = {        'train_loss': [],        'val_loss': [],        'pixel_accuracy': [],        'miou': []    }    print(f"
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册