欢迎访问 生活随笔!

生活随笔

当前位置: 首页 >

基于SegNet和UNet的遥感图像分割代码解读

发布时间:2025/3/15 51 豆豆
生活随笔 收集整理的这篇文章主要介绍了 基于SegNet和UNet的遥感图像分割代码解读 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

基于SegNet和UNet的遥感图像分割代码解读

目录

    • 基于SegNet和UNet的遥感图像分割代码解读
  • 前言
  • 概述
  • 代码框架
  • 代码细节分析
    • 划分数据集gen_dataset.py
    • UNet模型训练unet_train.py
    • 模型融合combind.py
    • UNet模型预测unet_predict.py
    • 分类结果集成ensemble.py
    • SegNet模型训练segnet_train.py

前言

上了一学期的课,趁着寒假有时间,看了往年论文和部分比赛的代码,现在整理出来。整理的这部分内容以实际操作为主,主要讲解代码部分的分析。

概述

首先来分享一个小项目,基于SegNet和UNet的遥感图像比赛。代码来自github,这是对项目的简要介绍。

代码框架

以下是项目的代码结构:总共有4个子目录,分别是deprecated、ensemble、segnet、unet,其中deprecated是作者的一些代码草稿,ensemble是对不同分类结果的集成,segnet和unet分别是两个典型网络的网络架构、训练代码、预测代码、划分训练集和测试集的代码。

代码细节分析

划分数据集gen_dataset.py

import cv2 import random import os import numpy as np from tqdm import tqdmimg_w = 256 img_h = 256 # 数据集一共5张图片 image_sets = ['1.png','2.png','3.png','4.png','5.png']def gamma_transform(img, gamma):gamma_table = [np.power(x / 255.0, gamma) * 255.0 for x in range(256)]gamma_table = np.round(np.array(gamma_table)).astype(np.uint8)# LUT: Look Up Table查找表,通过LUT变换可以改变图像的曝光和色彩return cv2.LUT(img, gamma_table)def random_gamma_transform(img, gamma_vari):log_gamma_vari = np.log(gamma_vari)alpha = np.random.uniform(-log_gamma_vari, log_gamma_vari)gamma = np.exp(alpha)return gamma_transform(img, gamma)# 旋转image def rotate(xb,yb,angle):M_rotate = cv2.getRotationMatrix2D((img_w/2, img_h/2), angle, 1)xb = cv2.warpAffine(xb, M_rotate, (img_w, img_h))yb = cv2.warpAffine(yb, M_rotate, (img_w, img_h))return xb,ybdef blur(img):# cv2.blur(img,(size,size))表示对img使用尺寸为size x size的均值滤波器进行平滑img = cv2.blur(img, (3, 3));return img # 加噪声 def add_noise(img):for i in range(200): #添加点噪声temp_x = np.random.randint(0,img.shape[0])temp_y = np.random.randint(0,img.shape[1])img[temp_x][temp_y] = 255return img# 数据增强:图像旋转、gamma变换、模糊变换、加噪声 def data_augment(xb,yb):if np.random.random() < 0.25:xb,yb = rotate(xb,yb,90)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,180)if np.random.random() < 0.25:xb,yb = rotate(xb,yb,270)if np.random.random() < 0.25:xb = cv2.flip(xb, 1) # flipcode > 0:沿y轴翻转yb = cv2.flip(yb, 1)if np.random.random() < 0.25:xb = random_gamma_transform(xb,1.0)if np.random.random() < 0.25:xb = blur(xb)if np.random.random() < 0.2:xb = add_noise(xb)return xb,yb # 构建数据集 def creat_dataset(image_num = 50000, mode = 'original'):print('creating dataset...')# len(image_sets) = 5image_each = image_num / len(image_sets)g_count = 0for i in tqdm(range(len(image_sets))):count = 0# 读取源图像和标记图像src_img = cv2.imread('./data/src/' + image_sets[i]) # 3 channelslabel_img = cv2.imread('./data/road_label/' + image_sets[i],cv2.IMREAD_GRAYSCALE) # single channelX_height,X_width,_ = src_img.shapewhile count < image_each:# img_w = img_h = 256random_width = random.randint(0, X_width - img_w - 1)random_height = random.randint(0, X_height - img_h - 1)# 随机截取img_h x img_w大小的图像src_roi = src_img[random_height: random_height + img_h, random_width: random_width + img_w,:]label_roi = label_img[random_height: random_height + img_h, random_width: random_width + img_w]# 如果是增强模式,那么对源图像和标记图像使用数据增强if mode == 'augment':src_roi,label_roi = data_augment(src_roi,label_roi)visualize = np.zeros((256,256)).astype(np.uint8)visualize = label_roi *50# 划分数据集cv2.imwrite(('./unet_train/visualize/%d.png' % g_count),visualize)cv2.imwrite(('./unet_train/road/src/%d.png' % g_count),src_roi)cv2.imwrite(('./unet_train/road/label/%d.png' % g_count),label_roi)count += 1 g_count += 1if __name__=='__main__': creat_dataset(mode='augment')

