pytorch 实现语义分割 PSPNet

语意分割是指一张图片上包含多个物体,通过语义分割可以识别物体分类、物体名称、像素识别的任务。和物体检测不同,他不会将物体框出来,而是根据像素的归属把物体标注出来。PSPNet 的输入是一张图片,例如300500,那么输出就是一个 300500 数组,数组中的值就是分类索引,如果是 20 中分类,分类索引就是 0-19。
PSPNet 的输入和输出,如下图
在这里插入图片描述

PSPNet 物体检测流程

  1. 预处理:图像调整为 475*475, 对颜色标准化。
  2. 图像输入神经网络,输出 21 * 475 * 475数据,每个数据就是当前像素分类概率
  3. 根据概率最高的分类生成图像
  4. 将图像还原成原有尺寸

数据准备

  1. 准备数据
  2. 数据增强
  3. 创建 Dataset 类
  4. 创建 Dataloader
准备数据

# 导入软件包
import os.path as osp
from PIL import Image

import torch.utils.data as data


def make_datapath_list(rootpath):
    """
    创建用于学习、验证的图像数据和标注数据的文件路径列表变量

    Parameters
    ----------
    rootpath : str
        指向数据文件夹的路径

    Returns
    -------
    ret : train_img_list, train_anno_list, val_img_list, val_anno_list
        保存了指向数据的路径列表变量
    """

    #创建指向图像文件和标注数据的路径的模板
    imgpath_template = osp.join(rootpath, 'JPEGImages', '%s.jpg')
    annopath_template = osp.join(rootpath, 'SegmentationClass', '%s.png')

    #训练和验证,分别获取相应的文件 ID(文件名)
    train_id_names = osp.join(rootpath + 'ImageSets/Segmentation/train.txt')
    val_id_names = osp.join(rootpath + 'ImageSets/Segmentation/val.txt')

    #创建指向训练数据的图像文件和标注文件的路径列表变量
    train_img_list = list()
    train_anno_list = list()

    for line in open(train_id_names):
        file_id = line.strip()  #删除空格和换行
        img_path = (imgpath_template % file_id)  #图像的路径
        anno_path = (annopath_template % file_id)  #标注数据的路径
        train_img_list.append(img_path)
        train_anno_list.append(anno_path)

    #创建指向验证数据的图像文件和标注文件的路径列表变量
    val_img_list = list()
    val_anno_list = list()

    for line in open(val_id_names):
        file_id = line.strip()  #删除空格和换行符
        img_path = (imgpath_template % file_id)  #图像的路径
        anno_path = (annopath_template % file_id)  #标注数据的路径
        val_img_list.append(img_path)
        val_anno_list.append(anno_path)

    return train_img_list, train_anno_list, val_img_list, val_anno_list


#确认执行结果,获取文件路径列表
rootpath = "./data/VOCdevkit/VOC2012/"

train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(
    rootpath=rootpath)

print(train_img_list[0])
print(train_anno_list[0])

创建 dataset
#首先导入数据处理类和数据增强类
from utils.data_augumentation import Compose, Scale, RandomRotation, RandomMirror, Resize, Normalize_Tensor


class DataTransform():
    """
    图像和标注的预处理类。训练和验证时分别采取不同的处理方法
    将图像的尺寸调整为input_size x input_size
    训练时进行数据增强处理


    Attributes
    ----------
    input_size : int
        指定调整图像尺寸的大小
    color_mean : (R, G, B)
        指定每个颜色通道的平均值
    color_std : (R, G, B)
        指定每个颜色通道的标准差
    """

    def __init__(self, input_size, color_mean, color_std):
        self.data_transform = {
            'train': Compose([
                Scale(scale=[0.5, 1.5]),  #图像的放大
                RandomRotation(angle=[-10, 10]),  #旋转
                RandomMirror(), #随机镜像
                Resize(input_size),  #调整尺寸(input_size)
                Normalize_Tensor(color_mean, color_std)  #颜色信息的正规化和张量化
            ]),
            'val': Compose([
                Resize(input_size),  #调整图像尺寸(input_size))
                Normalize_Tensor(color_mean, color_std) #颜色信息的正规化和张量化
            ])
        }

    def __call__(self, phase, img, anno_class_img):
        """
        Parameters
        ----------
        phase : 'train' or 'val'
            指定预处理的执行模式。
        """
        return self.data_transform[phase](img, anno_class_img)

