1. 前言

最近在复现MCNN时发现一个问题,ShanghaiTech数据集图片的尺寸不一,转换为tensor后的shape形状不一致,无法直接进行多batch_size的数据加载。经过查找资料,有人提到可以定义dataloader的collate_fn函数,在加载时将数据裁剪为最小的图片尺寸,以便于堆叠成多个batch_size。

2. 代码

2.1 数据集的定义

dataset.py


import scipy.io as sio
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch
import os
import cv2
from PIL import Image
import torchvision

class myDatasets(Dataset):
    def __init__(self,img_path, ann_path, down_sample=False,transform=None):
        self.pre_img_path = img_path
        self.pre_ann_path = ann_path
        # 图像的文件名是 IMG_15.jpg 则 标签是 GT_IMG_15.mat
        # 因此不需要listdir标签路径
        self.img_names = os.listdir(img_path)
        self.transform=transform
        self.down_sample = down_sample

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, index):
        img_name = self.img_names[index]
        mat_name = 'GT_' + img_name.replace('jpg','mat')

        img = Image.open(os.path.join(self.pre_img_path,img_name)).convert('L')
        img = np.array(img).astype(np.float32)
        
        # print(F"{h=},{w=}")
        if self.transform != None:
            img=self.transform(img)
        # img.permute(0,2,1) # totensor会自动进行维度的转换,所以这里是不必要的

        h,w = img.shape[1],img.shape[2]

        anno = sio.loadmat(self.pre_ann_path + mat_name)
        xy = anno['image_info'][0][0][0][0][0]  # N,2的坐标数组
        density_map = self.get_density((h,w), xy).astype(np.float32) # 密度图
        density_map = torch.from_numpy(density_map)

        return img,density_map


    def get_density(self,img_shape, points):
        if self.down_sample:
            h, w  = img_shape[0]//4, img_shape[1]//4
        else:
            h, w  = img_shape[0], img_shape[1]
        # 进行下采样
        # 密度图 初始化全0
        labels = np.zeros(shape=(h,w))
        for loc in points:
            f_sz = 15  # 滤波器尺寸 预设为15 也是邻域的尺寸
            sigma = 4.0  # sigma参数
            H = self.fspecial(f_sz, f_sz , sigma)  # 高斯核矩阵
            if self.down_sample:
                x = min(max(0,abs(int(loc[0]/4))),int(w))  # 头部坐标
                y = min(max(0,abs(int(loc[1]/4))),int(h))
            else:
                x = min(max(0,abs(int(loc[0]))),int(w))  # 头部坐标
                y = min(max(0,abs(int(loc[1]))),int(h))
            if x > w or y > h:
                continue
            x1 = x - f_sz/2 ; y1 = y - f_sz/2
            x2 = x + f_sz/2 ; y2 = y + f_sz/2
            dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0

            change_H = False
            if x1 < 0:
                dfx1 = abs(x1);x1 = 0 ;change_H = True
            if y1 < 0:
                dfy1 = abs(y1); y1 = 0 ; change_H = True
            if x2 > w:
                dfx2 = x2-w ; x2 =w-1 ; change_H =True
            if y2 > h:
                dfy2 = y2 -h ; y2 = h-1 ; change_H =True
            x1h =  1 + dfx1 ; y1h =  1 + dfy1
            x2h = f_sz - dfx2 ; y2h = f_sz - dfy2
            if change_H :
                H = self.fspecial(int(y2h-y1h+1), int(x2h-x1h+1),sigma)
            labels[int(y1):int(y2), int(x1):int(x2)] = labels[int(y1):int(y2), int(x1):int(x2)] + H
        return labels

    def fspecial(self,ksize_x=5, ksize_y = 5, sigma=4):
        kx = cv2.getGaussianKernel(ksize_x, sigma)
        ky = cv2.getGaussianKernel(ksize_y, sigma)
        return np.multiply(kx,np.transpose(ky))

View Code

2.2 使用

demo.py


from config import get_args
from model import MCNN
from dataset import myDatasets
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
import time
from utils import get_mse_mae,show
import os
import numpy as np
import matplotlib.pyplot as plt
from debug_utils import ModelVerbose
import random
import cv2

