Сохранение и загрузка данных и модели

Последнее обновление: 14.12.2025

Сохранение данных

PyTorch предоставляет ряд специальных функций для сохранения данных во внешнем файле и для обратной загрузки данных из файла. Для сохранения данных применяется метод torch.save():

torch.save(obj, f, pickle_module=pickle, pickle_protocol=2, _use_new_zipfile_serialization=True)

Функция принимает следующие параметры:

  • obj (object): сохраняемый объект

  • f (Union[str, PathLike[str], IO[bytes]]): файлоподобный объект (должен реализовывать операции записи и сброса) или строка или объект os.PathLike, которые содержат имя файла

  • pickle_module (Any): модуль, используемый для сериализации метаданных и объектов

  • pickle_protocol (int): может быть указан для переопределения протокола по умолчанию

Из этих параметров обязательными являются только первые два - собственно сохраняемые данные и файл (путь к файлу), где надо выполнить сохранение. Возьмем простой пример:

import torch

# произвольные данные для сохранения в виде тензора
x = torch.tensor([0, 1, 2, 3, 4])
# сохраняем данные в файл "tensor.pt"
torch.save(x, "tensor.pt")

Здесь определяем простейший тензор - x и сохраняем его в файле "tensor.pt" (".pt" - одно из расширений файлов, используетмых в PyTorch).

И после выполнения этой программы в текущей папке проекта мы увидим файл "tensor.pt". Это бинарный файл, для сохранения в который применяется бинарная сериализации.

Загрузка данных

Для загрузки ранее сохраненных данных из файла применяется функция load():

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=True, mmap=None, **pickle_load_args)[source]

Она принимает следующие параметры:

  • f: файлоподобный объект (должен реализовывать функции read(), readline(), tell() и seek()), или строка, или объект os.PathLike, которые содержат имя файла.

  • map_location: функция, объект torch.device, строка или словарь, которые указывают, как переназначать места хранения.

  • pickle_module: модуль, используемый для десериализации метаданных и объектов (должен соответствовать pickle_module, используемому для сериализации файла).

  • weights_only: указывает, следует ли ограничивать десериализацию только тензорами, примитивными типами, словарями и любыми типами, добавленными с помощью torch.serialization.add_safe_globals()

  • mmap: указывает, следует ли файл отображаться на память, а не загружаться полностью в память. Как правило, хранилища тензоров в файле сначала перемещаются с диска в память ЦП, после чего они перемещаются в место, которое было указано при сохранении или задано параметром map_location. Этот второй шаг ничего не делает, если конечное местоположение - центральный процессор. Если установлен флаг mmap, вместо копирования хранилищ тензоров с диска в память ЦП на первом шаге выполняется отображение файла f, что означает, что тензоры будут загружаться по мере обращения к ним.

  • pickle_load_args: необязательные аргументы ключевых слов

Например, применим эту функцию для загрузки ранее сохраненного тензора из файла:

import torch

# произвольные данные для сохранения в виде тензора
x = torch.tensor([0, 1, 2, 3, 4])
# сохраняем данные в файл "tensor.pt"
torch.save(x, "tensor.pt")

# загружаем данные из файла
y = torch.load("tensor.pt")
print(y)    # tensor([0, 1, 2, 3, 4])

Сохранение и загрузка состояния модели

Подобно тому , как мы можем сохранять и загружать тензоры, мы также можем сохранять и загружать модели и их состояния. И это играет очень важную роль, поскольку после обучения модели мы можем сохранить ее обученные параметры - весы (weight) и смещения (bias) в файл, а затем при необходимости загрузить их обратно, избежав посторного обучения, для которого может потребоваться продолжительное время и значительные аппаратные ресурсы.

Сохранение и загрузка моделей

Аналогично как мы сохраняем тензор, мы можем сохранить во внешнем файле модель, передав ее в функцию torch.save():

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features=4, out_features=2),
            nn.ReLU(),
            nn.Linear(2, 1)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

# сохраняем модель в файл
torch.save(model, "model.pth")

Здесь определена произвольная тестовая модель NeuralNetwork с 3 слоями, объект которой сохраняется в файл "model.pth" (".pth" - это стандартное расширение, которое используется в PyTorch). И после выполнения этой программы в текущей папке проекта мы увидим файл "model.pth". Это бинарный файл, для сохранения в который применяется бинарная сериализации.

Затем мы можем загрузить модель из файла:

model2 = torch.load("model.pth", weights_only=False)

Сохранение состояния модели

При сохранении только весов модели достаточно передать в функцию torch.save() результат метода model.state_dict() (то есть словрь с весами и смещениями):

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features=4, out_features=2),
            nn.ReLU(),
            nn.Linear(2, 1)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork()

print(model.state_dict())

# сохраняем параметры модели в файл
torch.save(model.state_dict(), "model_weights.pth")