class VOCDataset(data.Dataset):
    """
    用于创建VOC2012的Dataset的类,继承自PyTorch的Dataset类

    Attributes
    ----------
    img_list : 
        保存了图像路径列表
    anno_list :
        保存了标注路径列表n
    phase : 'train' or 'test'
        保存了标注路径列表
    transform : object
        预处理类的实例
    """

    def __init__(self, img_list, anno_list, phase, transform):
        self.img_list = img_list
        self.anno_list = anno_list
        self.phase = phase
        self.transform = transform

    def __len__(self):
        '''返回图像的张数'''
        return len(self.img_list)

    def __getitem__(self, index):
        '''
        获取经过预处理的图像的张量形式的数据和标注
        '''
        img, anno_class_img = self.pull_item(index)
        return img, anno_class_img

    def pull_item(self, index):
        ''''获取图像的张量形式的数据和标注'''

        # 1.读入图像数据
        image_file_path = self.img_list[index]
        img = Image.open(image_file_path)   #[高度][宽度][颜色RGB]

        # 2.读入标注图像数据
        anno_file_path = self.anno_list[index]
        anno_class_img = Image.open(anno_file_path)   #[ 高度 ][ 宽度 ]

        # 3.进行预处理操作
        img, anno_class_img = self.transform(self.phase, img, anno_class_img)

        return img, anno_class_img

#确认执行结果

#(RGB)颜色的平均值和均方差
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

