US 12,353,981 B2
Training of large neural networks
Slav Petrov, New York, NY (US); Yonghui Wu, Fremont, CA (US); Andrew M. Dai, San Francisco, CA (US); David Richard So, Brooklyn, NY (US); Dmitry Lepikhin, Menlo Park, CA (US); Erica Ann Moreira, Fremont, CA (US); Gaurav Mishra, Sunnyvale, CA (US); Jonathan Hudson Clark, Seattle, WA (US); Maxim Krikun, Castro Valley, CA (US); Melvin Jose Johnson Premkumar, Sunnyvale, CA (US); Nan Du, San Jose, CA (US); Orhan Firat, Mountain View, CA (US); Rohan Anil, San Francisco, CA (US); Siamak Shakeri, New York, NY (US); Xavier Garcia, New York, NY (US); Yanping Huang, Mountain View, CA (US); Yong Cheng, Mountain View, CA (US); Yuanzhong Xu, Mountain View, CA (US); Yujing Zhang, Sunnyvale, CA (US); Zachary Alexander Nado, Brookline, MA (US); Eric Jun Jie Ni, Mountain View, CA (US); Kefan Xiao, Sunnyvale, CA (US); Vladimir Feinberg, San Francisco, CA (US); Jin Young Sohn, Jersey City, NJ (US); and Aurko Roy, San Francisco, CA (US)
Assigned to Google LLC, Mountain View, CA (US)
Filed by Google LLC, Mountain View, CA (US)
Filed on May 10, 2024, as Appl. No. 18/661,499.
Claims priority of provisional application 63/465,487, filed on May 10, 2023.
Prior Publication US 2024/0378427 A1, Nov. 14, 2024
Int. Cl. G06F 40/284 (2020.01); G06N 3/04 (2023.01); G06N 3/0475 (2023.01); G06N 3/08 (2023.01)
CPC G06N 3/0475 (2023.01) [G06F 40/284 (2020.01); G06N 3/08 (2013.01)] 17 Claims
OG exemplary drawing
 
1. A method performed by one or more computers, wherein the method comprises:
obtaining a plurality of unlabeled text sequences, wherein each unlabeled text sequence comprises a plurality of text tokens;
training an autoregressive generative neural network comprising one or more self-attention layers based on optimizing multiple different pre-training objective functions that comprise (i) a causal language modeling objective function and (ii) a prefix language modeling objective function, wherein training the autoregressive generative neural network based on optimizing the multiple different pre-training objective functions comprises:
obtaining data specifying a respective weight assigned to each of the multiple different pre-training objective functions; and
repeatedly (a) selecting, based on the respective weights, a pre-training objective function from the multiple different pre-training objective functions and (b) training the autoregressive generative neural network on the selected pre-training objective function,
wherein training the autoregressive generative neural network based on optimizing the causal language modeling objective function comprises:
generating, from the plurality of unlabeled text sequences, a plurality of causal language modeling text sequences, wherein generating each causal language modeling text sequence comprises using a corresponding unlabeled text sequence as the causal language modeling text sequence without further processing the corresponding unlabeled text sequence to add to the corresponding unlabeled text sequence any additional tokens that were not included in the corresponding unlabeled text sequence;
processing, using the autoregressive generative neural network, each causal language modeling text sequence to generate, for each token in the causal language modeling text sequence, a causal prediction of a text token that should occupy a particular position of the text token in the causal language modeling text sequence conditioned on text tokens at preceding positions in the causal language modeling text sequence, wherein the one or more self-attention layers within the autoregressive generative neural network apply a masked self-attention mechanism over the preceding positions in the causal language modeling text sequence; and
determining, based on a quality of the causal predictions, an update to parameter values of the autoregressive generative neural network, and
wherein training the autoregressive generative neural network based on optimizing the prefix language modeling objective function comprises:
generating, from the plurality of unlabeled text sequences, a plurality of prefix language modeling text sequences, wherein generating each prefix language modeling text sequence comprises further processing a corresponding unlabeled text sequence to divide the corresponding unlabeled text sequence into a prefix text sequence and a suffix text sequence that follows the prefix text sequence;
processing, using the autoregressive generative neural network, each prefix language modeling text sequence to generate, for each token in the suffix text sequence, a prefix prediction of a text token that should occupy a particular position of the token in the suffix text sequence conditioned on tokens in the prefix text sequence and tokens at any preceding positions in the suffix text sequence, wherein the one or more self-attention layers within the autoregressive generative neural network applies a bidirectional, unmasked attention mechanism over the positions in the prefix text sequence and applies a masked self-attention mechanism over positions in the suffix text sequence so that each position in the suffix text sequence attend over the positions in the prefix text sequence and any preceding positions in the suffix text sequence; and
determining, based on a quality of the prefix predictions, an update to the parameter values of the autoregressive generative neural network.