PyTorch предоставляет ряд специальных функций для сохранения данных во внешнем файле и для обратной загрузки данных из файла. Для сохранения данных применяется метод torch.save():
torch.save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=True)
Функция принимает следующие параметры:
obj (object): сохраняемый объект
f (Union[str, PathLike[str], IO[bytes]]): файлоподобный объект (должен реализовывать операции записи и сброса) или строка или объект os.PathLike, которые содержат имя файла
pickle_module (Any): модуль, используемый для сериализации метаданных и объектов
pickle_protocol (int): может быть указан для переопределения протокола по умолчанию
Из этих параметров обязательными являются только первые два - собственно сохраняемые данные и файл (путь к файлу), где надо выполнить сохранение. Возьмем простой пример:
import torch # произвольные данные для сохранения в виде тензора x = torch.tensor([0, 1, 2, 3, 4]) # сохраняем данные в файл "tensor.pt" torch.save(x, "tensor.pt")
Здесь определяем простейший тензор - x и сохраняем его в файле "tensor.pt" (".pt" - одно из расширений файлов, используетмых в PyTorch).
И после выполнения этой программы в текущей папке проекта мы увидим файл "tensor.pt". Это бинарный файл, для сохранения в который применяется бинарная сериализации.
Для загрузки ранее сохраненных данных из файла применяется функция load():
torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[source]
Она принимает следующие параметры:
f: файлоподобный объект (должен реализовывать функции read(), readline(), tell() и seek()), или строка, или объект os.PathLike, которые содержат имя файла.
map_location: функция, объект torch.device, строка или словарь, которые указывают, как переназначать места хранения.
pickle_module: модуль, используемый для десериализации метаданных и объектов (должен соответствовать pickle_module, используемому для сериализации файла).
weights_only: указывает, следует ли ограничивать десериализацию только тензорами, примитивными типами, словарями и любыми типами,
добавленными с помощью torch.serialization.add_safe_globals()
mmap: указывает, следует ли файл отображаться на память, а не загружаться полностью в память.
Как правило, хранилища тензоров в файле сначала перемещаются с диска в память ЦП, после чего они перемещаются в место, которое было указано при сохранении или задано параметром map_location.
Этот второй шаг ничего не делает, если конечное местоположение - центральный процессор. Если установлен флаг mmap, вместо копирования хранилищ тензоров с диска в память ЦП на первом шаге выполняется отображение файла f, что означает, что тензоры будут загружаться по мере обращения к ним.
pickle_load_args: необязательные аргументы ключевых слов
Например, применим эту функцию для загрузки ранее сохраненного тензора из файла:
import torch
# произвольные данные для сохранения в виде тензора
x = torch.tensor([0, 1, 2, 3, 4])
# сохраняем данные в файл "tensor.pt"
torch.save(x, "tensor.pt")
# загружаем данные из файла
y = torch.load("tensor.pt")
print(y) # tensor([0, 1, 2, 3, 4])
Подобно тому , как мы можем сохранять и загружать тензоры, мы также можем сохранять и загружать модели и их состояния. И это играет очень важную роль, поскольку после обучения модели мы можем сохранить ее обученные параметры - весы (weight) и смещения (bias) в файл, а затем при необходимости загрузить их обратно, избежав посторного обучения, для которого может потребоваться продолжительное время и значительные аппаратные ресурсы.
Аналогично как мы сохраняем тензор, мы можем сохранить во внешнем файле модель, передав ее в функцию torch.save():
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features=4, out_features=2),
nn.ReLU(),
nn.Linear(2, 1)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
# сохраняем модель в файл
torch.save(model, "model.pth")
Здесь определена произвольная тестовая модель NeuralNetwork с 3 слоями, объект которой сохраняется в файл "model.pth" (".pth" - это стандартное расширение, которое используется в PyTorch). И после выполнения этой программы в текущей папке проекта мы увидим файл "model.pth". Это бинарный файл, для сохранения в который применяется бинарная сериализации.
Затем мы можем загрузить модель из файла:
model2 = torch.load("model.pth", weights_only=False)
При сохранении только весов модели достаточно передать в функцию torch.save() результат метода model.state_dict() (то есть словрь с весами и смещениями):
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features=4, out_features=2),
nn.ReLU(),
nn.Linear(2, 1)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
print(model.state_dict())
# сохраняем параметры модели в файл
torch.save(model.state_dict(), "model_weights.pth")
Здесь данные сохраняются в файле "model_weights.pth". В целом там будет таже информация которую мы выводим на консоль с помощью print(model.state_dict())
Для загрузки весов модели необходимо сначала создать экземпляр той же модели, а затем загрузить параметры с помощью метода load_state_dict():
load_state_dict(state_dict, strict=True, assign=False)
Этот метод принимает следующие параметры:
state_dict: словарь, который содержит параметры
strict: следует ли строго требовать, чтобы ключи в state_dict соответствовали ключам, возвращаемым функцией state_dict() этого модуля
assign: если установлено значение False, свойства тензоров сохраняются в текущем модуле, а если True - сохраняются в словаре состояния
Например, загрузим ранее сохраненные веса одной модели в другую модель:
import torch
from torch import nn
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features=4, out_features=2),
nn.ReLU(),
nn.Linear(2, 1)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model1 = NeuralNetwork()
# сохраняем веса model1 в файл
torch.save(model1.state_dict(), "model_weights.pth")
print("Веса model1:")
print(model1.state_dict())
model2 = NeuralNetwork() # веса model1 и model2 не совпадают
print("\nВеса model2 до загрузки:")
print(model2.state_dict())
# загружаем веса
model2.load_state_dict(torch.load("model_weights.pth", weights_only=True))
model2.eval()
print("\nВеса model2 после загрузки:")
print(model2.state_dict())
Пример консольного вывода:
Веса model1:
OrderedDict({'linear_relu_stack.0.weight': tensor([[ 0.3493, 0.1946, 0.3091, 0.4323],
[-0.1120, 0.3059, 0.1715, -0.3470]]), 'linear_relu_stack.0.bias': tensor([0.2972, 0.0108]), 'linear_relu_stack.2.weight': tensor([[0.5835, 0.2446]]), 'linear_relu_stack.2.bias': tensor([0.3949])})
Веса model2 до загрузки:
OrderedDict({'linear_relu_stack.0.weight': tensor([[ 0.4207, -0.1752, -0.1711, 0.0560],
[ 0.1033, -0.3037, 0.4582, -0.3819]]), 'linear_relu_stack.0.bias': tensor([-0.4162, -0.0339]), 'linear_relu_stack.2.weight': tensor([[ 0.2078, -0.5866]]), 'linear_relu_stack.2.bias': tensor([0.3684])})
Веса model2 после загрузки:
OrderedDict({'linear_relu_stack.0.weight': tensor([[ 0.3493, 0.1946, 0.3091, 0.4323],
[-0.1120, 0.3059, 0.1715, -0.3470]]), 'linear_relu_stack.0.bias': tensor([0.2972, 0.0108]), 'linear_relu_stack.2.weight': tensor([[0.5835, 0.2446]]), 'linear_relu_stack.2.bias': tensor([0.3949])})
В приведенном выше коде мы установили weights_only=True, чтобы ограничить функции, выполняемые во время десериализации, только теми, которые необходимы для загрузки весов.
Кроме того, вызываеся метод model2.eval(), чтобы установить режим оценки для слоев Dropout и пакетной нормализации. В противном случае результаты инференса будут непоследовательными.
(Инференс (inference) - это процесс, при котором ИИ использует ранее обученную модель для принятия решений на основе новых данных.)
Возьмем пример с обучением модели из прошлой статьи и добавим к нему сохранение весов в файл после обучения:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# обучение модели
def train_loop(dataloader, model, loss_fn, optimizer, device):
size = len(dataloader.dataset)
# устанавливаем модель в режим обучения.
# хотя в данном случае это необязательно, но в целом может быть важно для пакетной нормализации и отсева слоев
model.train()
for batch, (X, y) in enumerate(dataloader):
# перемещаем тензоры на устройство
X = X.to(device)
y = y.to(device)
# Вычисляем предсказание (Forward pass)
pred = model(X)
# Вычисляем ошибку (Loss)
loss = loss_fn(pred, y)
# Обратное распространение ошибки (Backpropagation)
loss.backward() # Считаем новые градиенты
optimizer.step() # Обновляем веса
optimizer.zero_grad() # Сбрасываем градиенты с прошлого шага
# определяем устройство для обучения модели
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(device)
model = NeuralNetwork().to(device) # устанавливаем модель и перемещаем ее на устройство
learning_rate = 1e-3 # Скорость обучения
batch_size = 64 # Размер пакета
epochs = 10 # Количество эпох
# устанавливаем loss-функцию
loss_fn = nn.CrossEntropyLoss()
# устанавливаем оптимизатор
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
for t in range(epochs):
print(f"Эпоха: {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer, device)
print("Завершено!")
# Сохраняем веса модели в файл
torch.save(model.state_dict(), "model_weights.pth")
print("Веса модели успешно сохранены в файл model_weights.pth")