首页 > 编程知识 正文

pytorch 读取模型,pytorch输入数据格式

时间:2023-05-04 16:34:58 阅读:208632 作者:4785

下面的代码包含用途有:

1.训练时多GPU,推理时所有层多出一个module时替换;

2.训练模型出现层的定义不一致时替换;

3.打印训练过程中学习的参数,可视化对应参数的值。

import torchfrom collections import OrderedDictfrom your_model import Net# your net architecturenet = Net()model_path = "your_model_path"# load model parametersstate_dict = torch.load(model_path, map_location="cpu")# define a new dictnew_state_dict = OrderedDict()for k,v in state_dict.items(): # if you train model with mutil gpu if "module" in k: name = k[7:] else: name = k # if some layer's name need to change if "fc" in name: name = name.replace("fc", "classifier") # print your layer params print("layer : {}, value : {}nn".format(name, v)) # store the value to new_state_dict new_state_dict[name] = v# load model param to model architecturenet.load_state_dict(new_state_dict)

版权声明:该文观点仅代表作者本人。处理文章:请发送邮件至 三1五14八八95#扣扣.com 举报,一经查实,本站将立刻删除。