A date with PyTorch Lightning⚡

Supercharge your model with a beautiful wrapper for PyTorch

Do you often forget to call model.eval() before evaluating the performance of your model? Do you spend time setting it up for training on TPU’s? Well, I’m sure Early Stopping and Model Checkpointing are essential requirements during the training process and we should definitely take out some time to configure it properly. But, isn’t it frustrating to set this up with every new project? Wouldn't it just be great to have a wrapper that would handle all the boilerplate code while following the best practices, and let you focus on improving your model accuracy? Well, you’re in luck, because that’s what PyTorch Lighting is actually here to do.

What is PyTorch Lightning?

PyTorch Lightning is a wrapper for PyTorch code that helps you abstract the unimportant and repetitive details of your training. The most beautiful thing about it is the way it allows you to structure and organize your code, which makes it even more fun to work with. If you wish to know more about the inception and goals of this project I suggest you read From PyTorch to PyTorch Lightning — A gentle introduction by William Falcon, who is the creator of PyTorch Lightning.

In this post, I’d like to share how I came across PyTorch Lightning and how it helped me structure and improve my code. We will also look at the differences between the vanilla implementation and the Lightning implementation of a PyTorch model.

Love at first sight

I happened to come across PyTorch Lightning while implementing Abhishek Thakur’s tutorial on Bengali.AI: Handwritten Grapheme Classification Using PyTorch. I was looking for a package that would help me implement early stopping as quickly as possible. As some of you might know, Early stopping halts the training process once the model performance stops improving. I came across Bjarten/early-stopping-pytorch and PyTorch Ignite, but none of them matched the elegance of Lightning. All I had to do was pass early_stop_callback=True to my trainer object and start training my model. It's that easy!

%%capture
!pip install pytorch-lightning
13.1s
PyTorch Lightning (Python)
%%capture
# Fixes a TQDM progress bar issue
!pip install git+https://github.com/lezwon/pytorch-lightning.git@tqdm-fix --upgrade
10.2s
PyTorch Lightning (Python)
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
1.9s
PyTorch Lightning (Python)
trainer = Trainer(early_stop_callback=True)
0.0s
PyTorch Lightning (Python)

If I wanted to customize my callback, I could configure it accordingly and add it to my trainer object.

early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=3,
    verbose=False,
    mode='min'
)
trainer = Trainer(early_stop_callback=early_stop_callback)
0.0s
PyTorch Lightning (Python)

In case you are wondering what a trainer object is, then it is just an instance of the Trainer class, which contains all the boilerplate code for a project. i.e Early Stopping, Model Checkpointing, Distributed Training, etc. It basically takes care of all the nitty-gritty details for you while training your model. Let’s see how to set it up from scratch.

Before Lightning came along

Let's train an MNIST model in vanilla PyTorch. As given in the example here, we would declare our model using the nn.Module class.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
0.2s
PyTorch Lightning (Python)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
0.0s
PyTorch Lightning (Python)

Our train and test logic would be written separately along with our data loading and pre-processing logic.

def train(model, device, train_loader, optimizer, epoch):
  global LOG_INTERVAL
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    data, target = data.to(device), target.to(device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    if batch_idx % LOG_INTERVAL == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
0.0s
PyTorch Lightning (Python)
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
0.0s
PyTorch Lightning (Python)
BATCH_SIZE = 64
TEST_BATCH_SIZE = 1000
EPOCHS = 14
LR = 1.0
GAMMA = 0.7
NO_CUDA = False
SEED = 1
LOG_INTERVAL = 10
SAVE_MODEL = False
0.0s
PyTorch Lightning (Python)
use_cuda = not NO_CUDA and torch.cuda.is_available()
torch.manual_seed(SEED)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
  datasets.MNIST('../data', train=True, download=True,
                 transform=transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
                 ])),
  batch_size=BATCH_SIZE, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
  datasets.MNIST('../data', train=False, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
  ])),
  batch_size=TEST_BATCH_SIZE, shuffle=True, **kwargs)
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=LR)
scheduler = StepLR(optimizer, step_size=1, gamma=GAMMA)
for epoch in range(1, EPOCHS + 1):
  train(model, device, train_loader, optimizer, epoch)
  test(model, device, test_loader)
  scheduler.step()
  if SAVE_MODEL:
    torch.save(model.state_dict(), "mnist_cnn.pt")
