Transfer Learning in Deep Learning

Learn the fundamentals of transfer learning, how it accelerates model training by leveraging pre-trained models, and implement transfer learning for image classification using Keras.

🚀 advanced
⏱️ 50 minutes
👤 SuperML Team

· Deep Learning · 2 min read

📋 Prerequisites

  • Understanding of neural networks and CNNs
  • Basic Python and Keras familiarity

🎯 What You'll Learn

  • Understand what transfer learning is and why it is useful
  • Use pre-trained models to accelerate deep learning projects
  • Fine-tune models on custom datasets for new tasks
  • Build and evaluate a transfer learning pipeline using Keras

Introduction

Transfer learning allows you to leverage pre-trained models trained on large datasets (like ImageNet) to solve new tasks with less data and compute.

Instead of training a model from scratch, you reuse learned features, which:

✅ Reduces training time.
✅ Improves model performance on smaller datasets.
✅ Requires fewer computational resources.


Why Use Transfer Learning?

  • Training deep networks from scratch requires large labeled datasets and high compute.
  • Pre-trained models capture generic patterns (edges, textures) useful for many tasks.
  • Transfer learning is widely used in computer vision and NLP applications.

Approaches to Transfer Learning

1️⃣ Feature Extraction: Use the pre-trained model as a fixed feature extractor, replacing the top layer(s) with your classifier.
2️⃣ Fine-Tuning: Unfreeze some deeper layers and retrain them on your dataset to adapt learned features.


Example: Transfer Learning for Image Classification with Keras

We will use MobileNetV2, a lightweight pre-trained CNN, to classify images on a custom dataset.


1️⃣ Install and Import Libraries

pip install tensorflow
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

2️⃣ Load and Prepare Data

Prepare your dataset using ImageDataGenerator for augmentation and rescaling.

train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

train_generator = train_datagen.flow_from_directory(
    'path_to_dataset',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='training'
)

val_generator = train_datagen.flow_from_directory(
    'path_to_dataset',
    target_size=(224, 224),
    batch_size=32,
    class_mode='categorical',
    subset='validation'
)

3️⃣ Load Pre-Trained Model

base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze base model layers
base_model.trainable = False

# Add custom top layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
outputs = Dense(train_generator.num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=outputs)

4️⃣ Compile and Train the Model

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(
    train_generator,
    epochs=10,
    validation_data=val_generator
)

5️⃣ Fine-Tuning (Optional)

After initial training, you can unfreeze some layers and continue training with a low learning rate for fine-tuning.

base_model.trainable = True

model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

history_finetune = model.fit(
    train_generator,
    epochs=5,
    validation_data=val_generator
)

Conclusion

✅ Transfer learning allows efficient use of pre-trained models for new tasks.
✅ It helps you build performant models even with limited data and compute.
✅ You can implement feature extraction and fine-tuning using Keras easily.


What’s Next?

✅ Experiment with other pre-trained models like ResNet, Inception, or EfficientNet.
✅ Apply transfer learning for NLP tasks using models like BERT.
✅ Integrate transfer learning into your pipelines for real-world projects.


Join our SuperML Community to share your transfer learning projects and continue growing as a deep learning practitioner.


Happy Learning! 🚀

Back to Tutorials

Related Tutorials

🚀advanced ⏱️ 2-4 hours

Computer Vision Project with Advanced Deep Learning

Apply advanced deep learning to build a complete computer vision project using CNNs and transfer learning, guiding you from dataset preparation to model deployment.

Deep Learning2 min read
deep learningcomputer visioncnn +2
🚀advanced ⏱️ 60 minutes

Convolutional Neural Networks (CNNs)

Learn the fundamentals of Convolutional Neural Networks, understand how they process image data, and build your first CNN for image classification using Keras.

Deep Learning2 min read
deep learningcnncomputer vision +2
🚀advanced ⏱️ 60 minutes

Deep Neural Networks

Understand the architecture and training of deep neural networks, explore their power in learning complex patterns, and learn how to build and train deep networks using Keras.

Deep Learning2 min read
deep learningneural networkspython +1
🚀advanced ⏱️ 60 minutes

Recurrent Neural Networks (RNNs)

Learn the fundamentals of Recurrent Neural Networks, understand their architecture for handling sequential data, and build your first RNN for sequence prediction using Keras.

Deep Learning2 min read
deep learningrnntime series +2