Custom models, losses, and metrics with CREsted#
This tutorial demonstrates CREsted’s flexibility in working with custom model architectures, loss functions, and metrics across different deep learning backends (TensorFlow and PyTorch) via Keras 3.
CREsted is built on Keras 3, which means you can seamlessly switch between TensorFlow and PyTorch backends while using the same high-level API. This tutorial shows how to:
Create custom model architectures for both backends
Implement custom loss functions and metrics
Use custom training loops (PyTorch example)
The key advantage is that your custom components work identically regardless of the backend, giving you the flexibility to choose the best framework for your specific use case. You could, for example, quickly test a novel loss function that you found in some paper’s codebase, whether it’s implemented in torch, tf, or keras code, all within the CREsted framework.
Setup and backend detection#
import numpy as np
import crested # import crested before keras to ensure correct backend setup
import keras
# Check which backend we're using
print(f"Keras backend: {keras.backend.backend()}")
print(f"Keras version: {keras.__version__}")
# Import backend-specific modules if needed
if keras.backend.backend() == "tensorflow":
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")
elif keras.backend.backend() == "torch":
import torch
print(f"PyTorch version: {torch.__version__}")
Keras backend: torch
Keras version: 3.11.2
PyTorch version: 2.8.0+cu128
Data loading#
We’ll use the same mouse cortex dataset from the main tutorial.
# Load preprocessed data (assuming you've run the main tutorial)
import anndata as ad
adata = ad.read_h5ad("data/mouse_cortex_filtered.h5ad")
# Register genome
genome = crested.Genome(fasta="data/genomes/mm10/mm10.fa", chrom_sizes="data/genomes/mm10/mm10.chrom.sizes")
crested.register_genome(genome)
print(f"Dataset shape: {adata.shape}")
print(f"Cell types: {list(adata.obs_names)}")
# Setup datamodule
datamodule = crested.tl.data.AnnDataModule(
adata,
batch_size=64,
max_stochastic_shift=3,
always_reverse_complement=True,
)
2025-08-21T16:09:49.058203+0200 INFO Genome mm10 registered.
Dataset shape: (19, 91477)
Cell types: ['Astro', 'Endo', 'L2_3IT', 'L5ET', 'L5IT', 'L5_6NP', 'L6CT', 'L6IT', 'L6b', 'Lamp5', 'Micro_PVM', 'OPC', 'Oligo', 'Pvalb', 'Sncg', 'Sst', 'SstChodl', 'VLMC', 'Vip']
Custom model architectures#
CREsted works with any Keras model, regardless of how it’s implemented under the hood. Here we show the same simple CNN architecture implemented in different ways depending on your backend preference.
Standard Keras model (framework agnostic)#
This approach works identically with both TensorFlow and PyTorch backends.
If you write your model in keras code (recent tf versions also follow this syntax), then keras will use the detected framework as backend for running the computations.
def create_simple_cnn_keras(seq_len=2114, num_classes=19):
"""Simple CNN model using standard Keras layers."""
inputs = keras.layers.Input(shape=(seq_len, 4))
# First conv block
x = keras.layers.Conv1D(128, 15, activation="relu", padding="same")(inputs)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPooling1D(8)(x)
x = keras.layers.Dropout(0.2)(x)
# Second conv block
x = keras.layers.Conv1D(256, 7, activation="relu", padding="same")(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.MaxPooling1D(4)(x)
x = keras.layers.Dropout(0.2)(x)
# Global pooling and dense layers
x = keras.layers.GlobalAveragePooling1D()(x)
x = keras.layers.Dense(256, activation="relu")(x)
x = keras.layers.Dropout(0.3)(x)
outputs = keras.layers.Dense(num_classes, activation="softplus")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="simple_cnn_keras")
return model
# Create the model
keras_model = create_simple_cnn_keras(seq_len=2114, num_classes=len(adata.obs_names))
print(f"Model created with {keras_model.count_params():,} parameters")
# you would now use this model architecture as input to the crested.tl.Crested trainer class as before
Model created with 309,651 parameters
Custom torch models#
If you prefer to write your custom models in PyTorch, then that works too!
You just have to make sure you have a CREsted environment with torch installed instead of tensorflow.
There are two options if you prefer to write your models in torch directly:
Write your model in pure torch without any changes. This means you won’t be able to use the
tl.Crestedtraining class though, since keras and torch compile models differently, so you’ll have to write a custom training loop that uses CREsted’s dataloaders (see bottom of tutorial). This is only the case for model definitions; losses and metrics you can still write in pure torch and use within CREsted as usual.The second option is easier: write your torch model inside a class that inherits from
keras.Modeland make sure to wrap any torch layers that uses parameters with akeras.layers.TorchModuleWrapper.
if keras.backend.backend() == "torch":
import torch.nn as nn
import torch.nn.functional as F
from keras.layers import TorchModuleWrapper
# Method 1: Pure PyTorch model (for custom training loops)
class PurePyTorchModel(nn.Module):
"""Pure PyTorch model equivalent to the Keras CNN."""
def __init__(self, num_classes):
super().__init__()
self.conv1 = nn.Conv1d(4, 128, 15, padding=7)
self.bn1 = nn.BatchNorm1d(128)
self.pool1 = nn.MaxPool1d(8)
self.dropout1 = nn.Dropout(0.2)
self.conv2 = nn.Conv1d(128, 256, 7, padding=3)
self.bn2 = nn.BatchNorm1d(256)
self.pool2 = nn.MaxPool1d(4)
self.dropout2 = nn.Dropout(0.2)
self.fc1 = nn.Linear(256, 256)
self.dropout3 = nn.Dropout(0.3)
self.fc2 = nn.Linear(256, num_classes)
def forward(self, x):
"""Forward pass for the PyTorch model."""
# Input: (batch, seq_len, 4) -> (batch, 4, seq_len) for Conv1d
x = x.transpose(1, 2)
# First conv block (matching Keras model exactly)
x = F.relu(self.conv1(x))
x = self.bn1(x)
x = self.pool1(x)
x = self.dropout1(x)
# Second conv block
x = F.relu(self.conv2(x))
x = self.bn2(x)
x = self.pool2(x)
x = self.dropout2(x)
# Global average pooling
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
# Dense layers
x = F.relu(self.fc1(x))
x = self.dropout3(x)
x = F.softplus(self.fc2(x))
return x
# Method 2: Keras model using PyTorch operations
class KerasModelWithPyTorch(keras.Model):
"""Keras model that uses PyTorch operations wrapped in TorchModuleWrapper."""
def __init__(self, seq_len, num_classes, **kwargs):
super().__init__(**kwargs)
self.seq_len = seq_len
# Wrap PyTorch modules with TorchModuleWrapper when they contain parameters
self.conv1 = TorchModuleWrapper(nn.Conv1d(4, 128, 15, padding=7))
self.bn1 = TorchModuleWrapper(nn.BatchNorm1d(128))
self.pool1 = nn.MaxPool1d(8) # No parameters, no need to wrap
self.dropout1 = nn.Dropout(0.2)
self.conv2 = TorchModuleWrapper(nn.Conv1d(128, 256, 7, padding=3))
self.bn2 = TorchModuleWrapper(nn.BatchNorm1d(256))
self.pool2 = nn.MaxPool1d(4)
self.dropout2 = nn.Dropout(0.2)
self.fc1 = TorchModuleWrapper(nn.Linear(256, 256))
self.dropout3 = nn.Dropout(0.3)
self.fc2 = TorchModuleWrapper(nn.Linear(256, num_classes))
def call(self, inputs):
"""Forward pass for the Keras model using PyTorch operations."""
# Transpose for PyTorch Conv1d: (batch, seq_len, 4) -> (batch, 4, seq_len)
x = inputs.transpose(1, 2)
# First conv block using PyTorch operations
x = F.relu(self.conv1(x))
x = self.bn1(x)
x = self.pool1(x)
x = self.dropout1(x)
# Second conv block
x = F.relu(self.conv2(x))
x = self.bn2(x)
x = self.pool2(x)
x = self.dropout2(x)
# Global average pooling
x = F.adaptive_avg_pool1d(x, 1).squeeze(-1)
# Dense layers
x = F.relu(self.fc1(x))
x = self.dropout3(x)
x = F.softplus(self.fc2(x))
return x
# Create both models
num_classes = len(adata.obs_names)
seq_len = 2114
# Pure PyTorch model (for custom training loops)
pure_pytorch_model = PurePyTorchModel(num_classes)
pytorch_params = sum(p.numel() for p in pure_pytorch_model.parameters())
print(f"Pure PyTorch model created with {pytorch_params:,} parameters")
# Keras model with PyTorch components (for CREsted workflow)
keras_pytorch_model = KerasModelWithPyTorch(seq_len, num_classes)
# Build the model to count parameters
keras_pytorch_model.build((None, seq_len, 4))
print(f"Keras+PyTorch hybrid model created with {keras_pytorch_model.count_params():,} parameters")
# Use the hybrid model for CREsted training (it works with CREsted.fit())
model = keras_pytorch_model
else:
print("PyTorch-specific examples only available with PyTorch backend")
model = keras_model # Use the standard Keras model
Pure PyTorch model created with 308,883 parameters
Keras+PyTorch hybrid model created with 308,883 parameters
Custom loss functions#
CREsted can work with any Keras or torch-compatible loss function. Here we show how to implement custom losses that work across different backends:
# Framework-agnostic custom loss using Keras ops
class WeightedMSELoss(keras.losses.Loss):
"""Custom weighted MSE loss that works with any backend."""
def __init__(self, class_weights=None, name="weighted_mse", **kwargs):
super().__init__(name=name, **kwargs)
self.class_weights = class_weights
def call(self, y_true, y_pred):
"""Compute weighted mean squared error loss."""
# Use keras.ops for backend-agnostic operations
mse = keras.ops.square(y_true - y_pred)
if self.class_weights is not None:
# Apply class-specific weights
weights = keras.ops.convert_to_tensor(self.class_weights)
mse = mse * weights
return keras.ops.mean(mse)
# Backend-specific loss implementations
def create_custom_loss_backend_specific():
"""Create custom loss using backend-specific operations."""
if keras.backend.backend() == "tensorflow":
import tensorflow as tf
class TensorFlowCustomLoss(keras.losses.Loss):
def __init__(self, alpha=1.0, name="tf_custom_loss", **kwargs):
super().__init__(name=name, **kwargs)
self.alpha = alpha
def call(self, y_true, y_pred):
# Use TensorFlow operations directly
mse = tf.square(y_true - y_pred)
# Add custom TF-specific regularization
l2_reg = tf.reduce_sum(tf.square(y_pred)) * 0.001
return tf.reduce_mean(mse) + self.alpha * l2_reg
return TensorFlowCustomLoss()
elif keras.backend.backend() == "torch":
import torch
class PyTorchCustomLoss(keras.losses.Loss):
def __init__(self, alpha=1.0, name="torch_custom_loss", **kwargs):
super().__init__(name=name, **kwargs)
self.alpha = alpha
def call(self, y_true, y_pred):
# Use PyTorch operations directly
mse = torch.square(y_true - y_pred)
l2_reg = torch.sum(torch.square(y_pred)) * 0.001
return torch.mean(mse) + self.alpha * l2_reg
return PyTorchCustomLoss()
# Create loss functions
print(f"Creating losses for {keras.backend.backend()} backend.")
# Framework-agnostic loss
class_weights = np.ones(len(adata.obs_names)) # Equal weights for demo
weighted_loss = WeightedMSELoss(class_weights=class_weights)
# Backend-specific loss
backend_loss = create_custom_loss_backend_specific()
print(backend_loss)
Creating losses for torch backend.
<__main__.create_custom_loss_backend_specific.<locals>.PyTorchCustomLoss object at 0x72a3be54c590>
Custom metrics#
Similar to losses, you can create custom metrics that work across backends:
# Framework-agnostic custom metric
class SpearmanCorrelation(keras.metrics.Metric):
"""Custom Spearman correlation metric that works with any backend."""
def __init__(self, name="spearman_corr", **kwargs):
super().__init__(name=name, **kwargs)
self.correlation_sum = self.add_weight(name="correlation_sum", initializer="zeros")
self.count = self.add_weight(name="count", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
"""Update state with new predictions and true values."""
# Simplified Spearman correlation (rank correlation)
# For demo - real implementation would be more complex
pearson = keras.ops.sum(
(y_true - keras.ops.mean(y_true, axis=0)) * (y_pred - keras.ops.mean(y_pred, axis=0)),
axis=0,
)
norm_factor = keras.ops.sqrt(
keras.ops.sum(keras.ops.square(y_true - keras.ops.mean(y_true, axis=0)), axis=0)
* keras.ops.sum(keras.ops.square(y_pred - keras.ops.mean(y_pred, axis=0)), axis=0)
)
correlation = keras.ops.mean(pearson / (norm_factor + 1e-8))
self.correlation_sum.assign_add(correlation)
self.count.assign_add(1.0)
def result(self):
"""Returns the average correlation."""
return self.correlation_sum / self.count
def reset_state(self):
"""Reset the state of the metric."""
self.correlation_sum.assign(0.0)
self.count.assign(0.0)
# Or, write backend-specific metrics
def create_custom_metric_backend_specific():
"""Create custom metric using backend-specific operations."""
if keras.backend.backend() == "tensorflow":
import tensorflow as tf
class TensorFlowR2Score(keras.metrics.Metric):
def __init__(self, name="tf_r2_score", **kwargs):
super().__init__(name=name, **kwargs)
self.total_sum_squares = self.add_weight(name="tss", initializer="zeros")
self.residual_sum_squares = self.add_weight(name="rss", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
# Using TensorFlow operations directly
y_mean = tf.reduce_mean(y_true)
tss = tf.reduce_sum(tf.square(y_true - y_mean))
rss = tf.reduce_sum(tf.square(y_true - y_pred))
self.total_sum_squares.assign_add(tss)
self.residual_sum_squares.assign_add(rss)
def result(self):
return 1.0 - (self.residual_sum_squares / (self.total_sum_squares + 1e-8))
def reset_state(self):
self.total_sum_squares.assign(0.0)
self.residual_sum_squares.assign(0.0)
return TensorFlowR2Score()
elif keras.backend.backend() == "torch":
import torch
class PyTorchR2Score(keras.metrics.Metric):
def __init__(self, name="torch_r2_score", **kwargs):
super().__init__(name=name, **kwargs)
self.total_sum_squares = self.add_weight(name="tss", initializer="zeros")
self.residual_sum_squares = self.add_weight(name="rss", initializer="zeros")
def update_state(self, y_true, y_pred, sample_weight=None):
# Using PyTorch operations directly
y_mean = torch.mean(y_true)
tss = torch.sum(torch.square(y_true - y_mean))
rss = torch.sum(torch.square(y_true - y_pred))
self.total_sum_squares.assign_add(tss)
self.residual_sum_squares.assign_add(rss)
def result(self):
return 1.0 - (self.residual_sum_squares / (self.total_sum_squares + 1e-8))
def reset_state(self):
self.total_sum_squares.assign(0.0)
self.residual_sum_squares.assign(0.0)
return PyTorchR2Score()
# Create metrics
print(f"Creating metrics for {keras.backend.backend()} backend:")
# Framework-agnostic metric
spearman_metric = SpearmanCorrelation()
# Backend-specific metric
r2_metric = create_custom_metric_backend_specific()
# Standard metrics that work with both backends
standard_metrics = [
keras.metrics.MeanAbsoluteError(),
keras.metrics.CosineSimilarity(axis=1),
spearman_metric,
r2_metric,
]
print(standard_metrics)
Creating metrics for torch backend:
[<MeanAbsoluteError name=mean_absolute_error>, <CosineSimilarity name=cosine_similarity>, <SpearmanCorrelation name=spearman_corr>, <PyTorchR2Score name=torch_r2_score>]
Custom training with CREsted#
You can use CREsted’s standard training workflow or implement custom training loops. Here’s an example of custom training when using the PyTorch backend with our custom model, metrics, and loss.
Custom training loop with the tl.Crested framework#
print(backend_loss)
<__main__.create_custom_loss_backend_specific.<locals>.PyTorchCustomLoss object at 0x72a3be54c590>
# Create TaskConfig with custom components
custom_config = crested.tl.TaskConfig(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=backend_loss, # Our custom loss (pure torch in my case)
metrics=standard_metrics, # Our custom metrics
)
# Train using CREsted framework
trainer = crested.tl.Crested(
data=datamodule,
model=model, # with a torch environment this uses our "keras-torch hybrid" model
config=custom_config,
project_name="custom_tutorial",
run_name="framework_agnostic",
)
# Quick training for demo (just 2 epochs)
trainer.fit(epochs=2, learning_rate_reduce_patience=1, early_stopping_patience=2)
Model: "keras_model_with_py_torch_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ torch_module_wrapper_23 │ ? │ 7,808 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_24 │ ? │ 256 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_25 │ ? │ 0 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_26 │ ? │ 0 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_27 │ ? │ 229,632 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_28 │ ? │ 512 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_29 │ ? │ 0 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_30 │ ? │ 0 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_31 │ ? │ 65,792 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_32 │ ? │ 0 │ │ (TorchModuleWrapper) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ torch_module_wrapper_33 │ ? │ 4,883 │ │ (TorchModuleWrapper) │ │ │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 308,883 (1.18 MB)
Trainable params: 308,883 (1.18 MB)
Non-trainable params: 0 (0.00 B)
None
2025-08-21T16:10:17.146763+0200 INFO Loading sequences into memory...
2025-08-21T16:10:24.141689+0200 INFO Loading sequences into memory...
Epoch 1/2
2292/2292 ━━━━━━━━━━━━━━━━━━━━ 98s 42ms/step - cosine_similarity: 0.6523 - loss: 28.2246 - mean_absolute_error: 2.8233 - spearman_corr: 0.1557 - torch_r2_score: 0.0570 - val_cosine_similarity: 0.6516 - val_loss: 28.7826 - val_mean_absolute_error: 2.3636 - val_spearman_corr: 0.1725 - val_torch_r2_score: -5.8913e-04 - learning_rate: 0.0010
Epoch 2/2
2292/2292 ━━━━━━━━━━━━━━━━━━━━ 100s 43ms/step - cosine_similarity: 0.6580 - loss: 27.5134 - mean_absolute_error: 2.7868 - spearman_corr: 0.1910 - torch_r2_score: 0.0807 - val_cosine_similarity: 0.6574 - val_loss: 28.2391 - val_mean_absolute_error: 2.3365 - val_spearman_corr: 0.2021 - val_torch_r2_score: 0.0183 - learning_rate: 0.0010
Complete custom training loop example (PyTorch backend only)#
If you want a complete custom training loop but still want to make use of the Crested dataloaders, then that’s possible too.
# datamodule contains pytorch DataLoader when using torch backend
for x, y in datamodule.train_dataloader.data:
print(x.shape, y.shape)
break
torch.Size([64, 2114, 4]) torch.Size([64, 19])
# Custom training loop example (PyTorch backend)
if keras.backend.backend() == "torch":
import torch
custom_model = pure_pytorch_model # Use the pure PyTorch model
custom_model.to(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
optimizer = torch.optim.Adam(custom_model.parameters(), lr=1e-3)
custom_model.train()
# Get training data from Crested datamodule
train_loader = datamodule.train_dataloader.data
# Custom training loop
for epoch in range(2): # Just 2 epochs for demo
epoch_loss = 0.0
num_batches = 0
for batch_idx, (sequences, targets) in enumerate(train_loader):
if batch_idx >= 10: # Limit batches for demo
break
with torch.enable_grad():
optimizer.zero_grad()
# Forward pass
predictions = custom_model(sequences)
# Calculate loss
loss_value = backend_loss(targets, predictions)
# Backward pass - get gradients
loss_value.backward()
optimizer.step()
epoch_loss += loss_value.item()
num_batches += 1
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch + 1}/2 - Custom Loss: {avg_loss:.4f}")
print("Custom PyTorch training completed")
else:
print("Custom training loop example only available with PyTorch backend")
print("Current backend:", keras.backend.backend())
Epoch 1/2 - Custom Loss: 21.7358
Epoch 2/2 - Custom Loss: 21.3225
Custom PyTorch training completed