课堂练习,课后作业不想做了……

 1 import torch
 2 from torchvision import transforms
 3 from torchvision import datasets
 4 from torch.utils.data import DataLoader
 5 import torch.nn.functional as F
 6 import torch.optim as optim
 7 
 8 # prepare dataset
 9 
10 batch_size = 64
11 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])  # 归一化,均值和方差
12 
13 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
14 train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
15 test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
16 test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
17 
18 
19 # design model using class
20 class Net(torch.nn.Module):
21     def __init__(self):
22         super(Net, self).__init__()
23         self.l1 = torch.nn.Linear(784, 512)
24         self.l2 = torch.nn.Linear(512, 256)
25         self.l3 = torch.nn.Linear(256, 128)
26         self.l4 = torch.nn.Linear(128, 64)
27         self.l5 = torch.nn.Linear(64, 10)
28 
29     def forward(self, x):
30         x = x.view(-1, 784)  # -1其实就是自动获取mini_batch
31         x = F.relu(self.l1(x))
32         x = F.relu(self.l2(x))
33         x = F.relu(self.l3(x))
34         x = F.relu(self.l4(x))
35         return self.l5(x)  # 最后一层不做激活,不进行非线性变换
36 model = Net()
37 
38 # construct loss and optimizer
39 criterion = torch.nn.CrossEntropyLoss()
40 optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
41 
42 
43 # training cycle forward, backward, update
44 def train(epoch):
45     running_loss = 0.0
46     for batch_idx, data in enumerate(train_loader, 0):
47         # 获得一个批次的数据和标签
48         inputs, target = data
49         optimizer.zero_grad()
50         # 获得模型预测结果(64, 10)
51         outputs = model(inputs)
52         # 交叉熵代价函数outputs(64,10),target(64)
53         loss = criterion(outputs, target)
54         loss.backward()
55         optimizer.step()
56 
57         running_loss += loss.item()
58         if batch_idx % 300 == 299:
59             print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
60             running_loss = 0.0
61 
62 #名字不能设为test会被识别为程序入口
63 def hehe_11():
64     correct = 0
65     total = 0
66     with torch.no_grad():
67         for data in test_loader:
68             images, labels = data
69             outputs = model(images)
70             _, predicted = torch.max(outputs.data, dim=1)  # dim = 1 列是第0个维度,行是第1个维度
71             total += labels.size(0)
72             correct += (predicted == labels).sum().item()  # 张量之间的比较运算
73     print('accuracy on test set: %d %% ' % (100 * correct / total))
74 
75 
76 if __name__ == '__main__':
77     for epoch in range(10):
78         train(epoch)
79         hehe_11()

结果:

accuracy on test set: 97 %
[9, 300] loss: 0.039
[9, 600] loss: 0.042
[9, 900] loss: 0.040
accuracy on test set: 97 %
[10, 300] loss: 0.033
[10, 600] loss: 0.034
[10, 900] loss: 0.032
accuracy on test set: 97 %

原文地址:http://www.cnblogs.com/zhouyeqin/p/16818731.html

1. 本站所有资源来源于用户上传和网络,如有侵权请邮件联系站长! 2. 分享目的仅供大家学习和交流,请务用于商业用途! 3. 如果你也有好源码或者教程,可以到用户中心发布,分享有积分奖励和额外收入! 4. 本站提供的源码、模板、插件等等其他资源,都不包含技术服务请大家谅解! 5. 如有链接无法下载、失效或广告,请联系管理员处理! 6. 本站资源售价只是赞助,收取费用仅维持本站的日常运营所需! 7. 如遇到加密压缩包,默认解压密码为"gltf",如遇到无法解压的请联系管理员! 8. 因为资源和程序源码均为可复制品,所以不支持任何理由的退款兑现,请斟酌后支付下载 声明:如果标题没有注明"已测试"或者"测试可用"等字样的资源源码均未经过站长测试.特别注意没有标注的源码不保证任何可用性