Custom models, losses, and metrics with CREsted#

This tutorial demonstrates CREsted’s flexibility in working with custom model architectures, loss functions, and metrics across different deep learning backends (TensorFlow and PyTorch) via Keras 3.

CREsted is built on Keras 3, which means you can seamlessly switch between TensorFlow and PyTorch backends while using the same high-level API. This tutorial shows how to:

  1. Create custom model architectures for both backends

  2. Implement custom loss functions and metrics

  3. Use custom training loops (PyTorch example)

The key advantage is that your custom components work identically regardless of the backend, giving you the flexibility to choose the best framework for your specific use case. You could, for example, quickly test a novel loss function that you found in some paper’s codebase, whether it’s implemented in torch, tf, or keras code, all within the CREsted framework.

Setup and backend detection#

import numpy as np
import crested  # import crested before keras to ensure correct backend setup
import keras

# Check which backend we're using
print(f"Keras backend: {keras.backend.backend()}")
print(f"Keras version: {keras.__version__}")

# Import backend-specific modules if needed
if keras.backend.backend() == "tensorflow":
    import tensorflow as tf

    print(f"TensorFlow version: {tf.__version__}")
elif keras.backend.backend() == "torch":
    import torch

    print(f"PyTorch version: {torch.__version__}")
Keras backend: torch
Keras version: 3.11.2
PyTorch version: 2.8.0+cu128

Data loading#

We’ll use the same mouse cortex dataset from the main tutorial.

# Load preprocessed data (assuming you've run the main tutorial)
import anndata as ad

adata = ad.read_h5ad("data/mouse_cortex_filtered.h5ad")

# Register genome
genome = crested.Genome(fasta="data/genomes/mm10/mm10.fa", chrom_sizes="data/genomes/mm10/mm10.chrom.sizes")
crested.register_genome(genome)

print(f"Dataset shape: {adata.shape}")
print(f"Cell types: {list(adata.obs_names)}")

# Setup datamodule
datamodule = crested.tl.data.AnnDataModule(
    adata,
    batch_size=64,
    max_stochastic_shift=3,
    always_reverse_complement=True,
)
2025-08-21T16:09:49.058203+0200 INFO Genome mm10 registered.
Dataset shape: (19, 91477)
Cell types: ['Astro', 'Endo', 'L2_3IT', 'L5ET', 'L5IT', 'L5_6NP', 'L6CT', 'L6IT', 'L6b', 'Lamp5', 'Micro_PVM', 'OPC', 'Oligo', 'Pvalb', 'Sncg', 'Sst', 'SstChodl', 'VLMC', 'Vip']

Custom model architectures#

CREsted works with any Keras model, regardless of how it’s implemented under the hood. Here we show the same simple CNN architecture implemented in different ways depending on your backend preference.

Standard Keras model (framework agnostic)#

This approach works identically with both TensorFlow and PyTorch backends.
If you write your model in keras code (recent tf versions also follow this syntax), then keras will use the detected framework as backend for running the computations.