Здесь данные сохраняются в файле "model_weights.pth". В целом там будет таже информация которую мы выводим на консоль с помощью print(model.state_dict())

Загрузка состояния модели

Для загрузки весов модели необходимо сначала создать экземпляр той же модели, а затем загрузить параметры с помощью метода load_state_dict():

load_state_dict(state_dict, strict=True, assign=False)

Этот метод принимает следующие параметры:

  • state_dict: словарь, который содержит параметры

  • strict: следует ли строго требовать, чтобы ключи в state_dict соответствовали ключам, возвращаемым функцией state_dict() этого модуля

  • assign: если установлено значение False, свойства тензоров сохраняются в текущем модуле, а если True - сохраняются в словаре состояния

Например, загрузим ранее сохраненные веса одной модели в другую модель:

import torch
from torch import nn

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features=4, out_features=2),
            nn.ReLU(),
            nn.Linear(2, 1)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model1 = NeuralNetwork()
# сохраняем веса model1 в файл
torch.save(model1.state_dict(), "model_weights.pth")
print("Веса model1:")
print(model1.state_dict())

model2 = NeuralNetwork()    # веса model1 и model2 не совпадают
print("\nВеса model2 до загрузки:")
print(model2.state_dict())

# загружаем веса
model2.load_state_dict(torch.load("model_weights.pth", weights_only=True))
model2.eval()

print("\nВеса model2 после загрузки:")
print(model2.state_dict())

Пример консольного вывода:

Веса model1:
OrderedDict({'linear_relu_stack.0.weight': tensor([[ 0.3493,  0.1946,  0.3091,  0.4323],
        [-0.1120,  0.3059,  0.1715, -0.3470]]), 'linear_relu_stack.0.bias': tensor([0.2972, 0.0108]), 'linear_relu_stack.2.weight': tensor([[0.5835, 0.2446]]), 'linear_relu_stack.2.bias': tensor([0.3949])})

Веса model2 до загрузки:
OrderedDict({'linear_relu_stack.0.weight': tensor([[ 0.4207, -0.1752, -0.1711,  0.0560],
        [ 0.1033, -0.3037,  0.4582, -0.3819]]), 'linear_relu_stack.0.bias': tensor([-0.4162, -0.0339]), 'linear_relu_stack.2.weight': tensor([[ 0.2078, -0.5866]]), 'linear_relu_stack.2.bias': tensor([0.3684])})

Веса model2 после загрузки:
OrderedDict({'linear_relu_stack.0.weight': tensor([[ 0.3493,  0.1946,  0.3091,  0.4323],
        [-0.1120,  0.3059,  0.1715, -0.3470]]), 'linear_relu_stack.0.bias': tensor([0.2972, 0.0108]), 'linear_relu_stack.2.weight': tensor([[0.5835, 0.2446]]), 'linear_relu_stack.2.bias': tensor([0.3949])})

В приведенном выше коде мы установили weights_only=True, чтобы ограничить функции, выполняемые во время десериализации, только теми, которые необходимы для загрузки весов. Кроме того, вызываеся метод model2.eval(), чтобы установить режим оценки для слоев Dropout и пакетной нормализации. В противном случае результаты инференса будут непоследовательными. (Инференс (inference) - это процесс, при котором ИИ использует ранее обученную модель для принятия решений на основе новых данных.)

Пример сохранения обученной модели

Возьмем пример с обучением модели из прошлой статьи и добавим к нему сохранение весов в файл после обучения:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

train_dataloader = DataLoader(training_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


# обучение модели
def train_loop(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    # устанавливаем модель в режим обучения. 
    # хотя в данном случае это необязательно, но в целом может быть важно для пакетной нормализации и отсева слоев
    model.train()
    for batch, (X, y) in enumerate(dataloader):

        # перемещаем тензоры на устройство
        X = X.to(device)
        y = y.to(device)

        # Вычисляем предсказание (Forward pass)
        pred = model(X)
        # Вычисляем ошибку (Loss)
        loss = loss_fn(pred, y)

        # Обратное распространение ошибки (Backpropagation)
        loss.backward()         # Считаем новые градиенты
        optimizer.step()        # Обновляем веса
        optimizer.zero_grad()   # Сбрасываем градиенты с прошлого шага

# определяем устройство для обучения модели
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(device)
model = NeuralNetwork().to(device)     # устанавливаем модель и перемещаем ее на устройство

learning_rate = 1e-3        # Скорость обучения
batch_size = 64             # Размер пакета
epochs = 10                 # Количество эпох

# устанавливаем loss-функцию
loss_fn = nn.CrossEntropyLoss()
# устанавливаем оптимизатор
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Эпоха: {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer, device)

print("Завершено!")

# Сохраняем веса модели в файл
torch.save(model.state_dict(), "model_weights.pth")
print("Веса модели успешно сохранены в файл model_weights.pth")
Помощь сайту
Юмани:
410011174743222
Номер карты:
4048415020898850