MNIST数字分类

模型结构

简单的三层全连接模型:两个隐藏层

超参数定义:

import torch
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import argparse
#参数定义
parser=argparse.ArgumentParser()
parser.add_argument('--batch_size',default=64)
parser.add_argument('--learning_rate',type=float,default=1e-2)
parser.add_argument('--num_epochs',type=int,default=20)
args=parser.parse_args()

模型结构:

#模型结构与激活函数
class simpleNet(nn.Module):
def __init__(self,in_dim,n_hidden1,n_hidden2,out_dim) -> None:
super(simpleNet,self).__init__()
self.layer1=nn.Sequential(
nn.Linear(in_dim,n_hidden1),nn.ReLU(True))
self.layer2=nn.Sequential(
nn.Linear(n_hidden1,n_hidden2),nn.ReLU(True))
self.layer3=nn.Sequential(
nn.Linear(n_hidden2,out_dim))
def forward(self,x):
self.x1=self.layer1(x)
self.x2=self.layer2(self.x1)
self.x3=self.layer3(self.x2)
return self.x3

数据预处理:

此处transforms.Compose()将各种预处理操作组合在一起,transforms.ToTensor()将图片转换为Tensor数据类型,transforms.Normalize()完成数据的去中心化和标准化,减去均值再除以方差,输入第一个参数为均值,第二个参数为方差。因为本例中图片时灰度图片,只有一个通道,如果是彩色图片有三个通道,需要使用transforms.Normalize([a,b,c],[d,e,f])

#数据预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])]
)

读取数据集:

#读取数据集
train_data=datasets.MNIST(
root='./data',train=True,transform=data_tf,download=True
)
test_data=train_data=datasets.MNIST(
root='./data',train=False,transform=data_tf)
train_loader=DataLoader(train_data,batch_size=args.batch_size,shuffle=True)
test_loader=DataLoader(test_data,batch_size=args.batch_size,shuffle=False)

训练模型:

if torch.cuda.is_available():
model=simpleNet(28*28,300,100,10).cuda()
else:
model=simpleNet(28*28,300,100,10)

criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=args.learning_rate,weight_decay=0.01,momentum=0.78)

#训练模型
for epoch in range(args.num_epochs):
for batch in train_loader:
img,label=batch
img=img.reshape(img.size(0),28*28)
if torch.cuda.is_available():
img=img.cuda()
label=label.cuda()
out=model(img)

loss=criterion(out,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epochs:{},loss:{:.6f}".format(epoch,loss))

测试模型:

#测试模型
model.eval()
eval_loss=0
eval_acc=0
for batch in test_loader:
img,label=batch
img=img.reshape(img.size(0),28*28)
if torch.cuda.is_available():
img=img.cuda()
label=label.cuda()
out=model(img)
loss=criterion(out,label)

eval_loss+=loss.detach()*label.size(0)
pred=torch.max(out,dim=1)[1]
num_correct=(pred==label).sum()
eval_acc+=num_correct.detach()
print('Test loss:{},ACC:{:.6f}'.format(eval_loss/len(test_data),eval_acc/len(test_data)))

输出结果:

rs2XC.png