import torch
from torch import nn,optim
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
import argparse
#参数定义
parser=argparse.ArgumentParser()
parser.add_argument('--batch_size',default=64)
parser.add_argument('--learning_rate',type=float,default=1e-2)
parser.add_argument('--num_epochs',type=int,default=20)
args=parser.parse_args()



#模型结构与激活函数
class simpleCNN(nn.Module):
def __init__(self) -> None:
super(simpleCNN,self).__init__()
#[b,1,28,28]
self.layer1=nn.Sequential(
nn.Conv2d(1,16,kernel_size=3),nn.BatchNorm2d(16),nn.ReLU(True))
#[b,16,26,26]
self.layer2=nn.Sequential(
nn.Conv2d(16,32,kernel_size=3),nn.BatchNorm2d(32),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2))
#[b,32,24,24]->[b,32,12,12]
self.layer3=nn.Sequential(
nn.Conv2d(32,64,kernel_size=3),nn.BatchNorm2d(64),nn.ReLU(True))
#[b,64,10,10]
self.layer4=nn.Sequential(
nn.Conv2d(64,128,kernel_size=3),nn.BatchNorm2d(128),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2))
#[b,128,8,8]->#[b,128,4,4]
self.fc=nn.Sequential(
nn.Linear(128*4*4,1024)
,nn.ReLU(True)
,nn.Linear(1024,128)
,nn.ReLU(True)
,nn.Linear(128,10)
)
def forward(self,x):
self.x1=self.layer1(x)
self.x2=self.layer2(self.x1)
self.x3=self.layer3(self.x2)
self.x4=self.layer4(self.x3)
self.x4=self.x4.reshape(self.x4.size(0),-1)
self.out=self.fc(self.x4)
return self.out

#数据预处理
data_tf=transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize([0.5],[0.5])]
)


#读取数据集
train_data=datasets.MNIST(
root='./data',train=True,transform=data_tf,download=True
)
test_data=datasets.MNIST(
root='./data',train=False,transform=data_tf)
train_loader=DataLoader(train_data,batch_size=args.batch_size,shuffle=True)
test_loader=DataLoader(test_data,batch_size=10000,shuffle=False)

if torch.cuda.is_available():
model=simpleCNN().cuda()
else:
model=simpleCNN()

criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=args.learning_rate,momentum=0.78)

#训练模型
for epoch in range(args.num_epochs):
for batch in train_loader:
img,label=batch
if torch.cuda.is_available():
img=img.cuda()
label=label.cuda()
out=model(img)

loss=criterion(out,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("epochs:{},loss:{:.6f}".format(epoch,loss))

#测试模型
model.eval()
eval_loss=0
eval_acc=0
for batch in test_loader:
img,label=batch
if torch.cuda.is_available():
img=img.cuda()
label=label.cuda()
out=model(img)
loss=criterion(out,label)

eval_loss+=loss.detach()*label.size(0)
pred=torch.max(out,dim=1)[1]
num_correct=(pred==label).sum()
eval_acc+=num_correct.detach()
print('Test loss:{},ACC:{:.6f}'.format(eval_loss/len(test_data),eval_acc/len(test_data)))

raHty.png