首页 > 编程知识 正文

pytorch训练好的神经网络模型怎么保存,pytorch模型的保存与加载

时间:2023-05-05 03:39:30 阅读:220918 作者:1531

https://blog.csdn.net/remanented/article/details/89161297

一、打算开始训练自己的模型,希望能够得到较好的training_model,包括了对模型的初始化 第一种 from torch.nn import init#define the initial function to init the layer's parameters for the networkdef weigth_init(m): if isinstance(m, nn.Conv2d): init.xavier_uniform_(m.weight.data) init.constant_(m.bias.data,0.1) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(0,0.01) m.bias.data.zero_()

首先定义了一个初始化函数,接着进行调用就ok了,不过要先把网络模型实例化: 

#Define Network model = Net(args.input_channel,args.output_channel) model.apply(weigth_init) 第二种 def initNetParams(net): '''Init net parameters.''' for m in net.modules(): if isinstance(m, nn.Conv2d): init.xavier_uniform(m.weight) if m.bias: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant(m.weight, 1) init.constant(m.bias, 0) elif isinstance(m, nn.Linear): init.normal(m.weight, std=1e-3) if m.bias: init.constant(m.bias, 0) initNetParams(net) 第三种(我自己使用的)

在utils文件中定义initialize_weights函数

def initialize_weights(net): for m in net.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.ConvTranspose2d): m.weight.data.normal_(0, 0.02) m.bias.data.zero_() elif isinstance(m, nn.Linear): m.weight.data.normal_(0, 0.02) m.bias.data.zero_()

 此处是照片,告诉大家初始化的位置,代码在下面

 

import torchimport utils# 定义网络结构class CNNnet(torch.nn.Module): def __init__(self): super(CNNnet,self).__init__() self.conv1 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1), torch.nn.BatchNorm2d(16), torch.nn.ReLU() ) self.conv2 = torch.nn.Sequential( torch.nn.Conv2d(16,32,3,2,1), torch.nn.BatchNorm2d(32), torch.nn.ReLU() ) self.conv3 = torch.nn.Sequential( torch.nn.Conv2d(32,64,3,2,1), torch.nn.BatchNorm2d(64), torch.nn.ReLU() ) self.conv4 = torch.nn.Sequential( torch.nn.Conv2d(64,64,2,2,0), torch.nn.BatchNorm2d(64), torch.nn.ReLU() ) self.mlp1 = torch.nn.Linear(2*2*64,100) self.mlp2 = torch.nn.Linear(100,10) utils.initialize_weights(self) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.mlp1(x.view(x.size(0),-1)) x = self.mlp2(x) return x

以上两个文件结合就可以

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。