В прошлой теме было рассмотрено обучение модели в PyTorch на основе встроенного набора данных FashionMNIST с последующим сохранением весов в файл. Рассмотрим, как мы можем использовать эту модель для распознавания изображений одежды на простом примере.
Чтобы заново не обучать модель, предположим, что веса обученной модели у нас сохранены в файле model_weights.pth. Как это сделать, можно посмотреть в прошлой теме. А сейчас определим следующий код, который будет загружать и использовать эти веса:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
# определяем класс модели
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# Создаем "пустую" модель
model = NeuralNetwork()
# Загружаем веса из файла
model.load_state_dict(torch.load("model_weights.pth"))
# Переключаем в режим оценки
# Это отключает слои типа Dropout и BatchNormalization, которые нужны только при обучении
model.eval()
print("Модель загружена!")
# Список классов FashionMNIST для расшифровки ответа
classes = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
# Загружаем одну картинку из тестового набора для проверки
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
# Берем первую картинку (индекс 0)
x, y = test_data[0]
# Добавляем размерность батча (было [1, 28, 28] -> стало [1, 1, 28, 28]), так как модель ждет пакет
# Если модель обучалась на GPU, здесь данные могут остаться на CPU, так как загруженная модель по умолчанию на CPU
with torch.no_grad(): # Отключаем градиенты для скорости
# Этот код выполняет техническую адаптацию данных: он превращает "просто картинку" в "пакет из одной картинки".
pred = model(x.unsqueeze(0))
predicted_class_index = pred[0].argmax(0)
predicted_class_name = classes[predicted_class_index]
actual_class_name = classes[y]
print(f"Предсказано: {predicted_class_name}, На самом деле: {actual_class_name}")
Рассмотрим ключевые моменты. Прежде всего загружаем ранее сохраненные веса модели из файла "model_weights.pth":
model = NeuralNetwork()
model.load_state_dict(torch.load("model_weights.pth"))
model.eval()
В данном случае модель представляет класс NeuralNetwork, соответственно веса из файла "model_weights.pth" также должны быть предназначены для модели NeuralNetwork.
Далее берем список классов FashionMNIST для расшифровки ответа:
classes = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
Для теста будем использовать картинку из того же набора FashionMNIST. Поэтому загружаем из тестового набора для проверки одну картинку (самую первую):
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
x, y = test_data[0]
Поскольку нам не надо обучать модель, то для большей производительности определяем блок torch.no_grad(), чтобы отключить вычисление градиентов
with torch.no_grad()
pred = model(x.unsqueeze(0))
predicted_class_index = pred[0].argmax(0)
predicted_class_name = classes[predicted_class_index]
actual_class_name = classes[y]
Этот блок кода и выполняет процесс предсказания (инференса) для одной конкретной картинки и расшифровку результата в понятный человеку текст. Разберем каждую строку подробно:
pred = model(x.unsqueeze(0))
Вызов x.unsqueeze(0) превращает картинку формата [1, 28, 28] в пакет [1, 1, 28, 28].
Зачем это нужно? Дело в том, что нейронные сети в PyTorch (и слои вроде nn.Linear, nn.Conv2d) обучены работать с пакетами (батчами) данных, а не с одиночными объектами. Они всегда ожидают на вход 4 измерения (для картинок).
Однако, когда мы берем одну картинку из датасета (x, y = test_data[0]), она имеет только 3 измерения.
Метод .unsqueeze(0) добавляет новую размерность (размером 1) по указанному индексу (в данном случае индекс 0, то есть в самое начало).
По сути он берет картинку x и борачивает ее в "искусственный" пакет размером 1.
Затем результат передается в модель: model(...)
Результат модели получаем в переменную pred. Это тензор размера [1, 10] (1 картинка, 10 классов). Он содержит логиты - "сырые" числа (баллы уверенности).
predicted_class_index = pred[0].argmax(0)
Это самая важная математическая часть, где:
pred[0]: берем нулевой элемент, чтобы получить просто вектор из 10 чисел (pred - это пакет, даже если в нем только один элемент)
.argmax(0): эта функция возвращает индекс самого большого числа в этом списке
Если наглядно, то представим, что модель выдала такие оценки для 3 классов:
Футболка: 0.5 (Индекс 0)
Брюки: 9.2 (Индекс 1)
Платье: 0.1 (Индекс 2)
Тогда argmax(0) вернет число 1, потому что 9.2 - самое большое число, и оно стоит под индексом 1.
predicted_class_name = classes[predicted_class_index]
Здесь мы берем полученный индекс (например, 1) и смотрим в нашем списке названий (classes). Например, если predicted_class_index равно 1,
то classes[1] вернет строку "Trouser".
actual_class_name = classes[y]
Здесь мы делаем то же самое, но для "правильного ответа" (y), который мы взяли из датасета. Это нужно, чтобы сравнить предсказание с реальностью.
То есть в итоге у нас есть две переменные: то, что модель "подумала" (predicted_class_name), и то, что было "на самом деле" (actual_class_name)
В конце выводим результат на консоль. В результате после выполнения программы вы увидим что-то типа следующего:
Модель загружена! Предсказано: Ankle boot, На самом деле: Ankle boot
Теперь пойдем чуть дальше - возьмем случайные 9 картинок из тестового набора, прогоним их через модель и построим таблицу, где, если модель угадала правильно, текст будет зеленым, если ошиблась - красным. Для визуализации изображений используем библиотеку Matplotlib, поэтому перед запуском кода убедитесь, что она у вас установлена. Если нет, установим ее с помощью команды:
pip install matplotlib
В коде также, как и в примере выше, будем загружать веса модели из "model_weights.pth":
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import random
class NeuralNetwork(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NeuralNetwork().to(device)
# Загружаем веса из файла
model.load_state_dict(torch.load("model_weights.pth"))
# Переключаем в режим оценки
model.eval()
print("Модель загружена!")
# Список имен классов (чтобы видеть названия, а не цифры)
classes = [
"T-shirt/top", "Trouser", "Pullover", "Dress", "Coat",
"Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot",
]
# набор для тестирования
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
# --- 3. Визуализация ---
# Создаем фигуру (окно) размером 10x10 дюймов
figure = plt.figure(figsize=(10, 10))
cols, rows = 3, 3 # Будет сетка 3 на 3 картинки
for i in range(1, cols * rows + 1):
# Берем случайный индекс из датасета
sample_idx = torch.randint(len(test_data), size=(1,)).item()
img, label = test_data[sample_idx]
# Отправляем в модель (добавляем измерение батча через unsqueeze)
img = img.to(device)
with torch.no_grad():
pred_logits = model(img.unsqueeze(0))
pred_label = pred_logits.argmax(1).item()
# Добавляем подграфик
figure.add_subplot(rows, cols, i)
# Определяем цвет текста: Зеленый если верно, Красный если ошибка
title_color = "green" if pred_label == label else "red"
plt.title(f"Pred: {classes[pred_label]}\nTrue: {classes[label]}", color=title_color)
plt.axis("off") # Убираем оси координат
# Отрисовываем картинку
# img.squeeze() убирает размерность канала (1, 28, 28) -> (28, 28), чтобы matplotlib понял
plt.imshow(img.cpu().squeeze(), cmap="gray")
plt.show()
Поскольку код в целом снабжен комментариями, то подсвечу лишь отдельные моменты:
torch.randint: выбирает случайную картинку, чтобы при каждом запуске скрипта можно было увидеть разные примеры.
sample_idx = torch.randint(len(test_data), size=(1,)).item() img, label = test_data[sample_idx]
title_color: сравнивает предсказание с истиной. Это очень удобно для визуального анализа слабых мест модели (например, вы часто будете видеть, что модель путает Shirt (Рубашку), T-shirt (Футболку) и Coat (Пальто), так как они похожи).
title_color = "green" if pred_label == label else "red"
cmap="gray": говорит Matplotlib рисовать картинку в черно-белых тонах (иначе по умолчанию она была бы желто-фиолетовой).
Запустим приложение и мы увидим визуальный результат работы модели:
Как видно из скриншота, модель угадала 5 примеров и ошиблась в 4 случаях. В принципе не удивительно, я сам не особо понимаю, что на ряде картинок изображено.