pytorch 载入预训练网络部分权重

首先weight_dict = torch.load(‘path_to_weight’)读取预训练网络的权重键值。

然后获取当前网络的权重键值
model_dict = model.state_dict() #model为当前定义的网络

最关键一步,根据键命名筛选出需要载入的部分权重。当前网络中要载入权重的部分,命名要与预训练网络相同
weight_dict = {k:v for k, v in weight_dict.items() if k in model_dict}

更新当前网络的键值字典
model_dict.update(weight_dict)

最后载入该键值字典到网络中
model.load_state_dict(model_dict)

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注