464.2s
PyTorch Lightning (Python)

If you see the output, we received an accuracy of around 99%, which is great. Our implementation, however, contains a lot of boilerplate code and does not seem neat. For example, we had to do a lot of things manually from finding out if we have a CUDA support and copying our tensors to it, to building the model checkpointing functionality ourselves. PyTorch provides us with the flexibility to define our workflow on our own terms, which is why it is so popular. Nevertheless, wouldn't it be great if all this came out of the box? PyTorch Lightning was born to solve this problem. Let's see how this would be done in the Lightning way.

After Lightning came along

We define our model the same way in PyTorch Lightning as we do in vanilla PyTorch except that we extend the class with LightningModule instead of nn.Module. The LightningModule is the same as nn.Module with some added functionality on top of it. We will add our data loading, pre-processing and training logic within the Lightning Module itself.

import os
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import random_split
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
0.0s
PyTorch Lightning (Python)
class Net(LightningModule):
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output
    def configure_optimizers(self):
        optimizer = optim.Adadelta(model.parameters(), lr=self.LR)
        scheduler = StepLR(optimizer, step_size=1, gamma=self.GAMMA)
        return [optimizer], [scheduler]
    def prepare_data(self):
        dataset = MNIST(
            '../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]))
        self.train_dataset, self.validation_dataset = random_split(
            dataset, [50000, 10000])
        self.test_dataset = MNIST(
            '../data',
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ]))
    def val_dataloader(self):
        return DataLoader(
            self.validation_dataset,
            batch_size=self.BATCH_SIZE, shuffle=True
        )
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = F.nll_loss(y_hat, y)
        return {'val_loss': loss}
    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        return {'val_loss': avg_loss}
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.BATCH_SIZE, shuffle=True
        )
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        return {'loss': F.nll_loss(y_hat, y)}
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.TEST_BATCH_SIZE, shuffle=True
        )
     
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        accurate = torch.max(y_hat, 1)[1].view(y.size()) == y
        return {'test_loss': F.nll_loss(y_hat, y), 'accurate': accurate}
      
    def test_epoch_end(self, outputs):
        preds = [x['accurate'] for x in outputs]
        accuracy = torch.flatten(torch.cat(preds).float()).mean()
        avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        tensorboard_logs = {'avg_test_loss': avg_test_loss, 'accuracy':accuracy}
        return {'avg_test_loss': avg_test_loss, 'log': tensorboard_logs,
                'progress_bar': tensorboard_logs}
0.0s
PyTorch Lightning (Python)

The LightningModule helps you structure your code by providing a set of abstract methods that can be filled with the required code. Let's go through some of these.

  • prepare_data: You can download and load your dataset here

  • configure_optimizers: Configure optimizer and schedulers required

  • train_dataloader: Setup a dataloader for the training dataset

  • training_step: Do a forward pass and return the loss

  • val_dataloader: Setup a dataloader for the validation dataset

  • validation_step: Do a forward pass and return the loss for validation

  • validation_epoch_end: Used to calculate avg loss at the end of an epoch

  • test_dataloader: Setup a dataloader for the test dataset

  • test_step: Do a forward pass for the test batch

  • test_epoch_end: Used to calculate the avg loss and model accuracy

You can check out the rest of the functions in the Lightning module in the Lightning documentation.