def create_simple_cnn_keras(seq_len=2114, num_classes=19):
    """Simple CNN model using standard Keras layers."""
    inputs = keras.layers.Input(shape=(seq_len, 4))

    # First conv block
    x = keras.layers.Conv1D(128, 15, activation="relu", padding="same")(inputs)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.MaxPooling1D(8)(x)
    x = keras.layers.Dropout(0.2)(x)

    # Second conv block
    x = keras.layers.Conv1D(256, 7, activation="relu", padding="same")(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.MaxPooling1D(4)(x)
    x = keras.layers.Dropout(0.2)(x)

    # Global pooling and dense layers
    x = keras.layers.GlobalAveragePooling1D()(x)
    x = keras.layers.Dense(256, activation="relu")(x)
    x = keras.layers.Dropout(0.3)(x)
    outputs = keras.layers.Dense(num_classes, activation="softplus")(x)

    model = keras.Model(inputs=inputs, outputs=outputs, name="simple_cnn_keras")
    return model


# Create the model
keras_model = create_simple_cnn_keras(seq_len=2114, num_classes=len(adata.obs_names))
print(f"Model created with {keras_model.count_params():,} parameters")

# you would now use this model architecture as input to the crested.tl.Crested trainer class as before
Model created with 309,651 parameters

Custom torch models#

If you prefer to write your custom models in PyTorch, then that works too!
You just have to make sure you have a CREsted environment with torch installed instead of tensorflow.
There are two options if you prefer to write your models in torch directly:

  1. Write your model in pure torch without any changes. This means you won’t be able to use the tl.Crested training class though, since keras and torch compile models differently, so you’ll have to write a custom training loop that uses CREsted’s dataloaders (see bottom of tutorial). This is only the case for model definitions; losses and metrics you can still write in pure torch and use within CREsted as usual.

  2. The second option is easier: write your torch model inside a class that inherits from keras.Model and make sure to wrap any torch layers that uses parameters with a keras.layers.TorchModuleWrapper.

if keras.backend.backend() == "torch":
    import torch.nn as nn
    import torch.nn.functional as F
    from keras.layers import TorchModuleWrapper

    # Method 1: Pure PyTorch model (for custom training loops)
    class PurePyTorchModel(nn.Module):
        """Pure PyTorch model equivalent to the Keras CNN."""

        def __init__(self, num_classes):
            super().__init__()
            self.conv1 = nn.Conv1d(4, 128, 15, padding=7)
            self.bn1 = nn.BatchNorm1d(128)
            self.pool1 = nn.MaxPool1d(8)
            self.dropout1 = nn.Dropout(0.2)

            self.conv2 = nn.Conv1d(128, 256, 7, padding=3)
            self.bn2 = nn.BatchNorm1d(256)
            self.pool2 = nn.MaxPool1d(4)
            self.dropout2 = nn.Dropout(0.2)

            self.fc1 = nn.Linear(256, 256)
            self.dropout3 = nn.Dropout(0.3)
            self.fc2 = nn.Linear(256, num_classes)

        def forward(self, x):
            """Forward pass for the PyTorch model."""
            # Input: (batch, seq_len, 4) -> (batch, 4, seq_len) for Conv1d
            x = x.transpose(1, 2)

            # First conv block (matching Keras model exactly)
            x = F.relu(self.conv1(x))
            x = self.bn1(x)
            x = self.pool1(x)
            x = self.dropout1(x)

            # Second conv block
            x = F.relu(self.conv2(x))
            x = self.bn2(x)
            x = self.pool2(x)
            x = self.dropout2(x)

            # Global average pooling
            x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)

            # Dense layers
            x = F.relu(self.fc1(x))
            x = self.dropout3(x)
            x = F.softplus(self.fc2(x))
            return x

    # Method 2: Keras model using PyTorch operations
    class KerasModelWithPyTorch(keras.Model):
        """Keras model that uses PyTorch operations wrapped in TorchModuleWrapper."""

        def __init__(self, seq_len, num_classes, **kwargs):
            super().__init__(**kwargs)
            self.seq_len = seq_len

            # Wrap PyTorch modules with TorchModuleWrapper when they contain parameters
            self.conv1 = TorchModuleWrapper(nn.Conv1d(4, 128, 15, padding=7))
            self.bn1 = TorchModuleWrapper(nn.BatchNorm1d(128))
            self.pool1 = nn.MaxPool1d(8)  # No parameters, no need to wrap
            self.dropout1 = nn.Dropout(0.2)

            self.conv2 = TorchModuleWrapper(nn.Conv1d(128, 256, 7, padding=3))
            self.bn2 = TorchModuleWrapper(nn.BatchNorm1d(256))
            self.pool2 = nn.MaxPool1d(4)
            self.dropout2 = nn.Dropout(0.2)

            self.fc1 = TorchModuleWrapper(nn.Linear(256, 256))
            self.dropout3 = nn.Dropout(0.3)
            self.fc2 = TorchModuleWrapper(nn.Linear(256, num_classes))

        def call(self, inputs):
            """Forward pass for the Keras model using PyTorch operations."""
            # Transpose for PyTorch Conv1d: (batch, seq_len, 4) -> (batch, 4, seq_len)
            x = inputs.transpose(1, 2)

            # First conv block using PyTorch operations
            x = F.relu(self.conv1(x))
            x = self.bn1(x)
            x = self.pool1(x)
            x = self.dropout1(x)

            # Second conv block
            x = F.relu(self.conv2(x))
            x = self.bn2(x)
            x = self.pool2(x)
            x = self.dropout2(x)

            # Global average pooling
            x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)

            # Dense layers
            x = F.relu(self.fc1(x))
            x = self.dropout3(x)
            x = F.softplus(self.fc2(x))
            return x

    # Create both models
    num_classes = len(adata.obs_names)
    seq_len = 2114

    # Pure PyTorch model (for custom training loops)
    pure_pytorch_model = PurePyTorchModel(num_classes)
    pytorch_params = sum(p.numel() for p in pure_pytorch_model.parameters())
    print(f"Pure PyTorch model created with {pytorch_params:,} parameters")

    # Keras model with PyTorch components (for CREsted workflow)
    keras_pytorch_model = KerasModelWithPyTorch(seq_len, num_classes)
    # Build the model to count parameters
    keras_pytorch_model.build((None, seq_len, 4))
    print(f"Keras+PyTorch hybrid model created with {keras_pytorch_model.count_params():,} parameters")

    # Use the hybrid model for CREsted training (it works with CREsted.fit())
    model = keras_pytorch_model

