US 12,217,186 B2
Method and apparatus for semi-supervised learning
Ivana Balazevic, Staines (GB); Carl Allen, Staines (GB); and Timothy Hospedales, Staines (GB)
Assigned to SAMSUNG ELECTRONICS CO., LTD., Suwon-si (KR)
Appl. No. 17/436,927
Filed by SAMSUNG ELECTRONICS CO., LTD., Suwon-si (KR)
PCT Filed May 25, 2021, PCT No. PCT/KR2021/006497
§ 371(c)(1), (2) Date Sep. 7, 2021,
PCT Pub. No. WO2021/241983, PCT Pub. Date Dec. 2, 2021.
Claims priority of application No. 2008030 (GB), filed on May 28, 2020; and application No. 20215401 (EP), filed on Dec. 18, 2020.
Prior Publication US 2023/0177344 A1, Jun. 8, 2023
Int. Cl. G06N 3/088 (2023.01); G06N 3/084 (2023.01); G06V 10/764 (2022.01); G06V 10/774 (2022.01); G06V 10/778 (2022.01); G06V 10/82 (2022.01)
CPC G06N 3/084 (2013.01) [G06N 3/088 (2013.01); G06V 10/764 (2022.01); G06V 10/7753 (2022.01); G06V 10/7784 (2022.01); G06V 10/82 (2022.01)] 14 Claims
OG exemplary drawing
 
1. A computer-implemented method for training a machine learning (ML) model using labelled and unlabelled data, the method comprising:
obtaining a set of training data comprising a set of labelled data items and a set of unlabelled data items;
training a loss module of the ML model using labels in the set of labelled data items, to generate a trained loss module capable of estimating a likelihood of a label for a data item; and
training a task module of the ML model using the loss module, the set of labelled data items, and the set of unlabelled data items, to generate a trained task module capable of making a prediction of a label for input data,
wherein the training of the task module comprises:
inputting a data item in the set of labelled data items into the task module and outputting, from the task module, a predicted label for the data item in the set of labelled data items;
comparing, using a supervised loss module, the output predicted label for the data item in the set of labelled data items with an actual label of the data item in the set of labelled data items, and determining a supervised loss function based on a result of the comparing;
inputting a data item in the set of unlabelled data items into the task module and outputting, from the task module, a predicted label for the data item in the set of unlabelled data items;
determining an unsupervised loss function using the trained loss module and the output predicted label, output from the task module, for the data item in the set of unlabelled data items, wherein the unsupervised loss function is defined by a likelihood of the predicted label for the data item in the set of unlabelled data items;
calculating a total loss function based on a sum of the supervised loss function and the unsupervised loss function; and
using the total loss function and backpropagation to train the task module by minimizing the sum of the supervised loss function and the unsupervised loss function.