UNet模型训练unet_train.py

#coding=utf-8 import matplotlib # matplotlib.use('Agg')必须放在import matplotlib.pyplot as plt前面,这个语句的意思是不使用交互式页面,仅仅保存图像而是不把图像shhow出来 matplotlib.use("Agg") import matplotlib.pyplot as plt import argparse import numpy as np from keras.models import Sequential from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation,Input from keras.utils.np_utils import to_categorical from keras.preprocessing.image import img_to_array from keras.callbacks import ModelCheckpoint from sklearn.preprocessing import LabelEncoder from keras.models import Model from keras.layers.merge import concatenate from PIL import Image import matplotlib.pyplot as plt import cv2 import random import os from tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "4" # 设置随机数种子,以便每次产生的随机数一样,方便比较在同一批数据上比较实验结果 seed = 7 np.random.seed(seed) #data_shape = 360*480 img_w = 256 img_h = 256 #有一个为背景 #n_label = 4+1 n_label = 1 # 总共5个类别 classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:# cv2.IMREAD_GRAYSCALE将灰度图读取成灰度图,否则cv2.imread默认将图像读取为RGBimg = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)# 归一化img = np.array(img,dtype="float") / 255.0return img # 训练数据路径 filepath ='./unet_train/' # 划分训练集和验证集,其中用25%的数据来做验证集 def get_train_val(val_rate = 0.25):train_url = [] train_set = []val_set = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)# 打乱顺序之后的前25%作为验证集,剩余75%作为训练集for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set # 产生训练数据 # data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img)label = load_img(filepath + 'label/' + url, grayscale=True) label = img_to_array(label)train_label.append(label) if batch % batch_size==0: #print 'get enough batch!\n'train_data = np.array(train_data) train_label = np.array(train_label) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # 产生验证数据 # data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) valid_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label)valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0 # 定义unet,整体上来看是一个对称的U型结构 def unet():inputs = Input((3, img_w, img_h))conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(inputs)conv1 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv1)pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(pool1)conv2 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv2)pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(pool2)conv3 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv3)pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(pool3)conv4 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv4)pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(pool4)conv5 = Conv2D(512, (3, 3), activation="relu", padding="same")(conv5)conv5 = MaxPooling2D(pool_size=(2,2))(conv5)# 引入上采样将特征图方法,就是简单的插值。其中,UpSampling2D(size = size)(x),执行的操作是分别将x的行和列重复size[0]和size[1]次# 例如令size = [2,2], 从[[1,2],[3,4]]变成[[1,1,2,2],[1,1,2,2],[3,3,4,4],[3,3,4,4]]up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(up6)conv6 = Conv2D(256, (3, 3), activation="relu", padding="same")(conv6)up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(up7)conv7 = Conv2D(128, (3, 3), activation="relu", padding="same")(conv7)up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(up8)conv8 = Conv2D(64, (3, 3), activation="relu", padding="same")(conv8)up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(up9)conv9 = Conv2D(32, (3, 3), activation="relu", padding="same")(conv9)conv10 = Conv2D(n_label, (1, 1), activation="sigmoid")(conv9)#conv10 = Conv2D(n_label, (1, 1), activation="softmax")(conv9)model = Model(inputs=inputs, outputs=conv10)# 使用二元分类的cross_entropy,直接用cross_entropy也可以,多分类问题也适用于二分类问题model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])return modeldef train(args): EPOCHS = 10# batch_sizeBS = 16#model = SegNet() model = unet()modelcheck = ModelCheckpoint(args['model'],monitor='val_accuracy',save_best_only=True,mode='max') callable = [modelcheck] train_set,val_set = get_train_val()train_numb = len(train_set) valid_numb = len(val_set) print ("the number of train data is",train_numb) print ("the number of val data is",valid_numb)# max_q_size定义了内部训练队列(queue)的最大大小H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1, validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1) # plot the training loss and accuracy# plt.style.use('ggplot')用ggplot样式美化画图效果# 可选的plt.style(plt.style.available)如下:# ['bmh', 'classic', 'dark_background', 'fast', 'fivethirtyeight', 'ggplot', 'grayscale', 'seaborn-bright', 'seaborn-colorblind', # 'seaborn-dark-palette', 'seaborn-dark', 'seaborn-darkgrid', 'seaborn-deep', 'seaborn-muted', 'seaborn-notebook', 'seaborn-paper',# 'seaborn-pastel', 'seaborn-poster', 'seaborn-talk', 'seaborn-ticks', 'seaborn-white', 'seaborn-whitegrid', 'seaborn', # 'Solarize_Light2', 'tableau-colorblind10', '_classic_test']plt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on U-Net Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")# 在右下角画图plt.legend(loc="lower left")plt.savefig(args["plot"])# 命令行输入参数的提示以及默认参数 def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-d", "--data", help="training data's path",default=True)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__': args = args_parse()filepath = args['data']train(args) #predict()

