import torch from torch import nn,optim from torch.utils.data import DataLoader from torchvision import datasets,transforms import argparse
parser=argparse.ArgumentParser() parser.add_argument('--batch_size',default=64) parser.add_argument('--learning_rate',type=float,default=1e-2) parser.add_argument('--num_epochs',type=int,default=20) args=parser.parse_args()
class simpleCNN(nn.Module): def __init__(self) -> None: super(simpleCNN,self).__init__() self.layer1=nn.Sequential( nn.Conv2d(1,16,kernel_size=3),nn.BatchNorm2d(16),nn.ReLU(True)) self.layer2=nn.Sequential( nn.Conv2d(16,32,kernel_size=3),nn.BatchNorm2d(32),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2)) self.layer3=nn.Sequential( nn.Conv2d(32,64,kernel_size=3),nn.BatchNorm2d(64),nn.ReLU(True)) self.layer4=nn.Sequential( nn.Conv2d(64,128,kernel_size=3),nn.BatchNorm2d(128),nn.ReLU(True),nn.MaxPool2d(kernel_size=2,stride=2)) self.fc=nn.Sequential( nn.Linear(128*4*4,1024) ,nn.ReLU(True) ,nn.Linear(1024,128) ,nn.ReLU(True) ,nn.Linear(128,10) ) def forward(self,x): self.x1=self.layer1(x) self.x2=self.layer2(self.x1) self.x3=self.layer3(self.x2) self.x4=self.layer4(self.x3) self.x4=self.x4.reshape(self.x4.size(0),-1) self.out=self.fc(self.x4) return self.out
data_tf=transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5],[0.5])] )
train_data=datasets.MNIST( root='./data',train=True,transform=data_tf,download=True ) test_data=datasets.MNIST( root='./data',train=False,transform=data_tf) train_loader=DataLoader(train_data,batch_size=args.batch_size,shuffle=True) test_loader=DataLoader(test_data,batch_size=10000,shuffle=False)
if torch.cuda.is_available(): model=simpleCNN().cuda() else: model=simpleCNN()
criterion=nn.CrossEntropyLoss() optimizer=optim.SGD(model.parameters(),lr=args.learning_rate,momentum=0.78)
for epoch in range(args.num_epochs): for batch in train_loader: img,label=batch if torch.cuda.is_available(): img=img.cuda() label=label.cuda() out=model(img) loss=criterion(out,label) optimizer.zero_grad() loss.backward() optimizer.step() print("epochs:{},loss:{:.6f}".format(epoch,loss))
model.eval() eval_loss=0 eval_acc=0 for batch in test_loader: img,label=batch if torch.cuda.is_available(): img=img.cuda() label=label.cuda() out=model(img) loss=criterion(out,label) eval_loss+=loss.detach()*label.size(0) pred=torch.max(out,dim=1)[1] num_correct=(pred==label).sum() eval_acc+=num_correct.detach() print('Test loss:{},ACC:{:.6f}'.format(eval_loss/len(test_data),eval_acc/len(test_data)))
|