In [6]:
Copied!
import torch
from torch import nn
# 单个Tensor写
x = torch.ones(3)
torch.save(x, 'x.pt')
# 单个Tensor读
x2 = torch.load('x.pt')
x2
import torch
from torch import nn
# 单个Tensor写
x = torch.ones(3)
torch.save(x, 'x.pt')
# 单个Tensor读
x2 = torch.load('x.pt')
x2
C:\Users\15087\AppData\Local\Temp\ipykernel_25544\1906428270.py:8: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
x2 = torch.load('x.pt')
Out[6]:
tensor([1., 1., 1.])
In [7]:
Copied!
# Tensor列表写
y = torch.zeros(4)
torch.save([x, y], 'x.pt')
# Tensor列表读
xy_list = torch.load('x.pt')
xy_list
# Tensor列表写
y = torch.zeros(4)
torch.save([x, y], 'x.pt')
# Tensor列表读
xy_list = torch.load('x.pt')
xy_list
C:\Users\15087\AppData\Local\Temp\ipykernel_25544\3493929128.py:5: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
xy_list = torch.load('x.pt')
Out[7]:
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
In [8]:
Copied!
# Tensor字典写
torch.save({'x': x, 'y': y}, 'x_dict.pt')
# Tensor字典读
xy = torch.load('x_dict.pt')
xy
# Tensor字典写
torch.save({'x': x, 'y': y}, 'x_dict.pt')
# Tensor字典读
xy = torch.load('x_dict.pt')
xy
C:\Users\15087\AppData\Local\Temp\ipykernel_25544\1202955033.py:4: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
xy = torch.load('x_dict.pt')
Out[8]:
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
读写模型¶
In [9]:
Copied!
# net.state_dict()用来获取模型的参数和缓冲区(比如权重、偏置)组成的字典
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
# net.state_dict()用来获取模型的参数和缓冲区(比如权重、偏置)组成的字典
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
Out[9]:
OrderedDict([('hidden.weight',
tensor([[-0.0391, -0.0036, 0.3501],
[ 0.3094, -0.0750, -0.1477]])),
('hidden.bias', tensor([ 0.3402, -0.5161])),
('output.weight', tensor([[-0.0817, -0.0907]])),
('output.bias', tensor([-0.2088]))])
In [10]:
Copied!
# 模型中,只有含可学习参数的层才有status_dict
# optimizer也有status_dict
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
# 模型中,只有含可学习参数的层才有status_dict
# optimizer也有status_dict
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
Out[10]:
{'state': {},
'param_groups': [{'lr': 0.001,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'maximize': False,
'foreach': None,
'differentiable': False,
'fused': None,
'params': [0, 1, 2, 3]}]}
In [11]:
Copied!
# # 模型保存
# torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth
# # 模型加载
# model = TheModelClass(*args, **kwargs)
# model.load_state_dict(torch.load(PATH))
X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
# # 模型保存
# torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth
# # 模型加载
# model = TheModelClass(*args, **kwargs)
# model.load_state_dict(torch.load(PATH))
X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
C:\Users\15087\AppData\Local\Temp\ipykernel_25544\1917582696.py:14: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. net2.load_state_dict(torch.load(PATH))
Out[11]:
tensor([[True],
[True]])