torch-cifar10

torch-cifar10

之前一直想整理下torch的用法,从这篇blog开始吧。

这是用torch来做kaggle cifar10的例子。希望在总结的时候,能够把其中的关键地方给学习到,以后能够用到其他的模型训练上面。

第一步,准备数据

# 1.prepare train data and test data
# data augment
# transforms on PIL image
transform_train = transforms.Compose([
                        transforms.RandomCrop(32, padding=4),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),  
           ])

# Normalize(mean, std)  # the len of mean or std corresponding to the number of channels,
transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                    
            ])

# train data

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=True, num_workers=2)

这里面最important的函数是

torch.utils.data.DataLoader(Data, batch_size=, shuffle=True, num_workers=2)

其中的Data需要是

torch.utils.data.Dataset

的一个对象。这里是用了已经有了的CIFAR10了,但是实际中在做项目的时候,可能只有Imgs,可能需要自己写一个类来继承

torch.utils.data.Dataset

比如下面就是一个框架, 当然其他的函数可以自己根据需要来加上去。

class Data(torch.utils.data.Dataset):
    
    def __init__(self, fileDir, transformer=None):
        # 下面的就省略了,其中transformer是做数据增强的时候用的变换,这个变换可以自己实现,也可以用torch里面自带的,上面用的就是torch自带的,但是有时候可能不能够完成自己的任务 ,所以必要的时候需要自己写一个来实现具体的任务。

    def __getitem__(self, index):
        # 用这个的好处是可以通过index来得到对应的数据


    def __len__(self):
        return 有关数据的长度

构建网络结构

这里用LeNet来做的。为了有结构化,可以新建一个models的文件目录专门来放模型。 然后里面加上

__init__.py

如果希望在其他的文件里面直接

from models import *

的话,需要在 __init__.py 的里面加上 from lenet import * 之类的,到时候就可以用models下面的所有的东西了,包括类。

这里建立一个 “modles/lenet.py” 来写

# 2. construct model
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84) 
        self.fc3 = nn.Linear(84,10)

    def forward(self, x): 
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1) 
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out 

写网络结构这一部分,和其他的框架基本上一样,比如mxnet, 这里需要是

torch.nn.Module

的子类, 需要实现

__init__

函数和前向计算的

forward

函数,有的网络结构比较复杂的可能要先得到网络每层的配置信息,然后写一个专门的

_makelayer

来制作每一个block.

准备训练

这一步需要的东西 有一些多,比如下面的这些

    # 3.get the network
    net = lenet.LeNet() 
    # choose which device to train
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark=True
    # 4. loss function
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

其中有决定在哪里训练的,是在gpu还是cpu上面,如果在gpu 上面的话,需要用

“net = torch.nn.DataParallel(net)”

然后要指定 cudnn.benchmark=True, 加了这个cudnn的好处是可以让cudnn内置的auto-tuner自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题.当然这个并不是加了一定就好,

  • 如果网络的输入的数据维度或类型上变化不大,这样设置可以增加效率

  • 如果网络的输入数据在每次iter的时候都变化的话,会导致cudnn每次都要去寻找一遍最优的配置,这样反而会降低运行效率。

torch.cuda.is_available() 是判断有没有cuda的。

如果是多gpu的话,需要指定

“net = torch.nn.DataParallel(net, device_ids=args.gpu)“

其中 args.gpu 是个列表,比如[0,2,3]其中第一个必须得是0,因为torch在训练的时候是有一个主训练的device,这个就是”0“号device,其中数据必需放在”0”号上面,参数更新也是在”0“号上面完成的,这一点要特别注意,另外,这里的”0“号并不一定是机子上的第一块gpu,而是从

“CUDA_VISIBLE_DEVICES = ‘2,3,4,8’ “ 中的第一个。

开始训练和测试

代码如下

def train(epoch):
    print("\n Epoch: %d" % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)   # put the data onto device
        optimizer.zero_grad()  # clear the gradients
        outputs = net(inputs)  # get the pred output
        loss = criterion(outputs, targets)
        # bp
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        _, predicted = outputs.max(1)  # or torch.max(outpus.data, 1)   
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()   # the num of predict true
        print("Train's loss: %.4f, Train's acc: %.6f" %(train_loss/(batch_idx+1), 1.0*correct/total))

def test(epoch):
    global best_acc
    net.eval()   # just like net.train()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            print("Test's loss: %.4f, Test's acc: %.6f" %(test_loss/(batch_idx+1), 1.0*correct/total))

需要注意的地方有

“loss=criterion(outputs, targets)”

因为从上面可以看出从LeNet出来的并不一定在(0,1)之间,然后这个targets现在也不清晰长什么 样,正常的话,如果batch 是16,那么输出的应该 是(16,10), 而targets也是(16,10)的话这样就好做ce了,即one-hot的方式来表示label, 但是经过把outputs和targets打印出来之后发现outputs是(128,10)的,而targets是(128,)的,里面的内容是0到9. 查了一下torch的源码,知道了需要outputs就是网络的输出,不一定经过softmax都可以, 而targets是真实的label,是个一维的向量,这里刚好都满足。

保存模型

下面是一种保存的方式

 #8. decide to save model
    acc = 1.0 *correct/total
    if acc > best_acc:
        print("Saving...")
        state = { 
                'net':net.state_dict(),
                'acc':acc,
                'epoch':epoch,
                }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/cifar.t7')
        best_acc = acc 

这里是以准确率来判断是否要更新模型的,这种办法最直观,在实际中常这样,因为看上去效果好了。

run

for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test(epoch)

其他的用到的参数需要在代码前面补充,包括用到的包。

打赏,谢谢~~

取消

感谢您的支持,我会继续努力的!

扫码支持
扫码打赏,多谢支持~

打开微信扫一扫,即可进行扫码打赏哦