Применение обученной модели

Последнее обновление: 14.12.2025

В прошлой теме было рассмотрено обучение модели в PyTorch на основе встроенного набора данных FashionMNIST с последующим сохранением весов в файл. Рассмотрим, как мы можем использовать эту модель для распознавания изображений одежды на простом примере.

Чтобы заново не обучать модель, предположим, что веса обученной модели у нас сохранены в файле model_weights.pth. Как это сделать, можно посмотреть в прошлой теме. А сейчас определим следующий код, который будет загружать и использовать эти веса:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# определяем класс модели
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# Создаем "пустую" модель
model = NeuralNetwork()

# Загружаем веса из файла
model.load_state_dict(torch.load("model_weights.pth"))

# Переключаем в режим оценки
# Это отключает слои типа Dropout и BatchNormalization, которые нужны только при обучении
model.eval()

print("Модель загружена!")

# Список классов FashionMNIST для расшифровки ответа
classes = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]

# Загружаем одну картинку из тестового набора для проверки
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# Берем первую картинку (индекс 0)
x, y = test_data[0]

# Добавляем размерность батча (было [1, 28, 28] -> стало [1, 1, 28, 28]), так как модель ждет пакет
# Если модель обучалась на GPU, здесь данные могут остаться на CPU, так как загруженная модель по умолчанию на CPU
with torch.no_grad(): # Отключаем градиенты для скорости
    # Этот код выполняет техническую адаптацию данных: он превращает "просто картинку" в "пакет из одной картинки".
    pred = model(x.unsqueeze(0))
    predicted_class_index = pred[0].argmax(0)
    predicted_class_name = classes[predicted_class_index]
    actual_class_name = classes[y]

print(f"Предсказано: {predicted_class_name}, На самом деле: {actual_class_name}")

Рассмотрим ключевые моменты. Прежде всего загружаем ранее сохраненные веса модели из файла "model_weights.pth":

model = NeuralNetwork()

model.load_state_dict(torch.load("model_weights.pth"))

model.eval()

В данном случае модель представляет класс NeuralNetwork, соответственно веса из файла "model_weights.pth" также должны быть предназначены для модели NeuralNetwork.

Далее берем список классов FashionMNIST для расшифровки ответа:

classes = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]

Для теста будем использовать картинку из того же набора FashionMNIST. Поэтому загружаем из тестового набора для проверки одну картинку (самую первую):

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

x, y = test_data[0]

Поскольку нам не надо обучать модель, то для большей производительности определяем блок torch.no_grad(), чтобы отключить вычисление градиентов

with torch.no_grad()
    pred = model(x.unsqueeze(0))
    predicted_class_index = pred[0].argmax(0)
    predicted_class_name = classes[predicted_class_index]
    actual_class_name = classes[y]

Этот блок кода и выполняет процесс предсказания (инференса) для одной конкретной картинки и расшифровку результата в понятный человеку текст. Разберем каждую строку подробно:

  1. pred = model(x.unsqueeze(0))

    Вызов x.unsqueeze(0) превращает картинку формата [1, 28, 28] в пакет [1, 1, 28, 28]. Зачем это нужно? Дело в том, что нейронные сети в PyTorch (и слои вроде nn.Linear, nn.Conv2d) обучены работать с пакетами (батчами) данных, а не с одиночными объектами. Они всегда ожидают на вход 4 измерения (для картинок). Однако, когда мы берем одну картинку из датасета (x, y = test_data[0]), она имеет только 3 измерения. Метод .unsqueeze(0) добавляет новую размерность (размером 1) по указанному индексу (в данном случае индекс 0, то есть в самое начало). По сути он берет картинку x и борачивает ее в "искусственный" пакет размером 1.

    Затем результат передается в модель: model(...)

    Результат модели получаем в переменную pred. Это тензор размера [1, 10] (1 картинка, 10 классов). Он содержит логиты - "сырые" числа (баллы уверенности).

  2. predicted_class_index = pred[0].argmax(0)

    Это самая важная математическая часть, где:

    • pred[0]: берем нулевой элемент, чтобы получить просто вектор из 10 чисел (pred - это пакет, даже если в нем только один элемент)

    • .argmax(0): эта функция возвращает индекс самого большого числа в этом списке

    Если наглядно, то представим, что модель выдала такие оценки для 3 классов:

    • Футболка: 0.5 (Индекс 0)

    • Брюки: 9.2 (Индекс 1)

    • Платье: 0.1 (Индекс 2)

    Тогда argmax(0) вернет число 1, потому что 9.2 - самое большое число, и оно стоит под индексом 1.

  3. predicted_class_name = classes[predicted_class_index]

    Здесь мы берем полученный индекс (например, 1) и смотрим в нашем списке названий (classes). Например, если predicted_class_index равно 1, то classes[1] вернет строку "Trouser".

  4. actual_class_name = classes[y]

    Здесь мы делаем то же самое, но для "правильного ответа" (y), который мы взяли из датасета. Это нужно, чтобы сравнить предсказание с реальностью. То есть в итоге у нас есть две переменные: то, что модель "подумала" (predicted_class_name), и то, что было "на самом деле" (actual_class_name)

