Sample Complexity of Deep Active Learning

29 Sep 2021  ·  Zhao Song, Baocheng Sun, Danyang Zhuo ·

Many machine learning algorithms require large numbers of labeled training data to deliver state-of-the-art results. However, in many domains of AI, there are abundant unlabeled data but it is costly to get data labeled by experts, such as medical diagnosis and fraud detection. In these domains, active learning, where an algorithm maximizes model accuracy while requiring the least number of labeled data, is appealing. Active learning uses both labeled and unlabeled data to train models, and the learning algorithm decides which subset of data should acquire labels. Due to the costly label acquisition, it is interesting to know whether it is possible from a theoretical perspective to understand how many labeled data are actually needed to train a machine learning model. This question is known as the sample complexity problem, and it has been extensively explored for training linear machine learning models (e.g., linear regression). Today, deep learning has become the de facto method for machine learning, but the sample complexity problem for deep active learning remains unsolved. This problem is challenging due to the non-linear nature of neural networks. In this paper, we present the first deep active learning algorithm which has a provable sample complexity. Using this algorithm, we have derived the first upper bound on the number of required labeled data for training neural networks. Our upper bound shows that the minimum number of labeled data a neural net needs does not depend on the data distribution or the width of the neural network but is determined by the smoothness of non-linear activation and the dimension of the input data.

PDF Abstract

Datasets


  Add Datasets introduced or used in this paper

Results from the Paper


  Submit results from this paper to get state-of-the-art GitHub badges and help the community compare results to other papers.

Methods


No methods listed for this paper. Add relevant methods here