else:
    print("PyTorch-specific examples only available with PyTorch backend")
    model = keras_model  # Use the standard Keras model
Pure PyTorch model created with 308,883 parameters
Keras+PyTorch hybrid model created with 308,883 parameters

Custom loss functions#

CREsted can work with any Keras or torch-compatible loss function. Here we show how to implement custom losses that work across different backends:

# Framework-agnostic custom loss using Keras ops
class WeightedMSELoss(keras.losses.Loss):
    """Custom weighted MSE loss that works with any backend."""

    def __init__(self, class_weights=None, name="weighted_mse", **kwargs):
        super().__init__(name=name, **kwargs)
        self.class_weights = class_weights

    def call(self, y_true, y_pred):
        """Compute weighted mean squared error loss."""
        # Use keras.ops for backend-agnostic operations
        mse = keras.ops.square(y_true - y_pred)

        if self.class_weights is not None:
            # Apply class-specific weights
            weights = keras.ops.convert_to_tensor(self.class_weights)
            mse = mse * weights

        return keras.ops.mean(mse)


# Backend-specific loss implementations
def create_custom_loss_backend_specific():
    """Create custom loss using backend-specific operations."""
    if keras.backend.backend() == "tensorflow":
        import tensorflow as tf

        class TensorFlowCustomLoss(keras.losses.Loss):
            def __init__(self, alpha=1.0, name="tf_custom_loss", **kwargs):
                super().__init__(name=name, **kwargs)
                self.alpha = alpha

            def call(self, y_true, y_pred):
                # Use TensorFlow operations directly
                mse = tf.square(y_true - y_pred)
                # Add custom TF-specific regularization
                l2_reg = tf.reduce_sum(tf.square(y_pred)) * 0.001
                return tf.reduce_mean(mse) + self.alpha * l2_reg

        return TensorFlowCustomLoss()

    elif keras.backend.backend() == "torch":
        import torch

        class PyTorchCustomLoss(keras.losses.Loss):
            def __init__(self, alpha=1.0, name="torch_custom_loss", **kwargs):
                super().__init__(name=name, **kwargs)
                self.alpha = alpha

            def call(self, y_true, y_pred):
                # Use PyTorch operations directly
                mse = torch.square(y_true - y_pred)
                l2_reg = torch.sum(torch.square(y_pred)) * 0.001
                return torch.mean(mse) + self.alpha * l2_reg

        return PyTorchCustomLoss()


# Create loss functions
print(f"Creating losses for {keras.backend.backend()} backend.")

# Framework-agnostic loss
class_weights = np.ones(len(adata.obs_names))  # Equal weights for demo
weighted_loss = WeightedMSELoss(class_weights=class_weights)

# Backend-specific loss
backend_loss = create_custom_loss_backend_specific()
print(backend_loss)
Creating losses for torch backend.
<__main__.create_custom_loss_backend_specific.<locals>.PyTorchCustomLoss object at 0x72a3be54c590>

Custom metrics#

Similar to losses, you can create custom metrics that work across backends:

