from pickletools import optimize from turtle import forward import torch from torch import nn from torch.utils.data import DataLoader from torchvision import datasets,transforms import os,sys from visdom import Visdom viz=Visdom() viz.line([0.],[0.],win='train_loss',opts=dict(title='train loss')) os.chdir(sys.path[0]) class MyRnn(nn.Module): def __init__(self,in_dim,hidden_dim,n_layer,n_class) -> None: super(MyRnn,self).__init__() self.n_layer=n_layer self.hidden_dim=hidden_dim self.lstm=nn.LSTM(in_dim,hidden_dim,n_layer,batch_first=True) self.classifer=nn.Linear(hidden_dim,n_class) def forward(self,x): out,(h_n,c_0)=self.lstm(x) out=out[:,-1,:] out=self.classifer(out) return out if torch.cuda.is_available(): model=MyRnn(784,50,2,10).cuda() else: model=MyRnn(784,50,2,10) criter=nn.CrossEntropyLoss() optimizer=torch.optim.SGD(model.parameters(),lr=1e-2) epochs=30
data_tf=transforms.Compose( [transforms.ToTensor(), transforms.Normalize([0.5],[0.5])] )
train_data=datasets.MNIST( root='./data',train=True,transform=data_tf,download=True ) test_data=train_data=datasets.MNIST( root='./data',train=False,transform=data_tf) train_loader=DataLoader(train_data,batch_size=32,shuffle=True) test_loader=DataLoader(test_data,batch_size=32,shuffle=False)
global_step=0 for epoch in range(epochs): for batch in train_loader: optimizer.zero_grad() img,label=batch if torch.cuda.is_available(): img=img.cuda() label=label.cuda()
img=img.reshape(img.size(0),img.size(1),784) global_step+=1 output=model(img) loss=criter(output,label) viz.line([loss.item()],[global_step],win='train_loss',update='append') loss.backward() optimizer.step() print("epochs:{},loss:{:.6f}".format(epoch,loss)) model.eval() eval_loss=0 eval_acc=0 for batch in test_loader: img,label=batch if torch.cuda.is_available(): img=img.cuda() label=label.cuda() img=img.reshape(img.size(0),img.size(1),784) out=model(img) loss=criter(out,label) eval_loss+=loss.detach()*label.size(0) pred=torch.max(out,dim=1)[1] num_correct=(pred==label).sum() eval_acc+=num_correct.detach() print('Test loss:{},ACC:{:.6f}'.format(eval_loss/len(test_data),eval_acc/len(test_data)))
|