#(RGB)颜色的平均值和均方差
train_dataset = VOCDataset(train_img_list, train_anno_list, phase="train", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

val_dataset = VOCDataset(val_img_list, val_anno_list, phase="val", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

#读取数据的示例
print(val_dataset.__getitem__(0)[0].shape)
print(val_dataset.__getitem__(0)[1].shape)
print(val_dataset.__getitem__(0))

#创建数据加载器

batch_size = 8

train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

#集中保存到字典型变量中
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}

#确认执行结果
batch_iterator = iter(dataloaders_dict["val"])  #转换为迭代器
imges, anno_class_imges = next(batch_iterator)  #取出第一个元素
print(imges.size())  # torch.Size([8, 3, 475, 475])
print(anno_class_imges.size())  # torch.Size([8, 3, 475, 475])

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# 每次执行都会改变

# ·读取图像数据
index = 0
imges, anno_class_imges = train_dataset.__getitem__(index)

#图像的显示
img_val = imges
img_val = img_val.numpy().transpose((1, 2, 0))
plt.imshow(img_val)
plt.show()

# 标注图像的显示
anno_file_path = train_anno_list[0]
anno_class_img = Image.open(anno_file_path)   # [高度][宽度][颜色RGB]
p_palette = anno_class_img.getpalette()

anno_class_img_val = anno_class_imges.numpy()
anno_class_img_val = Image.fromarray(np.uint8(anno_class_img_val), mode="P")
anno_class_img_val.putpalette(p_palette)
plt.imshow(anno_class_img_val)
plt.show()

# 读取图像数据
index = 0
imges, anno_class_imges = val_dataset.__getitem__(index)

# 画像的表示
img_val = imges
img_val = img_val.numpy().transpose((1, 2, 0))
plt.imshow(img_val)
plt.show()

# 标注图像的显示
anno_file_path = train_anno_list[0]
anno_class_img = Image.open(anno_file_path)   # [高度][宽度][颜色RGB]
p_palette = anno_class_img.getpalette()

anno_class_img_val = anno_class_imges.numpy()
anno_class_img_val = Image.fromarray(np.uint8(anno_class_img_val), mode="P")
anno_class_img_val.putpalette(p_palette)
plt.imshow(anno_class_img_val)
plt.show()

PSPNet 网络构建

PSPNet网络包括 Feature(Encoder)、Pyramid Pooling、Decoder、AuxLoss 四个模块构成。

# 实现 PSPNet 网络
# 导入软件包
import torch
import torch.nn as nn
import torch.nn.functional as F

class PSPNet(nn.Module):
    def __init__(self, n_classes):
        super(PSPNet, self).__init__()

        #参数设置
        block_config = [3, 4, 6, 3]  # resnet50
        img_size = 475
        img_size_8 = 60  #设为img_size的1/8

        #创建组成子网络的四个模块
        self.feature_conv = FeatureMap_convolution()
        self.feature_res_1 = ResidualBlockPSP(
            n_blocks=block_config[0], in_channels=128, mid_channels=64, out_channels=256, stride=1, dilation=1)
        self.feature_res_2 = ResidualBlockPSP(
            n_blocks=block_config[1], in_channels=256, mid_channels=128, out_channels=512, stride=2, dilation=1)
        self.feature_dilated_res_1 = ResidualBlockPSP(
            n_blocks=block_config[2], in_channels=512, mid_channels=256, out_channels=1024, stride=1, dilation=2)
        self.feature_dilated_res_2 = ResidualBlockPSP(
            n_blocks=block_config[3], in_channels=1024, mid_channels=512, out_channels=2048, stride=1, dilation=4)

        self.pyramid_pooling = PyramidPooling(in_channels=2048, pool_sizes=[
            6, 3, 2, 1], height=img_size_8, width=img_size_8)

        self.decode_feature = DecodePSPFeature(
            height=img_size, width=img_size, n_classes=n_classes)

        self.aux = AuxiliaryPSPlayers(
            in_channels=1024, height=img_size, width=img_size, n_classes=n_classes)

    def forward(self, x):
        x = self.feature_conv(x)
        x = self.feature_res_1(x)
        x = self.feature_res_2(x)
        x = self.feature_dilated_res_1(x)

        output_aux = self.aux(x)  #将Feature模块中转到Aux模块

        x = self.feature_dilated_res_2(x)

        x = self.pyramid_pooling(x)
        output = self.decode_feature(x)

        return (output, output_aux)

Feature 子网络

Feature 包括 5 个子网络,FeatureMap、两个ResidualBlockPSP和两个 dilated 版的 ResidualBlockPSP。

# Feature Map Convolution 子网络
class conv2DBatchNormRelu(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, bias):
        super(conv2DBatchNormRelu, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride, padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        #inplase指定不保存输入数据,直接计算输出结果,达到节约内存的目的

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        outputs = self.relu(x)

        return outputs

class FeatureMap_convolution(nn.Module):
    def __init__(self):
        '''创建网络结构'''
        super(FeatureMap_convolution, self).__init__()

        # #卷积层1
        in_channels, out_channels, kernel_size, stride, padding, dilation, bias = 3, 64, 3, 2, 1, 1, False
        self.cbnr_1 = conv2DBatchNormRelu(
            in_channels, out_channels, kernel_size, stride, padding, dilation, bias)

        #卷积层2
        in_channels, out_channels, kernel_size, stride, padding, dilation, bias = 64, 64, 3, 1, 1, 1, False
        self.cbnr_2 = conv2DBatchNormRelu(
            in_channels, out_channels, kernel_size, stride, padding, dilation, bias)

        #卷积层3
        in_channels, out_channels, kernel_size, stride, padding, dilation, bias = 64, 128, 3, 1, 1, 1, False
        self.cbnr_3 = conv2DBatchNormRelu(
            in_channels, out_channels, kernel_size, stride, padding, dilation, bias)

        #最大池化层
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        x = self.cbnr_1(x)
        x = self.cbnr_2(x)
        x = self.cbnr_3(x)
        outputs = self.maxpool(x)
        return outputs


class ResidualBlockPSP(nn.Sequential):
    def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation):
        super(ResidualBlockPSP, self).__init__()

        #设置bottleNeckPSP
        self.add_module(
            "block1",
            bottleNeckPSP(in_channels, mid_channels,
                          out_channels, stride, dilation)
        )

        #循环设置bottleNeckIdentifyPSP
        for i in range(n_blocks - 1):
            self.add_module(
                "block" + str(i+2),
                bottleNeckIdentifyPSP(
                    out_channels, mid_channels, stride, dilation)
            )

class conv2DBatchNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, bias):
        super(conv2DBatchNorm, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride, padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        outputs = self.batchnorm(x)

        return outputs

class bottleNeckPSP(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, stride, dilation):
        super(bottleNeckPSP, self).__init__()

        self.cbr_1 = conv2DBatchNormRelu(
            in_channels, mid_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.cbr_2 = conv2DBatchNormRelu(
            mid_channels, mid_channels, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.cb_3 = conv2DBatchNorm(
            mid_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        #跳跃链接
        self.cb_residual = conv2DBatchNorm(
            in_channels, out_channels, kernel_size=1, stride=stride, padding=0, dilation=1, bias=False)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        conv = self.cb_3(self.cbr_2(self.cbr_1(x)))
        residual = self.cb_residual(x)
        return self.relu(conv + residual)

class bottleNeckIdentifyPSP(nn.Module):
    def __init__(self, in_channels, mid_channels, stride, dilation):
        super(bottleNeckIdentifyPSP, self).__init__()

        self.cbr_1 = conv2DBatchNormRelu(
            in_channels, mid_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.cbr_2 = conv2DBatchNormRelu(
            mid_channels, mid_channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=False)
        self.cb_3 = conv2DBatchNorm(
            mid_channels, in_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        conv = self.cb_3(self.cbr_2(self.cbr_1(x)))
        residual = x
        return self.relu(conv + residual)



Pyramid Pooling池化
class PyramidPooling(nn.Module):
    def __init__(self, in_channels, pool_sizes, height, width):
        super(PyramidPooling, self).__init__()

        #在forward中使用的图像尺寸
        self.height = height
        self.width = width

        #各个卷积层输出的通道数
        out_channels = int(in_channels / len(pool_sizes))

       #生成每个卷积层
        # 该实现方法非常“耿直”,虽然笔者很想用 for 循环来编写这段代码,但最后还是决定优先以容易理解的方                     式编写
        # pool_sizes: [6, 3, 2, 1]
        self.avpool_1 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[0])
        self.cbr_1 = conv2DBatchNormRelu(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.avpool_2 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[1])
        self.cbr_2 = conv2DBatchNormRelu(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.avpool_3 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[2])
        self.cbr_3 = conv2DBatchNormRelu(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

        self.avpool_4 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[3])
        self.cbr_4 = conv2DBatchNormRelu(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False)

    def forward(self, x):

        out1 = self.cbr_1(self.avpool_1(x))
        out1 = F.interpolate(out1, size=(
            self.height, self.width), mode="bilinear", align_corners=True)

        out2 = self.cbr_2(self.avpool_2(x))
        out2 = F.interpolate(out2, size=(
            self.height, self.width), mode="bilinear", align_corners=True)

        out3 = self.cbr_3(self.avpool_3(x))
        out3 = F.interpolate(out3, size=(
            self.height, self.width), mode="bilinear", align_corners=True)

        out4 = self.cbr_4(self.avpool_4(x))
        out4 = F.interpolate(out4, size=(
            self.height, self.width), mode="bilinear", align_corners=True)

        #最后将结果进行合并,指定dim=1按通道数的维数进行合并
        output = torch.cat([x, out1, out2, out3, out4], dim=1)

        return output

Decoder和 AuxLoss的实现

class DecodePSPFeature(nn.Module):
    def __init__(self, height, width, n_classes):
        super(DecodePSPFeature, self).__init__()

        #在forward中使用的图像尺寸
        self.height = height
        self.width = width

        self.cbr = conv2DBatchNormRelu(
            in_channels=4096, out_channels=512, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(
            in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(
            x, size=(self.height, self.width), mode="bilinear", align_corners=True)

        return output

class AuxiliaryPSPlayers(nn.Module):
    def __init__(self, in_channels, height, width, n_classes):
        super(AuxiliaryPSPlayers, self).__init__()

       #在forward中使用的图像尺寸
        self.height = height
        self.width = width

        self.cbr = conv2DBatchNormRelu(
            in_channels=in_channels, out_channels=256, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(
            in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(
            x, size=(self.height, self.width), mode="bilinear", align_corners=True)

        return output

# 定义模型
net = PSPNet(n_classes=21)
net
 #生成伪数据
batch_size = 2
dummy_img = torch.rand(batch_size, 3, 475, 475)

#计算
outputs = net(dummy_img)
print(outputs)

训练
# 导入软件包
import random
import math
import time
import pandas as pd
import numpy as np

import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.optim as optim

# 初始设定
# Setup seeds
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

from utils.dataloader import make_datapath_list, DataTransform, VOCDataset

# 创建文件路径列表
rootpath = "./data/VOCdevkit/VOC2012/"
train_img_list, train_anno_list, val_img_list, val_anno_list = make_datapath_list(
    rootpath=rootpath)

# Dataset作成
#(RGB) 颜色的平均值和均方差
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

train_dataset = VOCDataset(train_img_list, train_anno_list, phase="train", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

val_dataset = VOCDataset(val_img_list, val_anno_list, phase="val", transform=DataTransform(
    input_size=475, color_mean=color_mean, color_std=color_std))

#生成DataLoader
batch_size = 8

train_dataloader = data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

val_dataloader = data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False)

#集中保存到字典型变量中
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}


from utils.pspnet import PSPNet

# 制作网络模型
#使用ADE20K数据集中事先训练好的模型,ADE20K的分类数量是150
net = PSPNet(n_classes=150)

#载入 ADE20K 中事先训练好的参数
state_dict = torch.load("./weights/pspnet50_ADE20K.pth")
net.load_state_dict(state_dict)

#将分类用的卷积层替换为输出数量为21的卷积层
n_classes = 21
net.decode_feature.classification = nn.Conv2d(
    in_channels=512, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

net.aux.classification = nn.Conv2d(
    in_channels=256, out_channels=n_classes, kernel_size=1, stride=1, padding=0)

#对替换的卷积层进行初始化。由于激励函数是Sigmoid,因此使用Xavier进行初始化


def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:  #如果bias存在
            nn.init.constant_(m.bias, 0.0)


net.decode_feature.classification.apply(weights_init)
net.aux.classification.apply(weights_init)


print('网络设置完毕 :成功的载入了事先训练完毕的权重')


#设置损失函数
class PSPLoss(nn.Module):
    """#设置损失函数"""

    def __init__(self, aux_weight=0.4):
        super(PSPLoss, self).__init__()
        self.aux_weight = aux_weight  #aux_loss的权重

    def forward(self, outputs, targets):
        """
        损失函数的计算。

        Parameters
        ----------
        outputs : PSPNet的输出(tuple)
            (output=torch.Size([num_batch, 21, 475, 475]), output_aux=torch.Size([num_batch, 21, 475, 475]))。

        targets : [num_batch, 475, 4755]
            正解的标注信息

        Returns
        -------
        loss :张量
            损失值
        """

        loss = F.cross_entropy(outputs[0], targets, reduction='mean')
        loss_aux = F.cross_entropy(outputs[1], targets, reduction='mean')

        return loss+self.aux_weight*loss_aux


criterion = PSPLoss(aux_weight=0.4)


# #由于使用的是微调,因此要降低学习率く
optimizer = optim.SGD([
    {'params': net.feature_conv.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_res_2.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_1.parameters(), 'lr': 1e-3},
    {'params': net.feature_dilated_res_2.parameters(), 'lr': 1e-3},
    {'params': net.pyramid_pooling.parameters(), 'lr': 1e-3},
    {'params': net.decode_feature.parameters(), 'lr': 1e-2},
    {'params': net.aux.parameters(), 'lr': 1e-2},
], momentum=0.9, weight_decay=0.0001)


# #设置调度器
def lambda_epoch(epoch):
    max_epoch = 30
    return math.pow((1-epoch/max_epoch), 0.9)


scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_epoch)

#创建对模型进行训练的函数


def train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs):

    #确认GPU是否可用
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用的设备 :", device)

    #将网络载入GPU中
    net.to(device)

    #如果网络相对固定,开启高速处理选项
    torch.backends.cudnn.benchmark = True

    #图像的张数
    num_train_imgs = len(dataloaders_dict["train"].dataset)
    num_val_imgs = len(dataloaders_dict["val"].dataset)
    batch_size = dataloaders_dict["train"].batch_size

    #设置迭代计数器
    iteration = 1
    logs = []

    # multiple minibatch
    batch_multiplier = 3

    # epochのループ
    for epoch in range(num_epochs):

        #epoch的循环
        t_epoch_start = time.time()
        t_iter_start = time.time()
        epoch_train_loss = 0.0  #epoch的损失和
        epoch_val_loss = 0.0  #epoch的损失和

        print('-------------')
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-------------')

        #对每轮epoch进行训练和验证的循环
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()  #将模式设置为训练模式
                scheduler.step()  #更新最优化调度器
                optimizer.zero_grad()
                print('(train)')

            else:
                if((epoch+1) % 5 == 0):
                    net.eval()   #将模型设置为验证模式
                    print('-------------')
                    print('(val)')
                else:
                    #每5轮进行1次验证
                    continue

           #从数据加载器中读取每个小批量并进行循环
            count = 0  # multiple minibatch
            for imges, anno_class_imges in dataloaders_dict[phase]:
                #如果小批量的尺寸是1,批量正规化处理会报错,因此需要避免
                if imges.size()[0] == 1:
                    continue

                #如果GPU可用,将数据传输到GPU中
                imges = imges.to(device)
                anno_class_imges = anno_class_imges.to(device)

                
                #使用multiple minibatch对参数进行更新
                if (phase == 'train') and (count == 0):
                    optimizer.step()
                    optimizer.zero_grad()
                    count = batch_multiplier

                #正向传播计算
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(imges)
                    loss = criterion(
                        outputs, anno_class_imges.long()) / batch_multiplier

                    #训练时采用反向传播
                    if phase == 'train':
                        loss.backward() #梯度的计算
                        count -= 1  # multiple minibatch

                        if (iteration % 10 == 0):  #每10次迭代显示一次loss
                            t_iter_finish = time.time()
                            duration = t_iter_finish - t_iter_start
                            print('迭代 {} || Loss: {:.4f} || 10iter: {:.4f} sec.'.format(
                                iteration, loss.item()/batch_size*batch_multiplier, duration))
                            t_iter_start = time.time()

                        epoch_train_loss += loss.item() * batch_multiplier
                        iteration += 1

                    #验证时
                    else:
                        epoch_val_loss += loss.item() * batch_multiplier

        #每个epoch的phase的loss和正解率
        t_epoch_finish = time.time()
        print('-------------')
        print('epoch {} || Epoch_TRAIN_Loss:{:.4f} ||Epoch_VAL_Loss:{:.4f}'.format(
            epoch+1, epoch_train_loss/num_train_imgs, epoch_val_loss/num_val_imgs))
        print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))
        t_epoch_start = time.time()

        #保存日志
        log_epoch = {'epoch': epoch+1, 'train_loss': epoch_train_loss /
                     num_train_imgs, 'val_loss': epoch_val_loss/num_val_imgs}
        logs.append(log_epoch)
        df = pd.DataFrame(logs)
        df.to_csv("log_output.csv")

    #保存最终的网络
    torch.save(net.state_dict(), 'weights/pspnet50_' +
               str(epoch+1) + '.pth')

# 学习和验证的实现
num_epochs = 30
train_model(net, dataloaders_dict, criterion, scheduler, optimizer, num_epochs=num_epochs)

至此, 网络搭建完成,CPU 训练太慢了,5.1 假期找个 GPU 服务器训练试试,然后上结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/583591.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Redis基本數據結構 ― List

Redis基本數據結構 ― List 介紹常用命令範例1. 將元素推入List中2. 取得List內容3. 彈出元素 介紹 Redis中的List結構是一個雙向鏈表。 LPUSH LPOP StackLPUSH RPOP QueueLPUSH BRPOP Queue(消息隊列) 常用命令 命令功能LPUSH將元素推入列表左端RPUSH將元素推入列表右…

特别推荐一个学习开发编程的网站

http://www.somecore.cn/ 为开发人员提供一系列好看的技术备忘单,方便开发过程中速查基本语法、快捷键、命令,节省查找时间,提高开发效率。 【人生苦短,抓住重点】

Java 面向对象—重载和重写/覆盖(面试)

重载和重写/覆盖: 重载(overload): Java重载是发生在本类中的,允许同一个类中,有多个同名方法存在,方法名可以相同,方法参数的个数和类型不同,即要求形参列表不一致。重载…

有趣的 CSS 图标整合技术!sprites精灵图,css贴图定位

你好,我是云桃桃。 一个希望帮助更多朋友快速入门 WEB 前端的程序媛。 云桃桃-大专生,一枚程序媛,感谢关注。回复 “前端基础题”,可免费获得前端基础 100 题汇总,回复 “前端工具”,可获取 Web 开发工具合…

【C语言进阶】程序编译中的预处理操作

📚作者简介:爱编程的小马,正在学习C/C,Linux及MySQL.. 📚以后会将数据结构收录为一个系列,敬请期待 ● 本期内容讲解C语言中程序预处理要做的事情 目录 1.1 预处理符号 1.2 #define 1.2.1 #define定义标识…

数据结构(01)——链表OJ

目录 移除链表元素 思路1 不创建虚拟头节点 思路2 创建虚拟头节点 反转链表 寻找链表中间节点 判断链表是否相交 回文链表 环形链表 环形链表|| 移除链表元素 . - 力扣(LeetCode) 要想移除链表的元素,那么只需要将目标节点的前一…

07_for循环返回值while循环

文章目录 1.循环返回值2.yield接收for返回值3.scala调用yield方法创建线程对象4.scala中的while循环5.scala中的流程控制 1.循环返回值 for循环返回值是Unit 原因是防止产生歧义; 2.yield接收for返回值 // 2.yield关键字打破循环,可以使for循环输出…

智慧农业设备——虫情监测系统

随着科技的不断进步和农业生产的日益现代化,智慧农业成为了新时代农业发展的重要方向。其中,虫情监测系统作为智慧农业的重要组成部分,正逐渐受到广大农户和农业专家的关注。 虫情监测系统是一种基于现代传感技术、图像识别技术和大数据分析技…

面试笔记——线程池

线程池的核心参数&#xff08;原理&#xff09; public ThreadPoolExecutor(int corePoolSize,int maximumPoolSize,long keepAliveTime,TimeUnit unit,BlockingQueue<Runnable> workQueue,ThreadFactory threadFactory,RejectedExecutionHandler handler)corePoolSize …

25计算机考研院校数据分析 | 四川大学

四川大学(Sichuan University)简称“川大”&#xff0c;由中华人民共和国教育部直属&#xff0c;中央直管副部级建制&#xff0c;是世界一流大学建设高校、985工程”、"211工程"重点建设的高水平综合性全国重点大学&#xff0c;入选”2011计划"、"珠峰计划…

PostgreSQL的学习心得和知识总结(一百四十)|深入理解PostgreSQL数据库 psql工具 \set 变量内部及HOOK机制

目录结构 注&#xff1a;提前言明 本文借鉴了以下博主、书籍或网站的内容&#xff0c;其列表如下&#xff1a; 1、参考书籍&#xff1a;《PostgreSQL数据库内核分析》 2、参考书籍&#xff1a;《数据库事务处理的艺术&#xff1a;事务管理与并发控制》 3、PostgreSQL数据库仓库…

【能力展现】魔改ZXING源码实现商业级DM码检测能力

学习《OpenCV应用开发&#xff1a;入门、进阶与工程化实践》一书 做真正的OpenCV开发者&#xff0c;从入门到入职&#xff0c;一步到位&#xff01; 什么是DM码 dataMatrix是一种二维码&#xff0c;原名datacode&#xff0c;由美国国际资料公司于1989年发明。dataMatrix二维码…

GuildFi升级为Zentry的背后 链游公会的探索与转型

​链游即区块链游戏&#xff0c;指依托区块链技术构建的游戏产品。其与传统游戏的最大区别在于区块链的去中心化特性对玩家的资产有着天然的确权行为&#xff0c;因此玩家在链游中的资产是作为玩家的个人资产存在。较于 GameFi 来说&#xff0c;链游的包含范围更大&#xff0c;…

吴恩达机器学习笔记:第 8 周-14降维(Dimensionality Reduction) 14.3-14.5

目录 第 8 周 14、 降维(Dimensionality Reduction)14.3 主成分分析问题14.4 主成分分析算法14.5 选择主成分的数量 第 8 周 14、 降维(Dimensionality Reduction) 14.3 主成分分析问题 主成分分析(PCA)是最常见的降维算法。 在 PCA 中&#xff0c;我们要做的是找到一个方向…

【高校科研前沿】华东师大白开旭教授博士研究生李珂为一作在RSE发表团队最新成果:基于波谱特征优化的全球大气甲烷智能反演技术

文章简介 论文名称&#xff1a;Developing unbiased estimation of atmospheric methane via machine learning and multiobjective programming based on TROPOMI and GOSAT data&#xff08;基于TROPOMI和GOSAT数据&#xff0c;通过机器学习和多目标规划实现大气甲烷的无偏估…

OS复习笔记ch5-1

引言 讲解完进程和线程之后&#xff0c;我们就要来到进程的并发控制这里&#xff0c;这一章和下一章是考试喜欢考察的点&#xff0c;有可能会出大题&#xff0c;面试也有可能会被频繁问到&#xff0c;所以章节内容较多。请小伙伴们慢慢食用&#xff0c;看完之后多思考加强消化…

【JPE】顶刊测算-工业智能化数据(附stata代码)

数据来源&#xff1a;国家TJ局、CEC2008、IFR数据 时间跨度&#xff1a;2006-2019年 数据范围&#xff1a;各省、地级市 数据指标&#xff1a; 本数据集展示了2006-2019年各省、各地级市的共工业智能化水平的数据。本数据集包含三种构建工业机器人密度来反映工业智能化水平的方…

基于Springboot的数字化农家乐管理平台(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的数字化农家乐管理平台&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系…

Apache Seata基于改良版雪花算法的分布式UUID生成器分析2

title: 关于新版雪花算法的答疑 author: selfishlover keywords: [Seata, snowflake, UUID, page split] date: 2021/06/21 本文来自 Apache Seata官方文档&#xff0c;欢迎访问官网&#xff0c;查看更多深度文章。 关于新版雪花算法的答疑 在上一篇关于新版雪花算法的解析中…

web前端学习笔记4

4. 盒子模型 4.0 代码地址 https://gitee.com/qiangge95243611/java118/tree/master/web/day044.1 什么是盒子模型(Box Model) 所有HTML元素可以看作盒子,在CSS中,"box model"这一术语是用来设计和布局时使用。 CSS盒模型本质上是一个盒子,封装周围的HTML元素,…
最新文章