US 12,271,822 B2
Active learning via a sample consistency assessment
Zizhao Zhang, San Jose, CA (US); Tomas Jon Pfister, Foster City, CA (US); Sercan Omer Arik, San Francisco, CA (US); and Mingfei Gao, Greenbelt, MD (US)
Assigned to GOOGLE LLC, Mountain View, CA (US)
Filed by Google LLC, Mountain View, CA (US)
Filed on Aug. 21, 2020, as Appl. No. 17/000,094.
Claims priority of provisional application 62/890,379, filed on Aug. 22, 2019.
Prior Publication US 2021/0056417 A1, Feb. 25, 2021
Int. Cl. G06N 3/044 (2023.01); G06F 7/24 (2006.01); G06F 18/211 (2023.01); G06F 18/214 (2023.01); G06N 3/045 (2023.01); G06N 3/08 (2023.01); G06N 3/084 (2023.01); G06N 7/01 (2023.01); G06N 20/00 (2019.01)
CPC G06N 3/084 (2013.01) [G06F 7/24 (2013.01); G06F 18/211 (2023.01); G06F 18/2155 (2023.01); G06N 3/08 (2013.01); G06N 20/00 (2019.01)] 15 Claims
OG exemplary drawing
 
1. A method for training a machine learning model, the method comprising:
obtaining, by data processing hardware, a set of unlabeled training samples; and
during each of a plurality of active learning cycles of training the machine learning model:
for each unlabeled training sample in the set of unlabeled training samples:
perturbing, by the data processing hardware, the unlabeled training sample to generate an augmented training sample;
generating, by the data processing hardware, using the machine learning model configured to receive the unlabeled training sample and the augmented training sample as inputs, a predicted label for the unlabeled training sample and a predicted label for the augmented training sample; and
determining, by the data processing hardware, an inconsistency value for the unlabeled training sample, the inconsistency value representing variance between the predicted label for the unlabeled training sample and the predicted label for the augmented training sample;
sorting, by the data processing hardware, the unlabeled training samples in the set of unlabeled training samples in a descending order based on the inconsistency values;
obtaining, by the data processing hardware, for each unlabeled training sample in a threshold number of unlabeled training samples selected from the sorted unlabeled training samples in the set of unlabeled training samples, a ground truth label; and
selecting, by the data processing hardware, a current set of labeled training samples, the current set of labeled training samples comprising each unlabeled training sample in the threshold number of unlabeled training samples selected from the sorted unlabeled training samples in the set of unlabeled training samples paired with the corresponding obtained ground truth label; and
training, by the data processing hardware, using the current set of labeled training samples and a subset of unlabeled training samples from the set of unlabeled training samples, the machine learning model;
during an initial active learning cycle:
randomly selecting, by the data processing hardware, a random set of unlabeled training samples from the set of unlabeled training samples;
obtaining, by the data processing hardware, corresponding ground truth labels for each unlabeled training sample in the random set of unlabeled training samples;
training, by the data processing hardware, using the random set of unlabeled training samples and the corresponding ground truth labels, the machine learning model;
identifying, by the data processing hardware, a candidate set of unlabeled training samples from the set of unlabeled training samples, wherein a cardinality of the candidate set of unlabeled training samples is less than a cardinality of the set of unlabeled training samples;
determining, by the data processing hardware, a first cross entropy between a distribution of ground truth labels and a distribution of predicted labels generated using the machine learning model for the unlabeled training samples in the candidate set of unlabeled training samples;
determining, by the data processing hardware, a second cross entropy between a distribution of ground truth labels and a distribution of predicted labels generated using the machine learning model for the unlabeled training samples in the set of unlabeled training samples;
determining, by the data processing hardware, whether the first cross entropy is greater than or equal to the second cross entropy; and
when the first cross entropy is greater than or equal to the second cross entropy, selecting, by the data processing hardware, the candidate set of unlabeled training samples as a starting size for initially training the machine learning model;
when the first cross entropy is less than the second cross entropy:
randomly selecting, by the data processing hardware, an expanded set of unlabeled training samples from the set of unlabeled training samples;
updating, by the data processing hardware, the candidate set of unlabeled training samples to include the expanded set of unlabeled training samples randomly selected from the set of unlabeled training samples;
updating, by the data processing hardware, the set of unlabeled training samples by removing each unlabeled training sample from the expanded set of unlabeled training samples from the set of unlabeled training samples; and
during an immediately subsequent active learning cycle:
determining, by the data processing hardware, the first cross entropy between a distribution of ground truth labels and a distribution of predicted labels generated using the machine learning model for the unlabeled training samples in the updated candidate set of unlabeled training samples;
determining, by the data processing hardware, the second cross entropy between the distribution of ground truth labels and a distribution of predicted labels generating using the machine learning model for the unlabeled training samples in the updated candidate set of unlabeled training samples;
determining, by the data processing hardware, whether the first cross entropy is greater than or equal to the second cross entropy; and
when the first cross entropy is greater than or equal to the second cross entropy, selecting, by the data processing hardware, the updated candidate set of unlabeled training samples as a starting size for initially training the machine learning model.