deel.lip.losses module
This module contains losses used in wasserstein distance estimation. See https://arxiv.org/abs/2006.06520 for more information.
- class deel.lip.losses.HKR(alpha, min_margin=1.0, reduction='auto', name='HKR')
Bases:
keras.losses.Loss
Wasserstein loss with a regularization param based on hinge loss.
\[\inf_{f \in Lip_1(\Omega)} \underset{\textbf{x} \sim P_-}{\mathbb{E}} \left[f(\textbf{x} )\right] - \underset{\textbf{x} \sim P_+} {\mathbb{E}} \left[f(\textbf{x} )\right] + \alpha \underset{\textbf{x}}{\mathbb{E}} \left(\text{min_margin} -Yf(\textbf{x})\right)_+\]- Parameters
alpha – regularization factor
min_margin – minimal margin ( see hinge_margin_loss )
consistent (Kantorovich-rubinstein term of the loss. In order to be) –
KR (between hinge and) –
class (the first label must yield the positve) –
class. (while the second yields negative) –
reduction – passed to tf.keras.Loss constructor
name – passed to tf.keras.Loss constructor
- Returns
a function that compute the regularized Wasserstein loss
- call(y_true, y_pred)
- get_config()
Returns the config dictionary for a Loss instance.
- class deel.lip.losses.HingeMargin(min_margin=1.0, reduction='auto', name='HingeMargin')
Bases:
keras.losses.Loss
Compute the hinge margin loss.
\[\underset{\textbf{x}}{\mathbb{E}} \left(\text{min_margin} -Yf(\textbf{x})\right)_+\]- Parameters
min_margin – the minimal margin to enforce.
reduction – passed to tf.keras.Loss constructor
name – passed to tf.keras.Loss constructor
- Returns
a function that compute the hinge loss.
- call(y_true, y_pred)
- get_config()
Returns the config dictionary for a Loss instance.
- deel.lip.losses.KR(y_true, y_pred)
Loss to estimate wasserstein-1 distance using Kantorovich-Rubinstein duality. The Kantorovich-Rubinstein duality is formulated as following:
\[W_1(\mu, \nu) = \sup_{f \in Lip_1(\Omega)} \underset{\textbf{x} \sim \mu}{\mathbb{E}} \left[f(\textbf{x} )\right] - \underset{\textbf{x} \sim \nu}{\mathbb{E}} \left[f(\textbf{x} )\right]\]Where mu and nu stands for the two distributions, the distribution where the label is 1 and the rest.
- Returns
Callable, the function to compute Wasserstein loss
- class deel.lip.losses.MultiMargin(min_margin=1, reduction='auto', name='MultiMargin')
Bases:
keras.losses.Loss
- Compute the mean hinge margin loss for multi class (equivalent to Pytorch
multi_margin_loss)
- Parameters
min_margin – the minimal margin to enforce.
reduction – passed to tf.keras.Loss constructor
name – passed to tf.keras.Loss constructor
Notes
y_true has to be to_categorical
- Returns
Callable, the function to compute multi margin loss
- call(y_true, y_pred)
- get_config()
Returns the config dictionary for a Loss instance.
- class deel.lip.losses.MulticlassHKR(alpha=10.0, min_margin=1.0, reduction='auto', name='MulticlassHKR')
Bases:
keras.losses.Loss
The multiclass version of HKR. This is done by computing the HKR term over each class and averaging the results.
- Parameters
alpha – regularization factor
min_margin – minimal margin ( see Hinge_multiclass_loss )
reduction – passed to tf.keras.Loss constructor
name – passed to tf.keras.Loss constructor
- Returns
Callable, the function to compute HKR loss #Note y_true has to be one hot encoded
- call(y_true, y_pred)
- get_config()
Returns the config dictionary for a Loss instance.
- class deel.lip.losses.MulticlassHinge(min_margin=1.0, reduction='auto', name='MulticlassHinge')
Bases:
keras.losses.Loss
Loss to estimate the Hinge loss in a multiclass setup. It compute the elementwise hinge term. Note that this formulation differs from the one commonly found in tensorflow/pytorch (with marximise the difference between the two largest logits). This formulation is consistent with the binary classification loss used in a multiclass fashion. Note y_true should be one hot encoded. labels in (1,0)
- Parameters
min_margin – positive float, margin to enforce.
reduction – passed to tf.keras.Loss constructor
name – passed to tf.keras.Loss constructor
- Returns
Callable, the function to compute multiclass Hinge loss #Note y_true has to be one hot encoded
- call(y_true, y_pred)
- get_config()
Returns the config dictionary for a Loss instance.
- class deel.lip.losses.MulticlassKR(reduction='auto', name='MulticlassKR')
Bases:
keras.losses.Loss
Loss to estimate average of W1 distance using Kantorovich-Rubinstein duality over outputs. Note y_true should be one hot encoding (labels being 1s and 0s ). In this multiclass setup thr KR term is computed for each class and then averaged.
- Parameters
reduction – passed to tf.keras.Loss constructor
name – passed to tf.keras.Loss constructor
- Returns
Callable, the function to compute Wasserstein multiclass loss. #Note y_true has to be one hot encoded
- call(y_true, y_pred)
- get_config()
Returns the config dictionary for a Loss instance.
- deel.lip.losses.negative_KR(y_true, y_pred)
Loss to compute the negative wasserstein-1 distance using Kantorovich-Rubinstein duality. This allows the maximisation of the term using conventional optimizer.
- Returns
Callable, the function to compute negative Wasserstein loss