下面的代码包含用途有:
1.训练时多GPU,推理时所有层多出一个module时替换;
2.训练模型出现层的定义不一致时替换;
3.打印训练过程中学习的参数,可视化对应参数的值。
import torchfrom collections import OrderedDictfrom your_model import Net# your net architecturenet = Net()model_path = "your_model_path"# load model parametersstate_dict = torch.load(model_path, map_location="cpu")# define a new dictnew_state_dict = OrderedDict()for k,v in state_dict.items(): # if you train model with mutil gpu if "module" in k: name = k[7:] else: name = k # if some layer's name need to change if "fc" in name: name = name.replace("fc", "classifier") # print your layer params print("layer : {}, value : {}nn".format(name, v)) # store the value to new_state_dict new_state_dict[name] = v# load model param to model architecturenet.load_state_dict(new_state_dict)