Contrastive learning is a self-supervised learning method to learn representations by contrasting positive and negative examples. For self-supervised contrastive learning, the next equation shows the contrastive loss:
where is the embedding of sample and is the temperature parameter.
Codes
There are two versions to implement contrastive loss:
Only augment positive data such as graph classification
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
defcontrastive_loss(x, x_aug, T): """ :param x: the hidden vectors of original data :param x_aug: the positive vector of the auged data :param T: temperature :return: loss """ batch_size, _ = x.size() x_abs = x.norm(dim=1) x_aug_abs = x_aug.norm(dim=1)