import torchvision import torch from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from model import * from torch import nn from torch.nn import Module train_data = torchvision.datasets.CIFAR10("CIFAR10", train=True, transform=torchvision.transforms.ToTensor(), download=True) test_data = torchvision.datasets.CIFAR10("CIFAR10", train=False, transform=torchvision.transforms.ToTensor(), download=True) train_data_size = len(train_data) test_data_size = len(test_data) train_dataloader = DataLoader(train_data, batch_size=64) test_dataloader = DataLoader(test_data, batch_size=64) # create model tudui = Tudui() # create loss function loss_fn = nn.CrossEntropyLoss() # optim learning_rate = 0.01 optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate) # set some parameters to train total_train_step = 0 total_test_step = 0 epoch = 10 # add tensorboard writer = SummaryWriter("logs_train") # for i in range(epoch): i = 0 whileTrue: print("-------- This is No. {} times of training ----------".format(i + 1)) for data in train_dataloader: imgs, targets = data outputs = tudui(imgs) loss = loss_fn(outputs, targets) # do some optimization optimizer.zero_grad() loss.backward() optimizer.step() total_train_step += 1 if total_train_step % 100 == 0: print("Times of training {}, loss: {}".format(total_train_step, loss.item())) writer.add_scalar("train_loss", loss.item(), total_train_step) # test the result of training total_test_loss = 0 total_accuracy = 0 with torch.no_grad(): for data in test_dataloader: imgs, targets = data outputs = tudui(imgs) loss = loss_fn(outputs, targets) total_test_loss += loss.item() accuracy = (outputs.argmax(1) == targets).sum() total_accuracy += accuracy print("Loss on the whole test set: {} ".format(total_test_loss)) print("Accuracy rate on the whole test set: {} ".format(total_accuracy / test_data_size)) if total_accuracy / test_data_size > 0.8: break writer.add_scalar("test_loss", total_test_loss, total_test_step) writer.add_scalar("test_accuracy", total_accuracy / test_data_size, total_test_step) total_test_step += 1
torch.save(tudui, "tudui_{}.pth".format(i)) print("Model has been saved ") i = i + 1 writer.close()