# Framework-agnostic custom metric
class SpearmanCorrelation(keras.metrics.Metric):
    """Custom Spearman correlation metric that works with any backend."""

    def __init__(self, name="spearman_corr", **kwargs):
        super().__init__(name=name, **kwargs)
        self.correlation_sum = self.add_weight(name="correlation_sum", initializer="zeros")
        self.count = self.add_weight(name="count", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        """Update state with new predictions and true values."""
        # Simplified Spearman correlation (rank correlation)
        # For demo - real implementation would be more complex
        pearson = keras.ops.sum(
            (y_true - keras.ops.mean(y_true, axis=0)) * (y_pred - keras.ops.mean(y_pred, axis=0)),
            axis=0,
        )
        norm_factor = keras.ops.sqrt(
            keras.ops.sum(keras.ops.square(y_true - keras.ops.mean(y_true, axis=0)), axis=0)
            * keras.ops.sum(keras.ops.square(y_pred - keras.ops.mean(y_pred, axis=0)), axis=0)
        )
        correlation = keras.ops.mean(pearson / (norm_factor + 1e-8))

        self.correlation_sum.assign_add(correlation)
        self.count.assign_add(1.0)

    def result(self):
        """Returns the average correlation."""
        return self.correlation_sum / self.count

    def reset_state(self):
        """Reset the state of the metric."""
        self.correlation_sum.assign(0.0)
        self.count.assign(0.0)


# Or, write backend-specific metrics
def create_custom_metric_backend_specific():
    """Create custom metric using backend-specific operations."""
    if keras.backend.backend() == "tensorflow":
        import tensorflow as tf

        class TensorFlowR2Score(keras.metrics.Metric):
            def __init__(self, name="tf_r2_score", **kwargs):
                super().__init__(name=name, **kwargs)
                self.total_sum_squares = self.add_weight(name="tss", initializer="zeros")
                self.residual_sum_squares = self.add_weight(name="rss", initializer="zeros")

            def update_state(self, y_true, y_pred, sample_weight=None):
                # Using TensorFlow operations directly
                y_mean = tf.reduce_mean(y_true)
                tss = tf.reduce_sum(tf.square(y_true - y_mean))
                rss = tf.reduce_sum(tf.square(y_true - y_pred))

                self.total_sum_squares.assign_add(tss)
                self.residual_sum_squares.assign_add(rss)

            def result(self):
                return 1.0 - (self.residual_sum_squares / (self.total_sum_squares + 1e-8))

            def reset_state(self):
                self.total_sum_squares.assign(0.0)
                self.residual_sum_squares.assign(0.0)

        return TensorFlowR2Score()

    elif keras.backend.backend() == "torch":
        import torch

        class PyTorchR2Score(keras.metrics.Metric):
            def __init__(self, name="torch_r2_score", **kwargs):
                super().__init__(name=name, **kwargs)
                self.total_sum_squares = self.add_weight(name="tss", initializer="zeros")
                self.residual_sum_squares = self.add_weight(name="rss", initializer="zeros")

            def update_state(self, y_true, y_pred, sample_weight=None):
                # Using PyTorch operations directly
                y_mean = torch.mean(y_true)
                tss = torch.sum(torch.square(y_true - y_mean))
                rss = torch.sum(torch.square(y_true - y_pred))

                self.total_sum_squares.assign_add(tss)
                self.residual_sum_squares.assign_add(rss)

            def result(self):
                return 1.0 - (self.residual_sum_squares / (self.total_sum_squares + 1e-8))

            def reset_state(self):
                self.total_sum_squares.assign(0.0)
                self.residual_sum_squares.assign(0.0)

        return PyTorchR2Score()


# Create metrics
print(f"Creating metrics for {keras.backend.backend()} backend:")

# Framework-agnostic metric
spearman_metric = SpearmanCorrelation()

# Backend-specific metric
r2_metric = create_custom_metric_backend_specific()

# Standard metrics that work with both backends
standard_metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.CosineSimilarity(axis=1),
    spearman_metric,
    r2_metric,
]
print(standard_metrics)
Creating metrics for torch backend:
[<MeanAbsoluteError name=mean_absolute_error>, <CosineSimilarity name=cosine_similarity>, <SpearmanCorrelation name=spearman_corr>, <PyTorchR2Score name=torch_r2_score>]

Custom training with CREsted#

You can use CREsted’s standard training workflow or implement custom training loops. Here’s an example of custom training when using the PyTorch backend with our custom model, metrics, and loss.

Custom training loop with the tl.Crested framework#

print(backend_loss)
<__main__.create_custom_loss_backend_specific.<locals>.PyTorchCustomLoss object at 0x72a3be54c590>
# Create TaskConfig with custom components
custom_config = crested.tl.TaskConfig(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=backend_loss,  # Our custom loss (pure torch in my case)
    metrics=standard_metrics,  # Our custom metrics
)