args = get_args()

if args.dataset == 'ShanghaiTechA':
    if os.name == 'nt':
        # for windows
        train_imgs_path = args.dataset_path + r'\train_data\images\\'
        train_labels_path = args.dataset_path+r'\train_data\ground-truth\\'
        test_imgs_path = args.dataset_path+r'\test_data\images\\'
        test_labels_path = args.dataset_path+r'\test_data\ground-truth\\'
    else:
        # for linux
        train_imgs_path = os.path.join(args.dataset_path,'train_data/images/')
        train_labels_path = os.path.join(args.dataset_path,'train_data/ground-truth/')
        test_imgs_path = os.path.join(args.dataset_path,'test_data/images/')
        test_labels_path = os.path.join(args.dataset_path,'test_data/ground-truth/')
    # print(F"{train_imgs_path=}\n{train_labels_path=}\n{test_imgs_path=}\n{test_labels_path=}")
else:
    raise Exception(F'Dataset {args.dataset} Not Implement')

# 数据集
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
    # torchvision.transforms.Resize((768,1024)),
    # torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def get_min_size(batch):
    min_ht, min_wd = (float('inf'),float('inf'))
    for img in batch:
        c,h,w = img.shape
        if h<min_ht:
            min_ht = h
        if w<min_wd:
            min_wd = w
    return min_ht,min_wd

def random_crop_img(img,size):
    c,h,w = img.shape
    h_start = random.randint(0,h-size[0])
    h_end = h_start + size[0]
    w_start = random.randint(0,w - size[1])
    w_end = w_start + size[1]

    
    return img[:,h_start:h_end,w_start:w_end]

def random_crop_dtmap(dt_map,size):
    h,w = dt_map.shape
    h_start = random.randint(0,h-size[0])
    h_end = h_start + size[0]
    w_start = random.randint(0,w - size[1])
    w_end = w_start + size[1]
    return dt_map[h_start:h_end,w_start:w_end]

def random_crop(img,dt_map,size):
    c,h,w = img.shape
    h_start = random.randint(0,h-size[0])
    h_end = h_start + size[0]
    w_start = random.randint(0,w - size[1])
    w_end = w_start + size[1]
    return img[:,h_start:h_end,w_start:w_end],dt_map[h_start:h_end,w_start:w_end]



def c_f(batch):
    # 这里接收到的data 是[(img_1_768_1024,target_192,256)]
    # 1. 分别找到img target的最大h w
    # 2. 新建数组(h,w)
    transposed = list(zip(*batch))
    imgs, dens = [transposed[0],transposed[1]]
    error_msg = "batch must contain tensors; found {}"
    if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor):
        min_h, min_w = get_min_size(imgs)
        cropped_imgs = []
        cropped_dens = []
        for i in range(len(batch)):
            # _img = random_crop_img(imgs[i],(min_h,min_w))
            # 下采样
            # _dtmap = random_crop_dtmap(dens[i],(min_h//4,min_w//4))
            # _dtmap = random_crop_dtmap(dens[i],(min_h,min_w))
            _img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w))
            cropped_imgs.append(_img)
            cropped_dens.append(_dtmap)
        cropped_imgs = torch.stack(cropped_imgs)
        cropped_dens = torch.stack(cropped_dens)
        return [cropped_imgs,cropped_dens]
    raise TypeError((error_msg.format(type(batch[0]))))

train_datasets = myDatasets(train_imgs_path, train_labels_path,down_sample=False,transform=transform)
train_data_loader = DataLoader(train_datasets, batch_size=args.batch_size,collate_fn=c_f)
test_datasets = myDatasets(test_imgs_path, test_labels_path,down_sample=True,transform=transform)
test_data_loader = DataLoader(test_datasets, batch_size=args.batch_size)

