1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn, optim
from Lenet5 import Lenet5
from torch.nn import functional as F
batch_size = 128
def main(): cifar_train = datasets.CIFAR10( root="/media/D/dataset/CIFAR10", train=True, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), download=False) cifar_train = DataLoader( dataset=cifar_train, batch_size=batch_size, shuffle=True, ) cifar_test = datasets.CIFAR10( root="/media/D/dataset/CIFAR10", train=False, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), download=False) cifar_test = DataLoader( dataset=cifar_test, batch_size=batch_size, shuffle=True, ) x, label = iter(cifar_train).next() print(x.shape, label.shape)
device = torch.device("cuda") model = Lenet5().to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.SGD(model.parameters(), lr=0.001) for epoch in range(10000): model.train() for batch_idx, (x, label) in enumerate(cifar_train): x, label = x.to(device), label.to(device) logits = model(x) loss = criterion(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() print(epoch, loss.item()) model.eval() with torch.no_grad(): total_correct = 0 total_num = 0 for x, label in cifar_test: x, label = x.to(device), label.to(device) logits = model(x) pred = logits.argmax(dim=1) correct = torch.eq(pred, label).float().sum().item()
total_correct += correct total_num += x.size(0) acc = total_correct / total_num print(epoch, acc)
if __name__ == '__main__': main()
|