Training a Deep Network in PyTorch

Learn how to build and train your first deep neural network using PyTorch with a clear, step-by-step example on the MNIST dataset.

🔰 beginner
⏱️ 60 minutes
👤 SuperML Team

· Deep Learning · 2 min read

📋 Prerequisites

  • Basic Python knowledge
  • Basic understanding of neural networks

🎯 What You'll Learn

  • Set up PyTorch for deep learning projects
  • Build and train a deep neural network on MNIST
  • Monitor loss and accuracy during training
  • Evaluate your model effectively

Introduction

In this tutorial, you will learn to build and train your first deep neural network using PyTorch to classify handwritten digits using the MNIST dataset.


1️⃣ Setting Up PyTorch

First, install PyTorch if you haven’t:

pip install torch torchvision

2️⃣ Import Libraries

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

3️⃣ Prepare the Dataset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

4️⃣ Build the Model

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = SimpleNN()

5️⃣ Define Loss Function and Optimizer

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

6️⃣ Train the Model

epochs = 5

for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

7️⃣ Evaluate the Model

correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Conclusion

✅ You built and trained your first deep neural network using PyTorch.
✅ You learned how to handle data, build models, train, and evaluate them.
✅ You now have the foundation to explore deeper and more complex models using PyTorch.


What’s Next?

✅ Experiment with different optimizers and learning rates.
✅ Add dropout layers to prevent overfitting.
✅ Explore building Convolutional Neural Networks (CNNs) for improved image classification.


Join the SuperML Community to share your projects and continue learning together.


Happy Learning with PyTorch! 🚀

Back to Tutorials

Related Tutorials

🔰beginner ⏱️ 45 minutes

Your First Deep Learning Implementation

Build your first deep learning model to classify handwritten digits using TensorFlow and Keras, explained step-by-step for beginners.

Deep Learning2 min read
deep learningbeginnerkeras +2
🔰beginner ⏱️ 60 minutes

Training a Deep Network in TensorFlow

Learn how to build and train your first deep neural network using TensorFlow and Keras with clear, step-by-step guidance on the MNIST dataset.

Deep Learning2 min read
deep learningtensorflowkeras +2
🔰beginner ⏱️ 30 minutes

Building Convolutional Networks in PyTorch

Learn how to build, train, and evaluate convolutional neural networks (CNNs) in PyTorch with a practical step-by-step example using the CIFAR-10 dataset.

Deep Learning2 min read
deep learningcnnpytorch +1
🔰beginner ⏱️ 30 minutes

Basic Linear Algebra for Deep Learning

Understand the essential linear algebra concepts for deep learning, including scalars, vectors, matrices, and matrix operations, with clear examples for beginners.

Deep Learning2 min read
deep learninglinear algebrabeginner +1