def color_map(img,color='gray'):
    # labels是一个二维数组,是密度图
    max_pixel = np.max(img)
    min_pixel = np.min(img)
    delta = max_pixel - min_pixel
    labels_int = ((img-min_pixel)/delta*255)
    # 以下操作是为了反转jet的颜色,不然就会出现数值高的反而是蓝色,数值低的是红色,不像热力图了
    labels_int = labels_int * (-1)
    labels_int = labels_int + 255
    labels_int = labels_int.astype(np.uint8)
    if color == 'jet':
        return cv2.applyColorMap(labels_int,cv2.COLORMAP_JET)
    else:
        img_ = img[::,::]
        img_ = cv2.cvtColor(img_,cv2.COLOR_GRAY2RGB)
        return img_

for i,(imgs,targets) in enumerate(train_data_loader):
    # img.shape:        (1,1,768,1024)
    # targets.shape:    (1,192,256)
    for j in range(args.batch_size):
        img = imgs[j][0].numpy()
        dtmap = targets[j].numpy()
        # dtmap = cv2.resize(dtmap,img.shape[::-1])
        img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
        img = img.astype(np.uint8)
        dtmap = color_map(dtmap,'jet')
        visual_img = cv2.addWeighted(img,0.5,dtmap,0.5,0)
        plt.imshow(visual_img)
        plt.show()
    if i>1:
        break
    

View Code

2.3 配置

config.py


import argparse


def get_args():
    parser = argparse.ArgumentParser(description='MCNN')

    parser.add_argument('--dataset',type=str,default='ShanghaiTechA')

    parser.add_argument('--dataset_path',type=str,default=r"C:\Users\ocean\Downloads\datasets\ShanghaiTech\part_A\\")

    parser.add_argument('--save_path',type=str,default='./save_file/')

    parser.add_argument('--print_freq',type=int,default=1)

    parser.add_argument('--device',type=str,default='cuda')

    parser.add_argument('--epochs',type=int,default=600)

    parser.add_argument('--batch_size',type=int,default=4)

    parser.add_argument('--lr',type=float,default=1e-5)

    parser.add_argument('--optimizer',type=str,default='Adam')

    args = parser.parse_args()
    # for jupyer notbook
    # args = parser.parse_know_args()[0]
    return args

View Code

 3. 总结

其中比较值得说道时collate_fn函数c_f(),它的代码如下所示

def c_f(batch):
    transposed = list(zip(*batch))
    imgs, dens = [transposed[0],transposed[1]]
    error_msg = "batch must contain tensors; found {}"
    if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor):
        min_h, min_w = get_min_size(imgs)
        cropped_imgs = []
        cropped_dens = []
        for i in range(len(batch)):
            _img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w))
            cropped_imgs.append(_img)
            cropped_dens.append(_dtmap)
        cropped_imgs = torch.stack(cropped_imgs)
        cropped_dens = torch.stack(cropped_dens)
        return [cropped_imgs,cropped_dens] # 这里不用列表包起来应该也行
    raise TypeError((error_msg.format(type(batch[0]))))

这里传入的参数batch是一个list,其长度是batch_size。它的每一个元素代表了一个数据集单元,即自定义数据集类中__getitem__方法return的值。由于我们的__getitem__方法return了img和density_map两个数据,所以batch的每一个数据单元其实是一个元组(img, density_map)。

list(zip(*batch))所做的事情是把batch中的imgs和density_maps分别拿出来各自成为一个列表,方便下一步的处理。

在处理最后还要将列表中的元素堆叠成tensor返回

原文地址:http://www.cnblogs.com/x-ocean/p/16878864.html

1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长! 2. 分享目的仅供大家学习和交流,请务用于商业用途! 3. 如果你也有好源码或者教程,可以到用户中心发布,分享有积分奖励和额外收入! 4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解! 5. 如有链接无法下载、失效或广告,请联系管理员处理! 6. 本站资源售价只是赞助,收取费用仅维持本站的日常运营所需! 7. 如遇到加密压缩包,默认解压密码为"gltf",如遇到无法解压的请联系管理员! 8. 因为资源和程序源码均为可复制品,所以不支持任何理由的退款兑现,请斟酌后支付下载 声明:如果标题没有注明"已测试"或者"测试可用"等字样的资源源码均未经过站长测试.特别注意没有标注的源码不保证任何可用性