As you can see, I no longer have to worry about calling cuda(), to(), eval(), no_grad(), zero_grad(), backward(), step(). PyTorch Lightning takes all this off my plate and handles it for me. The code is much cleaner and easier to read than ever before. Now let's go ahead and initialize the model and trainer objects.

model = Net(
    BATCH_SIZE = 64,
    TEST_BATCH_SIZE = 64,
    LR = 1.0,
    GAMMA = 0.7
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=3,
    verbose=True,
    mode='min'
)
0.0s
PyTorch Lightning (Python)
trainer = Trainer(gpus=1, early_stop_callback=early_stop_callback)
0.0s
PyTorch Lightning (Python)

Pass in the required arguments and parameters to the model. Make sure you set the number of GPU's your system supports in the gpus argument. Once this is done, we can go ahead and start training our model.

trainer.fit(model)
681.8s
PyTorch Lightning (Python)
1

As we can see the model has completed the training process. Our model stopped training automatically after the validation loss stopped decreasing due to the Early Stopping functionality. We can now go ahead and test the accuracy score of our model. To do this we just call the .test() method on our trainer object.

trainer.test()
2.5s
PyTorch Lightning (Python)

As you see, the accuracy score and test loss are shown in the output box. Lightning didn't increase our score but it definitely did help us save some training time. You can access the best performing model in the lightning_logs folder.

! ls lightning_logs/version_0/checkpoints
1.0s
PyTorch Lightning (Python)

We could import this model directly using Lightning and resume the training process anytime we want. Lightning saves the training state along with the weights of the model.

from pathlib import Path
PATH = "lightning_logs/version_0/checkpoints"
ckpt_path = str(list(Path(PATH).glob('*.ckpt'))[-1])
ckpt_path
0.0s
PyTorch Lightning (Python)
'lightning_logs/version_0/checkpoints/epoch=12.ckpt'
model = Net.load_from_checkpoint(ckpt_path)
model.eval()
0.0s
PyTorch Lightning (Python)
Net( (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) (dropout1): Dropout2d(p=0.25, inplace=False) (dropout2): Dropout2d(p=0.5, inplace=False) (fc1): Linear(in_features=9216, out_features=128, bias=True) (fc2): Linear(in_features=128, out_features=10, bias=True) )

And finally ever after...❤️

PyTorch Lightning does not aim to abstract any core functionality, just the repetitive housekeeping items which we shouldn't be bothering about. It still provides the same flexibility and speed of PyTorch, just in a more structured and organized manner. As you can see in the implementation, we set up a pretty decent training structure for our model. I can switch to CPU/TPU training or enable parallel/distributed processing without any major changes to my model. Such a wrapper really helps me focus on what's really important i.e the model accuracy rather than the training setup.

Another interesting feature Lightning provides is its support for loggers such as Tensorboard, Visdom, MLFlow, Comet, Neptune, and Wandb. It can help you get a pretty informative glance during the model training process. You can refer to the Loggers docs to set it up.

Well, folks, that's a wrap! I hope you have seen the benefits of PyTorch Lightning and are ready to train your models in the Lightning way. This article just demos the tip of the iceberg on what Lightning provides. It has a ton of more features ready to be explored. Feel free to look at the documentation for more examples or ask any questions on the forums. I hope you enjoyed this article and found it useful. If you have any feedback or suggestions do mention it in the comment section.

Now go on, and start building your models, the Lightning way! ⚡

References

  1. MNIST Training Code in PyTorch: https://github.com/pytorch/examples

  2. MNIST Training Code in PyTorch Lighting: https://github.com/PyTorchLightning/pytorch-lightning-conference-seed

  3. PyTorch Lightning Docs: https://pytorch-lightning.readthedocs.io/

  4. From PyTorch to PyTorch Lightning — A gentle introduction: https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09

  5. 36 Ways Pytorch Lightning Can Supercharge Your AI Research: https://towardsdatascience.com/supercharge-your-ai-research-with-pytorch-lightning-337948a99eec

Runtimes (1)