CPC G06N 3/084 (2013.01) [G06N 3/088 (2013.01); G06V 10/764 (2022.01); G06V 10/7753 (2022.01); G06V 10/7784 (2022.01); G06V 10/82 (2022.01)] | 14 Claims |
1. A computer-implemented method for training a machine learning (ML) model using labelled and unlabelled data, the method comprising:
obtaining a set of training data comprising a set of labelled data items and a set of unlabelled data items;
training a loss module of the ML model using labels in the set of labelled data items, to generate a trained loss module capable of estimating a likelihood of a label for a data item; and
training a task module of the ML model using the loss module, the set of labelled data items, and the set of unlabelled data items, to generate a trained task module capable of making a prediction of a label for input data,
wherein the training of the task module comprises:
inputting a data item in the set of labelled data items into the task module and outputting, from the task module, a predicted label for the data item in the set of labelled data items;
comparing, using a supervised loss module, the output predicted label for the data item in the set of labelled data items with an actual label of the data item in the set of labelled data items, and determining a supervised loss function based on a result of the comparing;
inputting a data item in the set of unlabelled data items into the task module and outputting, from the task module, a predicted label for the data item in the set of unlabelled data items;
determining an unsupervised loss function using the trained loss module and the output predicted label, output from the task module, for the data item in the set of unlabelled data items, wherein the unsupervised loss function is defined by a likelihood of the predicted label for the data item in the set of unlabelled data items;
calculating a total loss function based on a sum of the supervised loss function and the unsupervised loss function; and
using the total loss function and backpropagation to train the task module by minimizing the sum of the supervised loss function and the unsupervised loss function.
|