crested.tl.losses.CosineMSELogLoss#
- class crested.tl.losses.CosineMSELogLoss(max_weight=1.0, name='CosineMSELogLoss', reduction='sum_over_batch_size', multiplier=1000)#
Custom loss function combining logarithmic transformation, cosine similarity, and mean squared error (MSE).
This loss function applies a logarithmic transformation to predictions and true values, normalizes these values, and computes both MSE and cosine similarity. A dynamic weight based on the MSE is used to balance these two components.
- Parameters:
max_weight (
float(default:1.0)) – The maximum weight applied to the cosine similarity loss component. Lower values will emphasize the MSE component, while higher values will emphasize the cosine similarity component.name (
str|None(default:'CosineMSELogLoss')) – Name of the loss function.reduction (
str(default:'sum_over_batch_size')) – Type of reduction to apply to loss.multiplier (
float(default:1000)) – Scalar to multiply the predicted value with. When predicting mean coverage, we recommend multiplying by number of bp averaged over to get actual counts (1000 by default, also the default here). Recommended to keep to 1 when predicting insertion counts.
Notes
The log transformation is
log(1 + 1000 * y)for positive values and-log(1 + abs(1000 * y))for negative values.The cosine similarity is computed between L2-normalized true and predicted values.
The dynamic weight for the cosine similarity component is constrained between 1.0 and
max_weight.
Examples
>>> loss = CosineMSELogLoss(max_weight=2.0) >>> y_true = np.array([1.0, 0.0, -1.0]) >>> y_pred = np.array([1.2, -0.1, -0.9]) >>> loss(y_true, y_pred)
Attributes table#
Methods table#
|
Compute the loss value. |
|
Create a loss function from the configuration. |
Return the configuration of the loss function. |
Attributes#
- CosineMSELogLoss.dtype#
Methods#
- CosineMSELogLoss.call(y_true, y_pred)#
Compute the loss value.
- classmethod CosineMSELogLoss.from_config(config)#
Create a loss function from the configuration.
- CosineMSELogLoss.get_config()#
Return the configuration of the loss function.