US 11,941,531 B1
Attention-based prototypical learning
Sercan Omer Arik, San Francisco, CA (US); and Tomas Jon Pfister, Foster City, CA (US)
Assigned to Google LLC, Mountain View, CA (US)
Filed by Google LLC, Mountain View, CA (US)
Filed on Feb. 7, 2020, as Appl. No. 16/785,032.
Claims priority of provisional application 62/802,663, filed on Feb. 7, 2019.
Int. Cl. G06N 3/088 (2023.01); G06F 40/30 (2020.01); G06N 3/045 (2023.01)
CPC G06N 3/088 (2013.01) [G06F 40/30 (2020.01); G06N 3/045 (2023.01)] 19 Claims
OG exemplary drawing
 
1. A method performed by one or more data processing apparatus, the method comprising:
determining a respective attention weight between an input data element and each of a plurality of reference data elements, comprising:
processing the input data element using an encoder neural network to generate a query embedding of the input data element;
processing each of the reference data elements using the encoder neural network to generate a respective key embedding of each reference data element; and
for each reference data element, determining the attention weight between the input data element and the reference data element based on a measure of alignment between the query embedding of the input data element and the key embedding of the reference data element using a sparsemax normalization function;
generating a prediction output that characterizes the input data element based on at least the attention weights and the reference data elements, comprising:
processing each of the reference data elements using the encoder neural network to generate a respective value embedding of each reference data element;
determining a combined value embedding of the reference data elements based on (i) the respective value embedding of each reference data element, and (ii) the respective attention weight between the input data element and each reference data element; and
processing the combined value embedding of the reference data elements using a prediction neural network to generate the prediction output that characterizes the input data element;
identifying a proper subset of the reference data elements based on the attention weights; and
providing the identified proper subset of the reference data elements for use in interpreting the prediction output that characterizes the input data element.