В конце выводим результат на консоль. В результате после выполнения программы вы увидим что-то типа следующего:

Модель загружена!
Предсказано: Ankle boot, На самом деле: Ankle boot

Визуализация

Теперь пойдем чуть дальше - возьмем случайные 9 картинок из тестового набора, прогоним их через модель и построим таблицу, где, если модель угадала правильно, текст будет зеленым, если ошиблась - красным. Для визуализации изображений используем библиотеку Matplotlib, поэтому перед запуском кода убедитесь, что она у вас установлена. Если нет, установим ее с помощью команды:

pip install matplotlib

В коде также, как и в примере выше, будем загружать веса модели из "model_weights.pth":


import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

import matplotlib.pyplot as plt
import random

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

device = "cuda" if torch.cuda.is_available() else "cpu"

model = NeuralNetwork().to(device)

# Загружаем веса из файла
model.load_state_dict(torch.load("model_weights.pth"))

# Переключаем в режим оценки
model.eval()

print("Модель загружена!")


# Список имен классов (чтобы видеть названия, а не цифры)
classes = [
    "T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
    "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
# набор для тестирования
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
# --- 3. Визуализация ---
# Создаем фигуру (окно) размером 10x10 дюймов
figure = plt.figure(figsize=(10, 10))
cols, rows = 3, 3 # Будет сетка 3 на 3 картинки

for i in range(1, cols * rows + 1):
    # Берем случайный индекс из датасета
    sample_idx = torch.randint(len(test_data), size=(1,)).item()
    img, label = test_data[sample_idx]
    
    # Отправляем в модель (добавляем измерение батча через unsqueeze)
    img = img.to(device)
    with torch.no_grad():
        pred_logits = model(img.unsqueeze(0))
        pred_label = pred_logits.argmax(1).item()

    # Добавляем подграфик
    figure.add_subplot(rows, cols, i)
    
    # Определяем цвет текста: Зеленый если верно, Красный если ошибка
    title_color = "green" if pred_label == label else "red"
    
    plt.title(f"Pred: {classes[pred_label]}\nTrue: {classes[label]}", color=title_color)
    plt.axis("off") # Убираем оси координат
    
    # Отрисовываем картинку
    # img.squeeze() убирает размерность канала (1, 28, 28) -> (28, 28), чтобы matplotlib понял
    plt.imshow(img.cpu().squeeze(), cmap="gray")

plt.show()

Поскольку код в целом снабжен комментариями, то подсвечу лишь отдельные моменты:

  1. torch.randint: выбирает случайную картинку, чтобы при каждом запуске скрипта можно было увидеть разные примеры.

    sample_idx = torch.randint(len(test_data), size=(1,)).item()
    img, label = test_data[sample_idx]
    
  2. title_color: сравнивает предсказание с истиной. Это очень удобно для визуального анализа слабых мест модели (например, вы часто будете видеть, что модель путает Shirt (Рубашку), T-shirt (Футболку) и Coat (Пальто), так как они похожи).

    title_color = "green" if pred_label == label else "red"
  3. cmap="gray": говорит Matplotlib рисовать картинку в черно-белых тонах (иначе по умолчанию она была бы желто-фиолетовой).

Запустим приложение и мы увидим визуальный результат работы модели:

Визуализация корректности работы модели в PyTorch

Как видно из скриншота, модель угадала 5 примеров и ошиблась в 4 случаях. В принципе не удивительно, я сам не особо понимаю, что на ряде картинок изображено.

Помощь сайту
Юмани:
410011174743222
Номер карты:
4048415020898850