US 12,456,047 B2
Distilling from ensembles to improve reproducibility of neural networks
Gil Shamir, Sewickley, PA (US); and Lorenzo Coviello, Pittsburgh, PA (US)
Assigned to GOOGLE LLC, Mountain View, CA (US)
Filed by Google LLC, Mountain View, CA (US)
Filed on Sep. 18, 2020, as Appl. No. 17/025,418.
Claims priority of provisional application 62/938,645, filed on Nov. 21, 2019.
Prior Publication US 2021/0158156 A1, May 27, 2021
Int. Cl. G06N 3/08 (2023.01); G06N 3/045 (2023.01)
CPC G06N 3/08 (2013.01) [G06N 3/045 (2023.01)] 27 Claims
OG exemplary drawing
 
1. A computing system configured to improve the reproducibility of neural networks, the computing system comprising:
one or more processors; and
one or more non-transitory computer-readable media that collectively store:
an ensemble that comprises a plurality of neural networks;
a single neural network, wherein the single neural network exhibits a greater accuracy than the ensemble when trained on a shared training dataset for both the single neural network and the ensemble, wherein the single neural network comprises at least a first head and a second head that respectively generate a first head output and a second head output; and
instructions that, when executed by the one or more processors, cause the computing system to perform operations, the operations comprising:
accessing, by the computing system, one or more training examples;
processing, by the computing system, each of the one or more training examples with the ensemble to obtain an ensemble output from the ensemble;
processing, by the computing system, each of the one or more training examples with the single neural network to obtain a network output from the single neural network; and
training, by the computing system, the single neural network using a loss function that, at least in part, penalizes a difference between the network output and the ensemble output;
wherein the loss function used to train the single neural network comprises a distillation loss term and a supervised loss term,
wherein the distillation loss term penalizes a difference between the first head output and the ensemble output, and
wherein the supervised loss term penalizes a difference between the second head output and one or more ground truth labels associated with the one or more training examples.