torch-cifar10-2

torch-cifar10-2

上次提到实际中知道的都是imgs和labels,所以肯定需要自己制作traindata,这次就用真正的cifar10的原数据来做一个训练集。

kaggle cifar10的训练数据集有5w张,不过没多大。测试集好像有30w张。 那么按torch里面的要求,需要一个类来自己实现。

我是这样写的,

class Cifar(torch.utils.data.Dataset):
    def __init__(self, csvfile, transformer=None):
        #需要一个文件名和label的对应
        self.img_list = read_csv(csvfile)   # img_list = [(img_path, label), ..., ]
        self.transformer = transformer

    def __getitem__(self, idx):
        img_path = self.img_list[idx][0]   # abs path
        #img = np.array(cv2.imread(img_path), dtype=np.float32)   # maybe this is error!!!
        img = Image.open(img_path)
        label = self.img_list[idx][1]
        img = self.transformer(img) 
        return img, label
    
    def __len__(self):
        return len(self.img_list)

其中read_csv我是这样写的

import cv2 
import torch
import os
import numpy as np
from PIL import Image

def read_csv(csvfile):
    
    ret = []
    dirname = os.path.split(csvfile)[0]
    
    with open(csvfile, 'r') as f:
        lines = f.readlines()[1:]
        tokens = [l.strip().split(",") for l in lines]
    
        name2label = {}
        for x,y in tokens:
            newname = os.path.join(dirname, 'train', str(x)+".png")    
            print(newname)
            if y not in name2label:
                label = len(name2label)
                name2label[y] = label 
            else:
                label = name2label[y] 

            ret.append((newname, label))
    
        print("name2label", name2label)
    return ret 

因为torch里面要求transforms on PIL image, 刚开始我用cv的时候就报错了, 然后 trainset改成了下面的

# 1.prepare train data and test data
# data augment
# transforms on PIL image
transform_train = transforms.Compose([
                        transforms.RandomCrop(32, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  
           ])

csv_file = '/home/pengkun/torch_learn/cifar10_torch/data/cifar_10/trainLabels.csv'
trainset = Cifar10Folder.Cifar(csv_file, transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

然后其它的都没有更改,运行成功了,现在的问题是,上面的过程对不对,如何去检测对不对呢?我准备找一个模型训练一下,然后上传到kaggle上面对比一下。

打赏,谢谢~~

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,多谢支持~

打开微信扫一扫,即可进行扫码打赏哦