Deep neural networks (DNNs) are highly overparameterized, a feature that allows them to learn complex input to output mappings. However, this results in DNNs being difficult to interpret, as well as being computationally expensive. In convolutional neural networks (CNNs), filters are updated to find an optimal feature mapping for an image dataset. CNNs, such as ResNet-18, perform many image classification tasks well, but likely have redundant filters that could be removed to reduce the memory footprint of the network. Here, we use neural network pruning to reduce the number of filters used in image classification tasks and we explore the relationship between image dataset complexity and number of CNN filters necessary for accurate classification.
We performed three experiments on three canonical image classification datasets (MNIST, CIFAR-10, and CIFAR-100). We started with a pretrained ResNet-18 model trained on the ImageNet dataset. Then the network is fine-tuned to one of the experimental datasets (e.g., MNIST). Magnitude-based neural network filter pruning was used to remove filters with the lowest L2 norm. Each layer of the network is pruned independently and in reverse order, starting with the final layer. Pruning the layers in reverse order allows us to maximally sparsify the deeper layers of the network, which capture the higher-order features, before pruning the shallow layers, which capture lower-order features. After a group of filters are removed, the network undergoes a cycle of retraining to maintain a baseline level of performance. If performance falls too much, the pruning of a given layer is interrupted and we start pruning the next layer.
We compared the maximum sparsity achieved via neural network pruning across networks trained on three different canonical datasets (MNIST, CIFAR-10, and CIFAR-100). We found that the maximum achievable sparsity is correlated with the task complexity of the dataset. Simpler tasks, such as classifying MNIST digits, can be done with fewer parameters than more challenging tasks (classifying 10 colored image classes or 100 colored image classes).