为了看清楚unet的每一层的输入输出的tensor是怎么样的形状,我们将其打印出来如下:

__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_7 (InputLayer) (None, 3, 256, 256) 0 __________________________________________________________________________________________________ conv2d_79 (Conv2D) (None, 32, 256, 256) 896 input_7[0][0] __________________________________________________________________________________________________ conv2d_80 (Conv2D) (None, 32, 256, 256) 9248 conv2d_79[0][0] __________________________________________________________________________________________________ max_pooling2d_29 (MaxPooling2D) (None, 32, 128, 128) 0 conv2d_80[0][0] __________________________________________________________________________________________________ conv2d_81 (Conv2D) (None, 64, 128, 128) 18496 max_pooling2d_29[0][0] __________________________________________________________________________________________________ conv2d_82 (Conv2D) (None, 64, 128, 128) 36928 conv2d_81[0][0] __________________________________________________________________________________________________ max_pooling2d_30 (MaxPooling2D) (None, 64, 64, 64) 0 conv2d_82[0][0] __________________________________________________________________________________________________ conv2d_83 (Conv2D) (None, 128, 64, 64) 73856 max_pooling2d_30[0][0] __________________________________________________________________________________________________ conv2d_84 (Conv2D) (None, 128, 64, 64) 147584 conv2d_83[0][0] __________________________________________________________________________________________________ max_pooling2d_31 (MaxPooling2D) (None, 128, 32, 32) 0 conv2d_84[0][0] __________________________________________________________________________________________________ conv2d_85 (Conv2D) (None, 256, 32, 32) 295168 max_pooling2d_31[0][0] __________________________________________________________________________________________________ conv2d_86 (Conv2D) (None, 256, 32, 32) 590080 conv2d_85[0][0] __________________________________________________________________________________________________ max_pooling2d_32 (MaxPooling2D) (None, 256, 16, 16) 0 conv2d_86[0][0] __________________________________________________________________________________________________ conv2d_87 (Conv2D) (None, 512, 16, 16) 1180160 max_pooling2d_32[0][0] __________________________________________________________________________________________________ conv2d_88 (Conv2D) (None, 512, 16, 16) 2359808 conv2d_87[0][0] __________________________________________________________________________________________________ up_sampling2d_13 (UpSampling2D) (None, 512, 32, 32) 0 conv2d_88[0][0] __________________________________________________________________________________________________ concatenate_13 (Concatenate) (None, 768, 32, 32) 0 up_sampling2d_13[0][0]conv2d_86[0][0] __________________________________________________________________________________________________ conv2d_89 (Conv2D) (None, 256, 32, 32) 1769728 concatenate_13[0][0] __________________________________________________________________________________________________ conv2d_90 (Conv2D) (None, 256, 32, 32) 590080 conv2d_89[0][0] __________________________________________________________________________________________________ up_sampling2d_14 (UpSampling2D) (None, 256, 64, 64) 0 conv2d_90[0][0] __________________________________________________________________________________________________ concatenate_14 (Concatenate) (None, 384, 64, 64) 0 up_sampling2d_14[0][0]conv2d_84[0][0] __________________________________________________________________________________________________ conv2d_91 (Conv2D) (None, 128, 64, 64) 442496 concatenate_14[0][0] __________________________________________________________________________________________________ conv2d_92 (Conv2D) (None, 128, 64, 64) 147584 conv2d_91[0][0] __________________________________________________________________________________________________ up_sampling2d_15 (UpSampling2D) (None, 128, 128, 128 0 conv2d_92[0][0] __________________________________________________________________________________________________ concatenate_15 (Concatenate) (None, 192, 128, 128 0 up_sampling2d_15[0][0]conv2d_82[0][0] __________________________________________________________________________________________________ conv2d_93 (Conv2D) (None, 64, 128, 128) 110656 concatenate_15[0][0] __________________________________________________________________________________________________ conv2d_94 (Conv2D) (None, 64, 128, 128) 36928 conv2d_93[0][0] __________________________________________________________________________________________________ up_sampling2d_16 (UpSampling2D) (None, 64, 256, 256) 0 conv2d_94[0][0] __________________________________________________________________________________________________ concatenate_16 (Concatenate) (None, 96, 256, 256) 0 up_sampling2d_16[0][0]conv2d_80[0][0] __________________________________________________________________________________________________ conv2d_95 (Conv2D) (None, 32, 256, 256) 27680 concatenate_16[0][0] __________________________________________________________________________________________________ conv2d_96 (Conv2D) (None, 32, 256, 256) 9248 conv2d_95[0][0] __________________________________________________________________________________________________ conv2d_97 (Conv2D) (None, 1, 256, 256) 33 conv2d_96[0][0] ================================================================================================== Total params: 7,846,657 Trainable params: 7,846,657 Non-trainable params: 0 __________________________________________________________________________________________________

模型融合combind.py

#coding=utf-8import numpy as np import cv2 import csv from tqdm import tqdm # 定义三个mask mask1_pool = ['testing1_vegetation_predict.png','testing1_building_predict.png','testing1_water_predict.png','testing1_road_predict.png']mask2_pool = ['testing2_vegetation_predict.png','testing2_building_predict.png','testing2_water_predict.png','testing2_road_predict.png']mask3_pool = ['testing3_vegetation_predict.png','testing3_building_predict.png','testing3_water_predict.png','testing3_road_predict.png'] ## 0:none 1:vegetation 2:building 3:water 4:road#after mask combind img_sets = ['pre1.png','pre2.png','pre3.png']def combind_all_mask():for mask_num in tqdm(range(3)):if mask_num == 0:final_mask = np.zeros((5142,5664),np.uint8)#生成一个全黑全0图像,图片尺寸与原图相同elif mask_num == 1:final_mask = np.zeros((2470,4011),np.uint8)elif mask_num == 2:final_mask = np.zeros((6116,3356),np.uint8)#final_mask = cv2.imread('final_1_8bits_predict.png',0)if mask_num == 0:mask_pool = mask1_poolelif mask_num == 1:mask_pool = mask2_poolelif mask_num == 2:mask_pool = mask3_poolfinal_name = img_sets[mask_num]for idx,name in enumerate(mask_pool):img = cv2.imread('./predict_mask/'+name,0)height,width = img.shapelabel_value = idx+1 #coressponding labels valuefor i in tqdm(range(height)): #priority:building>water>road>vegetationfor j in range(width):# 模型融合if img[i,j] == 255:# 如果当前像素为全部为全白,那么到底这个区域属于哪个类别呢?按照优先级的顺序来定:building>water>road>vegetationif label_value == 2:final_mask[i,j] = label_valueelif label_value == 3 and final_mask[i,j] != 2:final_mask[i,j] = label_valueelif label_value == 4 and final_mask[i,j] != 2 and final_mask[i,j] != 3:final_mask[i,j] = label_valueelif label_value == 1 and final_mask[i,j] == 0:final_mask[i,j] = label_value cv2.imwrite('./final_result/'+final_name,final_mask) print 'combinding mask...' combind_all_mask()

UNet模型预测unet_predict.py

import cv2 import random import numpy as np import os import argparse from keras.preprocessing.image import img_to_array from keras.models import load_model from sklearn.preprocessing import LabelEncoder # 设置用编号为1的GPU来训练 os.environ["CUDA_VISIBLE_DEVICES"] = "1"TEST_SET = ['1.png','2.png','3.png']image_size = 256classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) def args_parse(): # construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-m", "--model", required=True,help="path to trained model model")ap.add_argument("-s", "--stride", required=False,help="crop slide stride", type=int, default=image_size)args = vars(ap.parse_args()) return argsdef predict(args):# load the trained convolutional neural networkprint("[INFO] loading network...")# 加载训练好的模型model = load_model(args["model"])stride = args['stride']for n in range(len(TEST_SET)):path = TEST_SET[n]#load the image读取测试图片image = cv2.imread('./test/' + path)h,w,_ = image.shape# 要怎么样进行预测呢?由于在训练的时候输入的图像大小是256x256,在测试的时候喂给model的size也是256,# 可以先对原图补零,确保padding之后的size刚好可以被256整除padding_h = (h//stride + 1) * stride padding_w = (w//stride + 1) * stridepadding_img = np.zeros((padding_h,padding_w,3),dtype=np.uint8)# 不足的部分补零padding_img[0:h,0:w,:] = image[:,:,:]#padding_img = padding_img.astype("float") / 255.0padding_img = img_to_array(padding_img)print ('src:',padding_img.shape)mask_whole = np.zeros((padding_h,padding_w),dtype=np.uint8)for i in range(padding_h//stride):for j in range(padding_w//stride):# 放到padding之后的图像对应的位置crop = padding_img[:3,i*stride:i*stride+image_size,j*stride:j*stride+image_size]_,ch,cw = crop.shapeif ch != 256 or cw != 256:print ('invalid size!')continuecrop = np.expand_dims(crop, axis=0) # fit当中的verbose = 0 为不在标准输出流输出日志信息# verbose = 1 为输出进度条记录# verbose = 2 为每个epoch输出一行记录# evaluate当中的verbose = 0 为不在标准输出流输出日志信息# verbose = 1 为输出进度条记录pred = model.predict(crop,verbose=2)#print (np.unique(pred)) pred = pred.reshape((256,256)).astype(np.uint8)#print ('pred:',pred.shape)mask_whole[i*stride:i*stride+image_size,j*stride:j*stride+image_size] = pred[:,:]# 再把图像切割成跟原来一样大小的图像cv2.imwrite('./predict/pre'+str(n+1)+'.png',mask_whole[0:h,0:w])if __name__ == '__main__':args = args_parse()predict(args)

分类结果集成ensemble.py

import numpy as np import cv2 import argparseRESULT_PREFIXX = ['./result1/','./result2/','./result3/']# each mask has 5 classes: 0~4def vote_per_image(image_id):result_list = []for j in range(len(RESULT_PREFIXX)):im = cv2.imread(RESULT_PREFIXX[j]+str(image_id)+'.png',0)result_list.append(im)# each pixelheight,width = result_list[0].shapevote_mask = np.zeros((height,width))for h in range(height):for w in range(width):# 像素级别# 每个像素的所属的类别,总共5类,因此类别list是一个1x5的recordrecord = np.zeros((1,5))# 下面这个for循环是每个像素的类别级别for n in range(len(result_list)):#对于每一类结果中的每一张图片的每一个像素,统计这个位置的类别票数mask = result_list[n]pixel = mask[h,w]#print('pix:',pixel)record[0,pixel]+=1# 集成学习,取票数最多的为最终类别label = record.argmax()#print(label)vote_mask[h,w] = labelcv2.imwrite('vote_mask'+str(image_id)+'.png',vote_mask) # 总共3类结果 vote_per_image(3)

SegNet模型训练segnet_train.py

#coding=utf-8 import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import argparse import numpy as np from keras.models import Sequential from keras.layers import Conv2D,MaxPooling2D,UpSampling2D,BatchNormalization,Reshape,Permute,Activation from keras.utils.np_utils import to_categorical from keras.preprocessing.image import img_to_array from keras.callbacks import ModelCheckpoint from sklearn.preprocessing import LabelEncoder from PIL import Image import matplotlib.pyplot as plt import cv2 import random import os from tqdm import tqdm os.environ["CUDA_VISIBLE_DEVICES"] = "1" seed = 7 np.random.seed(seed) #data_shape = 360*480 img_w = 256 img_h = 256 #有一个为背景 n_label = 4+1 classes = [0. , 1., 2., 3. , 4.] labelencoder = LabelEncoder() labelencoder.fit(classes) image_sets = ['1.png','2.png','3.png']def load_img(path, grayscale=False):if grayscale:img = cv2.imread(path,cv2.IMREAD_GRAYSCALE)else:img = cv2.imread(path)img = np.array(img,dtype="float") / 255.0return imgfilepath ='./train/' def get_train_val(val_rate = 0.25):train_url = [] train_set = []val_set = []for pic in os.listdir(filepath + 'src'):train_url.append(pic)random.shuffle(train_url)total_num = len(train_url)val_num = int(val_rate * total_num)for i in range(len(train_url)):if i < val_num:val_set.append(train_url[i]) else:train_set.append(train_url[i])return train_set,val_set# data for training def generateData(batch_size,data=[]): #print 'generateData...'while True: train_data = [] train_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) train_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape train_label.append(label) if batch % batch_size==0: #print 'get enough bacth!\n'train_data = np.array(train_data) train_label = np.array(train_label).flatten() train_label = labelencoder.transform(train_label) train_label = to_categorical(train_label, num_classes=n_label) train_label = train_label.reshape((batch_size,img_w * img_h,n_label)) yield (train_data,train_label) train_data = [] train_label = [] batch = 0 # data for validation def generateValidData(batch_size,data=[]): #print 'generateValidData...'while True: valid_data = [] valid_label = [] batch = 0 for i in (range(len(data))): url = data[i]batch += 1 img = load_img(filepath + 'src/' + url)img = img_to_array(img) valid_data.append(img) label = load_img(filepath + 'label/' + url, grayscale=True)label = img_to_array(label).reshape((img_w * img_h,)) # print label.shape valid_label.append(label) if batch % batch_size==0: valid_data = np.array(valid_data) valid_label = np.array(valid_label).flatten() valid_label = labelencoder.transform(valid_label) valid_label = to_categorical(valid_label, num_classes=n_label) valid_label = valid_label.reshape((batch_size,img_w * img_h,n_label)) yield (valid_data,valid_label) valid_data = [] valid_label = [] batch = 0 def SegNet(): model = Sequential() #encoder model.add(Conv2D(64,(3,3),strides=(1,1),input_shape=(3,img_w,img_h),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64,(3,3),strides=(1,1),padding='same',activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2,2),dim_ordering = 'th')) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(MaxPooling2D(pool_size=(2, 2),dim_ordering = 'th')) #(8,8) #decoder model.add(UpSampling2D(size=(2,2))) #(16,16) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(32,32) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(512, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(64,64) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(256, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(128,128) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(128, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(UpSampling2D(size=(2, 2))) #(256,256) model.add(Conv2D(64, (3, 3), strides=(1, 1), input_shape=(3,img_w, img_h), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(64, (3, 3), strides=(1, 1), padding='same', activation='relu')) model.add(BatchNormalization()) model.add(Conv2D(n_label, (1, 1), strides=(1, 1), padding='same')) model.add(Reshape((n_label,img_w*img_h))) #axis=1和axis=2互换位置,等同于np.swapaxes(layer,1,2) model.add(Permute((2,1))) model.add(Activation('softmax')) model.compile(loss='categorical_crossentropy',optimizer='sgd',metrics=['accuracy']) return model def train(args): EPOCHS = 30BS = 16model = SegNet() modelcheck = ModelCheckpoint(args['model'],monitor='val_acc',save_best_only=True,mode='max') callable = [modelcheck] train_set,val_set = get_train_val()train_numb = len(train_set) valid_numb = len(val_set) print ("the number of train data is",train_numb) print ("the number of val data is",valid_numb)H = model.fit_generator(generator=generateData(BS,train_set),steps_per_epoch=train_numb//BS,epochs=EPOCHS,verbose=1, validation_data=generateValidData(BS,val_set),validation_steps=valid_numb//BS,callbacks=callable,max_q_size=1) # plot the training loss and accuracyplt.style.use("ggplot")plt.figure()N = EPOCHSplt.plot(np.arange(0, N), H.history["loss"], label="train_loss")plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")plt.title("Training Loss and Accuracy on SegNet Satellite Seg")plt.xlabel("Epoch #")plt.ylabel("Loss/Accuracy")plt.legend(loc="lower left")plt.savefig(args["plot"])def args_parse():# construct the argument parse and parse the argumentsap = argparse.ArgumentParser()ap.add_argument("-a", "--augment", help="using data augment or not",action="store_true", default=False)ap.add_argument("-m", "--model", required=True,help="path to output model")ap.add_argument("-p", "--plot", type=str, default="plot.png",help="path to output accuracy/loss plot")args = vars(ap.parse_args()) return argsif __name__=='__main__': args = args_parse()if args['augment'] == True:filepath ='./aug/train/'train(args) #predict()

同理,为了搞清楚segnet每一层的输入输出的tensor分别是什么样的,我们将shape打印出来如下:

_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= conv2d_98 (Conv2D) (None, 64, 256, 256) 1792 _________________________________________________________________ batch_normalization_1 (Batch (None, 64, 256, 256) 1024 _________________________________________________________________ conv2d_99 (Conv2D) (None, 64, 256, 256) 36928 _________________________________________________________________ batch_normalization_2 (Batch (None, 64, 256, 256) 1024 _________________________________________________________________ max_pooling2d_33 (MaxPooling (None, 64, 128, 128) 0 _________________________________________________________________ conv2d_100 (Conv2D) (None, 128, 128, 128) 73856 _________________________________________________________________ batch_normalization_3 (Batch (None, 128, 128, 128) 512 _________________________________________________________________ conv2d_101 (Conv2D) (None, 128, 128, 128) 147584 _________________________________________________________________ batch_normalization_4 (Batch (None, 128, 128, 128) 512 _________________________________________________________________ max_pooling2d_34 (MaxPooling (None, 128, 64, 64) 0 _________________________________________________________________ conv2d_102 (Conv2D) (None, 256, 64, 64) 295168 _________________________________________________________________ batch_normalization_5 (Batch (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_103 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_6 (Batch (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_104 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_7 (Batch (None, 256, 64, 64) 256 _________________________________________________________________ max_pooling2d_35 (MaxPooling (None, 256, 32, 32) 0 _________________________________________________________________ conv2d_105 (Conv2D) (None, 512, 32, 32) 1180160 _________________________________________________________________ batch_normalization_8 (Batch (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_106 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_9 (Batch (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_107 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_10 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ max_pooling2d_36 (MaxPooling (None, 512, 16, 16) 0 _________________________________________________________________ conv2d_108 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_11 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_109 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_12 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_110 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_13 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ max_pooling2d_37 (MaxPooling (None, 512, 8, 8) 0 _________________________________________________________________ up_sampling2d_17 (UpSampling (None, 512, 16, 16) 0 _________________________________________________________________ conv2d_111 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_14 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_112 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_15 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ conv2d_113 (Conv2D) (None, 512, 16, 16) 2359808 _________________________________________________________________ batch_normalization_16 (Batc (None, 512, 16, 16) 64 _________________________________________________________________ up_sampling2d_18 (UpSampling (None, 512, 32, 32) 0 _________________________________________________________________ conv2d_114 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_17 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_115 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_18 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ conv2d_116 (Conv2D) (None, 512, 32, 32) 2359808 _________________________________________________________________ batch_normalization_19 (Batc (None, 512, 32, 32) 128 _________________________________________________________________ up_sampling2d_19 (UpSampling (None, 512, 64, 64) 0 _________________________________________________________________ conv2d_117 (Conv2D) (None, 256, 64, 64) 1179904 _________________________________________________________________ batch_normalization_20 (Batc (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_118 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_21 (Batc (None, 256, 64, 64) 256 _________________________________________________________________ conv2d_119 (Conv2D) (None, 256, 64, 64) 590080 _________________________________________________________________ batch_normalization_22 (Batc (None, 256, 64, 64) 256 _________________________________________________________________ up_sampling2d_20 (UpSampling (None, 256, 128, 128) 0 _________________________________________________________________ conv2d_120 (Conv2D) (None, 128, 128, 128) 295040 _________________________________________________________________ batch_normalization_23 (Batc (None, 128, 128, 128) 512 _________________________________________________________________ conv2d_121 (Conv2D) (None, 128, 128, 128) 147584 _________________________________________________________________ batch_normalization_24 (Batc (None, 128, 128, 128) 512 _________________________________________________________________ up_sampling2d_21 (UpSampling (None, 128, 256, 256) 0 _________________________________________________________________ conv2d_122 (Conv2D) (None, 64, 256, 256) 73792 _________________________________________________________________ batch_normalization_25 (Batc (None, 64, 256, 256) 1024 _________________________________________________________________ conv2d_123 (Conv2D) (None, 64, 256, 256) 36928 _________________________________________________________________ batch_normalization_26 (Batc (None, 64, 256, 256) 1024 _________________________________________________________________ conv2d_124 (Conv2D) (None, 1, 256, 256) 65 _________________________________________________________________ reshape_1 (Reshape) (None, 1, 65536) 0 _________________________________________________________________ permute_1 (Permute) (None, 65536, 1) 0 _________________________________________________________________ activation_1 (Activation) (None, 65536, 1) 0 ================================================================= Total params: 31,795,841 Trainable params: 31,791,425 Non-trainable params: 4,416 _________________________________________________________________

总结

以上是生活随笔为你收集整理的基于SegNet和UNet的遥感图像分割代码解读的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。