A date with PyTorch Lightning⚡
Supercharge your model with a beautiful wrapper for PyTorch
Do you often forget to call model.eval() before evaluating the performance of your model? Do you spend time setting it up for training on TPU’s? Well, I’m sure Early Stopping and Model Checkpointing are essential requirements during the training process and we should definitely take out some time to configure it properly. But, isn’t it frustrating to set this up with every new project? Wouldn't it just be great to have a wrapper that would handle all the boilerplate code while following the best practices, and let you focus on improving your model accuracy? Well, you’re in luck, because that’s what PyTorch Lighting is actually here to do.
What is PyTorch Lightning?

PyTorch Lightning is a wrapper for PyTorch code that helps you abstract the unimportant and repetitive details of your training. The most beautiful thing about it is the way it allows you to structure and organize your code, which makes it even more fun to work with. If you wish to know more about the inception and goals of this project I suggest you read From PyTorch to PyTorch Lightning — A gentle introduction by William Falcon, who is the creator of PyTorch Lightning.

In this post, I’d like to share how I came across PyTorch Lightning and how it helped me structure and improve my code. We will also look at the differences between the vanilla implementation and the Lightning implementation of a PyTorch model.
Love at first sight
I happened to come across PyTorch Lightning while implementing Abhishek Thakur’s tutorial on Bengali.AI: Handwritten Grapheme Classification Using PyTorch. I was looking for a package that would help me implement early stopping as quickly as possible. As some of you might know, Early stopping halts the training process once the model performance stops improving. I came across Bjarten/early-stopping-pytorch and PyTorch Ignite, but none of them matched the elegance of Lightning. All I had to do was pass early_stop_callback=True to my trainer object and start training my model. It's that easy!
%%capture!pip install pytorch-lightning%%capture# Fixes a TQDM progress bar issue!pip install git+https://github.com/lezwon/pytorch-lightning.git@tqdm-fix --upgradefrom pytorch_lightning import Trainerfrom pytorch_lightning.callbacks import EarlyStoppingtrainer = Trainer(early_stop_callback=True)If I wanted to customize my callback, I could configure it accordingly and add it to my trainer object.
early_stop_callback = EarlyStopping( monitor='val_loss', min_delta=0.00, patience=3, verbose=False, mode='min')trainer = Trainer(early_stop_callback=early_stop_callback)In case you are wondering what a trainer object is, then it is just an instance of the Trainer class, which contains all the boilerplate code for a project. i.e Early Stopping, Model Checkpointing, Distributed Training, etc. It basically takes care of all the nitty-gritty details for you while training your model. Let’s see how to set it up from scratch.
Before Lightning came along
Let's train an MNIST model in vanilla PyTorch. As given in the example here, we would declare our model using the nn.Module class.
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimfrom torch.optim.lr_scheduler import StepLRfrom torchvision import datasets, transformsclass Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return outputOur train and test logic would be written separately along with our data loading and pre-processing logic.
def train(model, device, train_loader, optimizer, epoch): global LOG_INTERVAL model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % LOG_INTERVAL == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))BATCH_SIZE = 64TEST_BATCH_SIZE = 1000EPOCHS = 14LR = 1.0GAMMA = 0.7NO_CUDA = FalseSEED = 1LOG_INTERVAL = 10SAVE_MODEL = Falseuse_cuda = not NO_CUDA and torch.cuda.is_available()torch.manual_seed(SEED)device = torch.device("cuda" if use_cuda else "cpu")kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=BATCH_SIZE, shuffle=True, **kwargs)test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=TEST_BATCH_SIZE, shuffle=True, **kwargs)model = Net().to(device)optimizer = optim.Adadelta(model.parameters(), lr=LR)scheduler = StepLR(optimizer, step_size=1, gamma=GAMMA)for epoch in range(1, EPOCHS + 1): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader) scheduler.step() if SAVE_MODEL: torch.save(model.state_dict(), "mnist_cnn.pt")If you see the output, we received an accuracy of around 99%, which is great. Our implementation, however, contains a lot of boilerplate code and does not seem neat. For example, we had to do a lot of things manually from finding out if we have a CUDA support and copying our tensors to it, to building the model checkpointing functionality ourselves. PyTorch provides us with the flexibility to define our workflow on our own terms, which is why it is so popular. Nevertheless, wouldn't it be great if all this came out of the box? PyTorch Lightning was born to solve this problem. Let's see how this would be done in the Lightning way.
After Lightning came along
We define our model the same way in PyTorch Lightning as we do in vanilla PyTorch except that we extend the class with LightningModule instead of nn.Module. The LightningModule is the same as nn.Module with some added functionality on top of it. We will add our data loading, pre-processing and training logic within the Lightning Module itself.
import osimport torchimport torch.nn as nnfrom torch.nn import functional as Ffrom torch.utils.data import DataLoaderfrom torchvision.datasets import MNISTimport torchvision.transforms as transformsfrom torch.utils.data import random_splitimport torch.optim as optimfrom torch.optim.lr_scheduler import StepLRfrom pytorch_lightning import LightningModulefrom pytorch_lightning import Trainerfrom pytorch_lightning.callbacks import EarlyStoppingclass Net(LightningModule): def __init__(self, **kwargs): self.__dict__.update(kwargs) super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output def configure_optimizers(self): optimizer = optim.Adadelta(model.parameters(), lr=self.LR) scheduler = StepLR(optimizer, step_size=1, gamma=self.GAMMA) return [optimizer], [scheduler] def prepare_data(self): dataset = MNIST( '../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) self.train_dataset, self.validation_dataset = random_split( dataset, [50000, 10000]) self.test_dataset = MNIST( '../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])) def val_dataloader(self): return DataLoader( self.validation_dataset, batch_size=self.BATCH_SIZE, shuffle=True ) def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) loss = F.nll_loss(y_hat, y) return {'val_loss': loss} def validation_epoch_end(self, outputs): avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() return {'val_loss': avg_loss} def train_dataloader(self): return DataLoader( self.train_dataset, batch_size=self.BATCH_SIZE, shuffle=True ) def training_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) return {'loss': F.nll_loss(y_hat, y)} def test_dataloader(self): return DataLoader( self.test_dataset, batch_size=self.TEST_BATCH_SIZE, shuffle=True ) def test_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) accurate = torch.max(y_hat, 1)[1].view(y.size()) == y return {'test_loss': F.nll_loss(y_hat, y), 'accurate': accurate} def test_epoch_end(self, outputs): preds = [x['accurate'] for x in outputs] accuracy = torch.flatten(torch.cat(preds).float()).mean() avg_test_loss = torch.stack([x['test_loss'] for x in outputs]).mean() tensorboard_logs = {'avg_test_loss': avg_test_loss, 'accuracy':accuracy} return {'avg_test_loss': avg_test_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}The LightningModule helps you structure your code by providing a set of abstract methods that can be filled with the required code. Let's go through some of these.
prepare_data: You can download and load your dataset hereconfigure_optimizers: Configure optimizer and schedulers requiredtrain_dataloader: Setup a dataloader for the training datasettraining_step: Do a forward pass and return the lossval_dataloader: Setup a dataloader for the validation datasetvalidation_step: Do a forward pass and return the loss for validationvalidation_epoch_end: Used to calculate avg loss at the end of an epochtest_dataloader: Setup a dataloader for the test datasettest_step: Do a forward pass for the test batchtest_epoch_end: Used to calculate the avg loss and model accuracy
You can check out the rest of the functions in the Lightning module in the Lightning documentation.
As you can see, I no longer have to worry about calling cuda(), to(), eval(), no_grad(), zero_grad(), backward(), step(). PyTorch Lightning takes all this off my plate and handles it for me. The code is much cleaner and easier to read than ever before. Now let's go ahead and initialize the model and trainer objects.
model = Net( BATCH_SIZE = 64, TEST_BATCH_SIZE = 64, LR = 1.0, GAMMA = 0.7)early_stop_callback = EarlyStopping( monitor='val_loss', min_delta=0.00, patience=3, verbose=True, mode='min')trainer = Trainer(gpus=1, early_stop_callback=early_stop_callback)Pass in the required arguments and parameters to the model. Make sure you set the number of GPU's your system supports in the gpus argument. Once this is done, we can go ahead and start training our model.
trainer.fit(model)As we can see the model has completed the training process. Our model stopped training automatically after the validation loss stopped decreasing due to the Early Stopping functionality. We can now go ahead and test the accuracy score of our model. To do this we just call the .test() method on our trainer object.
trainer.test() As you see, the accuracy score and test loss are shown in the output box. Lightning didn't increase our score but it definitely did help us save some training time. You can access the best performing model in the lightning_logs folder.
! ls lightning_logs/version_0/checkpointsWe could import this model directly using Lightning and resume the training process anytime we want. Lightning saves the training state along with the weights of the model.
from pathlib import PathPATH = "lightning_logs/version_0/checkpoints"ckpt_path = str(list(Path(PATH).glob('*.ckpt'))[-1])ckpt_pathmodel = Net.load_from_checkpoint(ckpt_path)model.eval()And finally ever after...❤️
PyTorch Lightning does not aim to abstract any core functionality, just the repetitive housekeeping items which we shouldn't be bothering about. It still provides the same flexibility and speed of PyTorch, just in a more structured and organized manner. As you can see in the implementation, we set up a pretty decent training structure for our model. I can switch to CPU/TPU training or enable parallel/distributed processing without any major changes to my model. Such a wrapper really helps me focus on what's really important i.e the model accuracy rather than the training setup.
Another interesting feature Lightning provides is its support for loggers such as Tensorboard, Visdom, MLFlow, Comet, Neptune, and Wandb. It can help you get a pretty informative glance during the model training process. You can refer to the Loggers docs to set it up.
Well, folks, that's a wrap! I hope you have seen the benefits of PyTorch Lightning and are ready to train your models in the Lightning way. This article just demos the tip of the iceberg on what Lightning provides. It has a ton of more features ready to be explored. Feel free to look at the documentation for more examples or ask any questions on the forums. I hope you enjoyed this article and found it useful. If you have any feedback or suggestions do mention it in the comment section.
Now go on, and start building your models, the Lightning way! ⚡
References
MNIST Training Code in PyTorch: https://github.com/pytorch/examples
MNIST Training Code in PyTorch Lighting: https://github.com/PyTorchLightning/pytorch-lightning-conference-seed
PyTorch Lightning Docs: https://pytorch-lightning.readthedocs.io/
From PyTorch to PyTorch Lightning — A gentle introduction: https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09
36 Ways Pytorch Lightning Can Supercharge Your AI Research: https://towardsdatascience.com/supercharge-your-ai-research-with-pytorch-lightning-337948a99eec