import torch.nn as nnclass LeNet5(nn.Module): def __init__(self): super(LeNet5,self).__init__() self.conv1 = nn.Sequential( nn.Conv2d(1,6,5,1,2) , nn.ReLU() , nn.MaxPool2d(2,2) ) self.conv2 = nn.Sequential( nn.Conv2d(6,16,5) , nn.ReLU() , nn.MaxPool2d(2,2) ) self.fc1 = nn.Sequential( nn.Linear(16*5*5,120) , nn.ReLU() ) self.fc2 = nn.Sequential( nn.Linear(120,84) , nn.ReLU() ) self.fc3 = nn.Linear(84,10) def forward(self,pic): pic = self.conv1(pic) pic = self.conv2(pic) pic = pic.view(pic.size()[0],-1) pic = self.fc1(pic) pic = self.fc2(pic) pic = self.fc3(pic) return pic