# Train using CREsted framework
trainer = crested.tl.Crested(
    data=datamodule,
    model=model,  # with a torch environment this uses our "keras-torch hybrid" model
    config=custom_config,
    project_name="custom_tutorial",
    run_name="framework_agnostic",
)

# Quick training for demo (just 2 epochs)
trainer.fit(epochs=2, learning_rate_reduce_patience=1, early_stopping_patience=2)
Model: "keras_model_with_py_torch_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ torch_module_wrapper_23         │ ?                      │         7,808 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_24         │ ?                      │           256 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_25         │ ?                      │             0 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_26         │ ?                      │             0 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_27         │ ?                      │       229,632 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_28         │ ?                      │           512 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_29         │ ?                      │             0 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_30         │ ?                      │             0 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_31         │ ?                      │        65,792 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_32         │ ?                      │             0 │
│ (TorchModuleWrapper)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ torch_module_wrapper_33         │ ?                      │         4,883 │
│ (TorchModuleWrapper)            │                        │               │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 308,883 (1.18 MB)
 Trainable params: 308,883 (1.18 MB)
 Non-trainable params: 0 (0.00 B)
None
2025-08-21T16:10:17.146763+0200 INFO Loading sequences into memory...
2025-08-21T16:10:24.141689+0200 INFO Loading sequences into memory...
Epoch 1/2
2292/2292 ━━━━━━━━━━━━━━━━━━━━ 98s 42ms/step - cosine_similarity: 0.6523 - loss: 28.2246 - mean_absolute_error: 2.8233 - spearman_corr: 0.1557 - torch_r2_score: 0.0570 - val_cosine_similarity: 0.6516 - val_loss: 28.7826 - val_mean_absolute_error: 2.3636 - val_spearman_corr: 0.1725 - val_torch_r2_score: -5.8913e-04 - learning_rate: 0.0010
Epoch 2/2
2292/2292 ━━━━━━━━━━━━━━━━━━━━ 100s 43ms/step - cosine_similarity: 0.6580 - loss: 27.5134 - mean_absolute_error: 2.7868 - spearman_corr: 0.1910 - torch_r2_score: 0.0807 - val_cosine_similarity: 0.6574 - val_loss: 28.2391 - val_mean_absolute_error: 2.3365 - val_spearman_corr: 0.2021 - val_torch_r2_score: 0.0183 - learning_rate: 0.0010

Complete custom training loop example (PyTorch backend only)#

If you want a complete custom training loop but still want to make use of the Crested dataloaders, then that’s possible too.

# datamodule contains pytorch DataLoader when using torch backend
for x, y in datamodule.train_dataloader.data:
    print(x.shape, y.shape)
    break
torch.Size([64, 2114, 4]) torch.Size([64, 19])
# Custom training loop example (PyTorch backend)
if keras.backend.backend() == "torch":
    import torch

    custom_model = pure_pytorch_model  # Use the pure PyTorch model
    custom_model.to(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    optimizer = torch.optim.Adam(custom_model.parameters(), lr=1e-3)
    custom_model.train()

    # Get training data from Crested datamodule
    train_loader = datamodule.train_dataloader.data

    # Custom training loop
    for epoch in range(2):  # Just 2 epochs for demo
        epoch_loss = 0.0
        num_batches = 0

        for batch_idx, (sequences, targets) in enumerate(train_loader):
            if batch_idx >= 10:  # Limit batches for demo
                break

            with torch.enable_grad():
                optimizer.zero_grad()

                # Forward pass
                predictions = custom_model(sequences)

                # Calculate loss
                loss_value = backend_loss(targets, predictions)

                # Backward pass - get gradients
                loss_value.backward()

                optimizer.step()

                epoch_loss += loss_value.item()
                num_batches += 1

        avg_loss = epoch_loss / num_batches
        print(f"Epoch {epoch + 1}/2 - Custom Loss: {avg_loss:.4f}")

    print("Custom PyTorch training completed")

else:
    print("Custom training loop example only available with PyTorch backend")
    print("Current backend:", keras.backend.backend())
Epoch 1/2 - Custom Loss: 21.7358
Epoch 2/2 - Custom Loss: 21.3225
Custom PyTorch training completed