Finetune Borzoi for scATAC peaks#

In this tutorial, we’ll show how to finetune the Borzoi model to do peak regression on scATAC data.

Note

If you just want use the Borzoi model (or Enformer, or Borzoi Prime) to predict pre-existing classes, please see this tutorial.

Hide code cell source

# Set package settings
%matplotlib inline
import matplotlib
import os

## Set the font type to ensure text is saved as whole words
matplotlib.rcParams["pdf.fonttype"] = 42  # Use TrueType fonts instead of Type 3 fonts
matplotlib.rcParams["ps.fonttype"] = 42  # For PostScript as well, if needed

## Set the base directory for data retrieval with crested.get_dataset()/get_model()
os.environ['CRESTED_DATA_DIR'] = '/staging/leuven/stg_00002/lcb/cblaauw/'
import os
import zipfile
import tempfile

import anndata as ad
import matplotlib.pyplot as plt
import pandas as pd
import keras
import crested
resources_dir = "/staging/leuven/res_00001/genomes/mus_musculus/mm10_ucsc/fasta/"
genome_file = os.path.join(resources_dir, "mm10.fa")
chromsizes_file = os.path.join(resources_dir, "mm10.chrom.sizes")
folds_file = "consensus_peaks_borsplit.bed" # See 'Add train/val/test split' for how this file was created
genome = crested.Genome(genome_file, chromsizes_file)
crested.register_genome(genome)
2026-02-17T16:41:19.095165+0100 INFO Genome mm10 registered.

Read in scATAC data#

We’ll use the same dataset as used in the default tutorial, the mouse BICCN dataset, derived from the brain cortex.

bigwigs_folder, regions_file = crested.get_dataset("mouse_cortex_bigwig_cut_sites")
adata = crested.import_bigwigs(
    bigwigs_folder=bigwigs_folder,
    regions_file=regions_file,
    target_region_width=1000,
    target="count",
)
adata
2026-02-17T16:21:41.930416+0100 INFO Extracting values from 19 bigWig files...
AnnData object with n_obs × n_vars = 19 × 546993
    obs: 'file_path'
    var: 'chr', 'start', 'end', 'target_start', 'target_end'

Add train/val/test split#

Generally, for finetuning, it’s recommended to use the train/test split from the original model, like Borzoi here.
This can be derived by intersecting your consensus peaks with sequences_mouse.bed from the Borzoi repository, like with BEDTools:

regions_file="consensus_peaks_biccn.bed" # regions_file from crested.get_dataset()
folds_file="sequences_mouse.bed" # From Borzoi repo
output_file="consensus_peaks_borsplit.bed"

grep fold3 ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'test/ > ${output_file}
grep fold4 ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'val/ >> ${output_file}
for i in 0 1 2 5 6 7; do
    grep fold${i} ${folds_file} | sort -k1,1 -k2,2n | bedtools merge -i stdin -d 10 | bedtools intersect -a ${regions_file} -b stdin -wa -f 0.5 | sed $'s/$/\t'train/ >> ${output_file}
done

folds = pd.read_csv(folds_file, sep="\t", names=["name", "split"], usecols=[3, 4]).set_index("name")
print(f"% of regions found in folds file: {adata.var_names.isin(folds.index).sum() / adata.n_vars * 100:.3f}%")
% of regions found in folds file: 99.425%
# Drop regions not in any folds
print(f"Dropping {(~adata.var_names.isin(folds.index)).sum()} regions because they are not in any fold.")
adata = adata[:, adata.var_names.isin(folds.index)].copy()

# Add fold data to var
adata.var = adata.var.join(folds)

# Check result
adata.var["split"].value_counts(dropna=False)
Dropping 3146 regions because they are not in any fold.
split
train    412229
val       72744
test      58874
Name: count, dtype: int64

Alternatively, you could use the default train/test split function set a chromosome-based or random split:

# crested.pp.train_val_test_split(
#     adata, strategy="chr", val_chroms=["chr8", "chr10"], test_chroms=["chr9", "chr18"]
# )

Preprocessing#

For the preprocessing, we’ll again follow the default steps, except for the adjusted input size.

Region width#

In this example, we’ll use 2048bp inputs, to align with the 2114bp input size of the standard CNN peak regression models while staying within a multiple of 128 (as required by the Borzoi architecture). Therefore, we’ll need to resize our regions:

crested.pp.change_regions_width(adata, 2048)
2026-02-17T16:22:12.788853+0100 INFO Lazily importing module crested.pp. This could take a second...

Peak normalization#

We can normalize our peak values based on the variability of the top peak heights per cell type using the crested.pp.normalize_peaks() function.

This function applies a normalization scalar to each cell type, obtained by comparing per cell type the distribution of peak heights for the maximally accessible regions which are not specific to any cell type.

crested.pp.normalize_peaks(adata, top_k_percent=0.03)
2026-02-17T16:22:29.346358+0100 INFO Filtering on top k Gini scores...
2026-02-17T16:22:33.657138+0100 INFO Added normalization weights to adata.obsm['weights']...
chr start end target_start target_end split
region
chr9:76566142-76568190 chr9 76566142 76568190 76566666 76567666 train
chr5:98328510-98330558 chr5 98328510 98330558 98329034 98330034 train
chr5:98347819-98349867 chr5 98347819 98349867 98348343 98349343 train
chr13:34635167-34637215 chr13 34635167 34637215 34635691 34636691 train
chr13:34642109-34644157 chr13 34642109 34644157 34642633 34643633 train
... ... ... ... ... ... ...
chr13:34344270-34346318 chr13 34344270 34346318 34344794 34345794 train
chr5:98166140-98168188 chr5 98166140 98168188 98166664 98167664 train
chr5:98166667-98168715 chr5 98166667 98168715 98167191 98168191 train
chr13:34344974-34347022 chr13 34344974 34347022 34345498 34346498 train
chr5:98185712-98187760 chr5 98185712 98187760 98186236 98187236 train

48089 rows × 6 columns

We can visualize the normalization factor for each cell type using the crested.pl.qc.normalization_weights() function to inspect which cell type peaks were up/down weighted.

%matplotlib inline
crested.pl.qc.normalization_weights(adata, title="Normalization weights per cell type")
../_images/63b702ae6865901537b121b3f79b631e9d721f1df74a2099c5cba0abe39f6d43.png

Subset to specific regions#

Like in the main tutorial, we also create a subset of specific regions. We found that double fine-tuning works better than single-round on either all or the specific peaks. This is the same subset as used in the default tutorial.

adata_filtered = crested.pp.filter_regions_on_specificity(adata, gini_std_threshold=1.0, inplace=False)
2026-02-17T16:23:22.003894+0100 INFO After specificity filtering, kept 90995 out of 543847 regions.

Load in model#

We load in the Borzoi model’s weights in its architecture, with one change - the input length. All of Borzoi’s layers are width-independent, so the length can be set to any value divisible by the internal bin size (128bp).
target_length is set to the total number of output bins (64 bins of 32bp makes 2048bp output), since no cropping is needed when predicting local features.
num_classes is set to the original size simply so that there are no weight shape mismatches when loading the initial weights; the head created based on num_classes will be replaced by a new head for the number of cell types we’d like to predict below.

# Create default Borzoi architecture, with shrunk input size and target_length
base_model_architecture = crested.tl.zoo.borzoi(seq_len=2048, target_length=2048//32, num_classes=2608)

To load in the weights, we can’t directly load the model from the .keras file, since that fixes the input length at the previously set value (524288bp). However, we can extract the model.weights.h5 file containing only the weights and use that.

# Load pretrained Borzoi weights
model_file, _ = crested.get_model("Borzoi_mouse_rep0")
# Put weights into base architecture
with zipfile.ZipFile(model_file) as model_archive, tempfile.TemporaryDirectory() as tmpdir:
    model_weights_path = model_archive.extract("model.weights.h5", tmpdir)
    base_model_architecture.load_weights(model_weights_path)

Now that we have the base model with the adjusted input shape, we need to adjust the final layers to return a value for each cell type per region, instead of per-bin values. Therefore, we drop the final head, add a flatten layer after the model’s final embedding, and add a new head predicting adata.n_obs values.

# Replace track head by flatten+dense to predict single vector of scalars per region
## Get last layer before head
current = base_model_architecture.get_layer("final_conv_activation").output
## Flatten and add new layer
current = keras.layers.Flatten()(current)
current = keras.layers.Dense(adata.n_obs, activation="softplus", name="dense_out")(current)

# Turn into model
model_architecture = keras.Model(inputs=base_model_architecture.inputs, outputs=current, name="Borzoi_scalar")
print(model_architecture.summary())

Hide code cell output

Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input (InputLayer)  │ (None, 2048, 4)   │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv1D)  │ (None, 2048, 512) │     31,232 │ input[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_pool           │ (None, 1024, 512) │          0 │ stem_conv[0][0]   │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_batch… │ (None, 1024, 512) │      2,048 │ stem_pool[0][0]   │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_activ… │ (None, 1024, 512) │          0 │ tower_conv_1_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_conv   │ (None, 1024, 608) │  1,557,088 │ tower_conv_1_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_pool   │ (None, 512, 608)  │          0 │ tower_conv_1_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_batch… │ (None, 512, 608)  │      2,432 │ tower_conv_1_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_activ… │ (None, 512, 608)  │          0 │ tower_conv_2_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_conv   │ (None, 512, 736)  │  2,238,176 │ tower_conv_2_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_pool   │ (None, 256, 736)  │          0 │ tower_conv_2_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_batch… │ (None, 256, 736)  │      2,944 │ tower_conv_2_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_activ… │ (None, 256, 736)  │          0 │ tower_conv_3_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_conv   │ (None, 256, 896)  │  3,298,176 │ tower_conv_3_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_pool   │ (None, 128, 896)  │          0 │ tower_conv_3_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_batch… │ (None, 128, 896)  │      3,584 │ tower_conv_3_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_activ… │ (None, 128, 896)  │          0 │ tower_conv_4_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_conv   │ (None, 128, 1056) │  4,731,936 │ tower_conv_4_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_pool   │ (None, 64, 1056)  │          0 │ tower_conv_4_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_batch… │ (None, 64, 1056)  │      4,224 │ tower_conv_4_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_activ… │ (None, 64, 1056)  │          0 │ tower_conv_5_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_conv   │ (None, 64, 1280)  │  6,759,680 │ tower_conv_5_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_pool   │ (None, 32, 1280)  │          0 │ tower_conv_5_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_batch… │ (None, 32, 1280)  │      5,120 │ tower_conv_5_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_activ… │ (None, 32, 1280)  │          0 │ tower_conv_6_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_conv   │ (None, 32, 1536)  │  9,831,936 │ tower_conv_6_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_pool   │ (None, 16, 1536)  │          0 │ tower_conv_6_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │      3,072 │ tower_conv_6_poo… │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 16, 1536)  │          0 │ tower_conv_6_poo… │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_l… │ (None, 16, 1536)  │      3,072 │ add[0][0]         │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_a… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 1536)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 16, 1536)  │          0 │ add[0][0],        │
│                     │                   │            │ transformer_ff_1… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │      3,072 │ add_1[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 16, 1536)  │          0 │ add_1[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_l… │ (None, 16, 1536)  │      3,072 │ add_2[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_a… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 1536)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_3 (Add)         │ (None, 16, 1536)  │          0 │ add_2[0][0],      │
│                     │                   │            │ transformer_ff_2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │      3,072 │ add_3[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_4 (Add)         │ (None, 16, 1536)  │          0 │ add_3[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_l… │ (None, 16, 1536)  │      3,072 │ add_4[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_a… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 1536)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_5 (Add)         │ (None, 16, 1536)  │          0 │ add_4[0][0],      │
│                     │                   │            │ transformer_ff_3… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │      3,072 │ add_5[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_6 (Add)         │ (None, 16, 1536)  │          0 │ add_5[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_l… │ (None, 16, 1536)  │      3,072 │ add_6[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_a… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 1536)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_7 (Add)         │ (None, 16, 1536)  │          0 │ add_6[0][0],      │
│                     │                   │            │ transformer_ff_4… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │      3,072 │ add_7[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_8 (Add)         │ (None, 16, 1536)  │          0 │ add_7[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_l… │ (None, 16, 1536)  │      3,072 │ add_8[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_a… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 1536)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_9 (Add)         │ (None, 16, 1536)  │          0 │ add_8[0][0],      │
│                     │                   │            │ transformer_ff_5… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │      3,072 │ add_9[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_10 (Add)        │ (None, 16, 1536)  │          0 │ add_9[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_l… │ (None, 16, 1536)  │      3,072 │ add_10[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_a… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 1536)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_11 (Add)        │ (None, 16, 1536)  │          0 │ add_10[0][0],     │
│                     │                   │            │ transformer_ff_6… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │      3,072 │ add_11[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_12 (Add)        │ (None, 16, 1536)  │          0 │ add_11[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_l… │ (None, 16, 1536)  │      3,072 │ add_12[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_a… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 1536)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_13 (Add)        │ (None, 16, 1536)  │          0 │ add_12[0][0],     │
│                     │                   │            │ transformer_ff_7… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │      3,072 │ add_13[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_14 (Add)        │ (None, 16, 1536)  │          0 │ add_13[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_l… │ (None, 16, 1536)  │      3,072 │ add_14[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_a… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 1536)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_15 (Add)        │ (None, 16, 1536)  │          0 │ add_14[0][0],     │
│                     │                   │            │ transformer_ff_8… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │      6,144 │ add_15[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_batchn… │ (None, 32, 1536)  │      6,144 │ tower_conv_6_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_activa… │ (None, 32, 1536)  │          0 │ unet_skip_2_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d       │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_conv    │ (None, 32, 1536)  │  2,360,832 │ unet_skip_2_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_16 (Add)        │ (None, 32, 1536)  │          0 │ up_sampling1d[0]… │
│                     │                   │            │ unet_skip_2_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 32, 1536)  │  2,365,440 │ add_16[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_batchn… │ (None, 64, 1280)  │      5,120 │ tower_conv_5_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_activa… │ (None, 64, 1280)  │          0 │ unet_skip_1_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_1     │ (None, 64, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_conv    │ (None, 64, 1536)  │  1,967,616 │ unet_skip_1_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_17 (Add)        │ (None, 64, 1536)  │          0 │ up_sampling1d_1[ │
│                     │                   │            │ unet_skip_1_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 64, 1536)  │  2,365,440 │ add_17[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_batchno… │ (None, 64, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_activat… │ (None, 64, 1536)  │          0 │ final_conv_batch… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten_1 (Flatten) │ (None, 98304)     │          0 │ final_conv_activ… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │  1,867,795 │ flatten_1[0][0]   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 170,213,747 (649.31 MB)
 Trainable params: 170,188,723 (649.22 MB)
 Non-trainable params: 25,024 (97.75 KB)
None

Model training#

Parameters#

The DataModule and TaskConfig let you set standard training parameters, like batch size and learning rate.
We use the same parameters as with peak regression in the default tutorial, except for a lower learning rate to match the fact that we are starting from a pre-trained model.

datamodule = crested.tl.data.AnnDataModule(
    adata,
    genome=genome,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)
optimizer = keras.optimizers.Adam(learning_rate=1e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
]

config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x1469e0eb4ec0>, loss=CosineMSELogLoss: {'name': 'CosineMSELogLoss', 'reduction': 'sum_over_batch_size', 'max_weight': 100}, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])

Finetune on full peak set#

By default:

  1. The model will continue training until the validation loss stops decreasing for 10 epochs with a maximum of 100 epochs.

  2. Every best model is saved based on the validation loss.

  3. The learning rate reduces by a factor of 0.25 if the validation loss stops decreasing for 5 epochs.

# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="biccn_borzoi_atac",
    run_name="testrun",
    logger="wandb",
)
# train the model
trainer.fit(epochs=10)

Hide code cell output

Tracking run with wandb version 0.24.1
Run data is saved locally in /lustre1/project/stg_00002/lcb/cblaauw/python_modules/CREsted/docs/tutorials/wandb/run-20260217_115238-vgmw92y2
Syncing run testrun to Weights & Biases (docs)
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input (InputLayer)  │ (None, 2048, 4)   │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv1D)  │ (None, 2048, 512) │     31,232 │ input[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_pool           │ (None, 1024, 512) │          0 │ stem_conv[0][0]   │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_batch… │ (None, 1024, 512) │      2,048 │ stem_pool[0][0]   │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_activ… │ (None, 1024, 512) │          0 │ tower_conv_1_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_conv   │ (None, 1024, 608) │  1,557,088 │ tower_conv_1_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_pool   │ (None, 512, 608)  │          0 │ tower_conv_1_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_batch… │ (None, 512, 608)  │      2,432 │ tower_conv_1_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_activ… │ (None, 512, 608)  │          0 │ tower_conv_2_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_conv   │ (None, 512, 736)  │  2,238,176 │ tower_conv_2_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_pool   │ (None, 256, 736)  │          0 │ tower_conv_2_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_batch… │ (None, 256, 736)  │      2,944 │ tower_conv_2_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_activ… │ (None, 256, 736)  │          0 │ tower_conv_3_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_conv   │ (None, 256, 896)  │  3,298,176 │ tower_conv_3_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_pool   │ (None, 128, 896)  │          0 │ tower_conv_3_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_batch… │ (None, 128, 896)  │      3,584 │ tower_conv_3_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_activ… │ (None, 128, 896)  │          0 │ tower_conv_4_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_conv   │ (None, 128, 1056) │  4,731,936 │ tower_conv_4_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_pool   │ (None, 64, 1056)  │          0 │ tower_conv_4_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_batch… │ (None, 64, 1056)  │      4,224 │ tower_conv_4_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_activ… │ (None, 64, 1056)  │          0 │ tower_conv_5_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_conv   │ (None, 64, 1280)  │  6,759,680 │ tower_conv_5_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_pool   │ (None, 32, 1280)  │          0 │ tower_conv_5_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_batch… │ (None, 32, 1280)  │      5,120 │ tower_conv_5_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_activ… │ (None, 32, 1280)  │          0 │ tower_conv_6_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_conv   │ (None, 32, 1536)  │  9,831,936 │ tower_conv_6_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_pool   │ (None, 16, 1536)  │          0 │ tower_conv_6_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │      3,072 │ tower_conv_6_poo… │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 16, 1536)  │          0 │ tower_conv_6_poo… │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_l… │ (None, 16, 1536)  │      3,072 │ add[0][0]         │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_a… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 1536)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 16, 1536)  │          0 │ add[0][0],        │
│                     │                   │            │ transformer_ff_1… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │      3,072 │ add_1[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 16, 1536)  │          0 │ add_1[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_l… │ (None, 16, 1536)  │      3,072 │ add_2[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_a… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 1536)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_3 (Add)         │ (None, 16, 1536)  │          0 │ add_2[0][0],      │
│                     │                   │            │ transformer_ff_2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │      3,072 │ add_3[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_4 (Add)         │ (None, 16, 1536)  │          0 │ add_3[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_l… │ (None, 16, 1536)  │      3,072 │ add_4[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_a… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 1536)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_5 (Add)         │ (None, 16, 1536)  │          0 │ add_4[0][0],      │
│                     │                   │            │ transformer_ff_3… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │      3,072 │ add_5[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_6 (Add)         │ (None, 16, 1536)  │          0 │ add_5[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_l… │ (None, 16, 1536)  │      3,072 │ add_6[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_a… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 1536)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_7 (Add)         │ (None, 16, 1536)  │          0 │ add_6[0][0],      │
│                     │                   │            │ transformer_ff_4… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │      3,072 │ add_7[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_8 (Add)         │ (None, 16, 1536)  │          0 │ add_7[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_l… │ (None, 16, 1536)  │      3,072 │ add_8[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_a… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 1536)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_9 (Add)         │ (None, 16, 1536)  │          0 │ add_8[0][0],      │
│                     │                   │            │ transformer_ff_5… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │      3,072 │ add_9[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_10 (Add)        │ (None, 16, 1536)  │          0 │ add_9[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_l… │ (None, 16, 1536)  │      3,072 │ add_10[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_a… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 1536)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_11 (Add)        │ (None, 16, 1536)  │          0 │ add_10[0][0],     │
│                     │                   │            │ transformer_ff_6… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │      3,072 │ add_11[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_12 (Add)        │ (None, 16, 1536)  │          0 │ add_11[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_l… │ (None, 16, 1536)  │      3,072 │ add_12[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_a… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 1536)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_13 (Add)        │ (None, 16, 1536)  │          0 │ add_12[0][0],     │
│                     │                   │            │ transformer_ff_7… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │      3,072 │ add_13[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_14 (Add)        │ (None, 16, 1536)  │          0 │ add_13[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_l… │ (None, 16, 1536)  │      3,072 │ add_14[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_a… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 1536)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_15 (Add)        │ (None, 16, 1536)  │          0 │ add_14[0][0],     │
│                     │                   │            │ transformer_ff_8… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │      6,144 │ add_15[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_batchn… │ (None, 32, 1536)  │      6,144 │ tower_conv_6_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_activa… │ (None, 32, 1536)  │          0 │ unet_skip_2_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d       │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_conv    │ (None, 32, 1536)  │  2,360,832 │ unet_skip_2_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_16 (Add)        │ (None, 32, 1536)  │          0 │ up_sampling1d[0]… │
│                     │                   │            │ unet_skip_2_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 32, 1536)  │  2,365,440 │ add_16[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_batchn… │ (None, 64, 1280)  │      5,120 │ tower_conv_5_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_activa… │ (None, 64, 1280)  │          0 │ unet_skip_1_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_1     │ (None, 64, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_conv    │ (None, 64, 1536)  │  1,967,616 │ unet_skip_1_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_17 (Add)        │ (None, 64, 1536)  │          0 │ up_sampling1d_1[ │
│                     │                   │            │ unet_skip_1_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 64, 1536)  │  2,365,440 │ add_17[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_batchno… │ (None, 64, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_activat… │ (None, 64, 1536)  │          0 │ final_conv_batch… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten_1 (Flatten) │ (None, 98304)     │          0 │ final_conv_activ… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │  1,867,795 │ flatten_1[0][0]   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 170,213,747 (649.31 MB)
 Trainable params: 170,188,723 (649.22 MB)
 Non-trainable params: 25,024 (97.75 KB)
None
2026-02-17T11:52:40.950585+0100 INFO Loading sequences into memory...
2026-02-17T11:52:48.870067+0100 INFO Loading sequences into memory...
Epoch 1/10
25764/25765 ━━━━━━━━━━━━━━━━━━━ 0s 34ms/step - concordance_correlation_coefficient: 0.7660 - cosine_similarity: 0.8603 - loss: -0.5722 - mean_absolute_error: 2.5224 - mean_squared_error: 27.6757 - pearson_correlation: 0.8287 - pearson_correlation_log: 0.6313 - zero_penalty_metric: 135.1406
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step - concordance_correlation_coefficient: 0.7660 - cosine_similarity: 0.8603 - loss: -0.5722 - mean_absolute_error: 2.5224 - mean_squared_error: 27.6754 - pearson_correlation: 0.8287 - pearson_correlation_log: 0.6313 - zero_penalty_metric: 135.1405
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 1022s 37ms/step - concordance_correlation_coefficient: 0.8346 - cosine_similarity: 0.8710 - loss: -0.6125 - mean_absolute_error: 2.3299 - mean_squared_error: 21.4265 - pearson_correlation: 0.8681 - pearson_correlation_log: 0.6517 - zero_penalty_metric: 134.1213 - val_concordance_correlation_coefficient: 0.8756 - val_cosine_similarity: 0.8773 - val_loss: -0.6307 - val_mean_absolute_error: 2.1718 - val_mean_squared_error: 17.8461 - val_pearson_correlation: 0.8773 - val_pearson_correlation_log: 0.6478 - val_zero_penalty_metric: 141.3275 - learning_rate: 1.0000e-05
Epoch 2/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 905s 35ms/step - concordance_correlation_coefficient: 0.8932 - cosine_similarity: 0.8861 - loss: -0.6628 - mean_absolute_error: 2.0703 - mean_squared_error: 15.0244 - pearson_correlation: 0.9068 - pearson_correlation_log: 0.6739 - zero_penalty_metric: 133.0103 - val_concordance_correlation_coefficient: 0.8621 - val_cosine_similarity: 0.8815 - val_loss: -0.6404 - val_mean_absolute_error: 2.1159 - val_mean_squared_error: 17.2662 - val_pearson_correlation: 0.8824 - val_pearson_correlation_log: 0.6569 - val_zero_penalty_metric: 139.1693 - learning_rate: 1.0000e-05
Epoch 3/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 902s 35ms/step - concordance_correlation_coefficient: 0.9088 - cosine_similarity: 0.8931 - loss: -0.6864 - mean_absolute_error: 1.9587 - mean_squared_error: 13.1128 - pearson_correlation: 0.9186 - pearson_correlation_log: 0.6840 - zero_penalty_metric: 132.3485 - val_concordance_correlation_coefficient: 0.8751 - val_cosine_similarity: 0.8825 - val_loss: -0.6417 - val_mean_absolute_error: 2.1054 - val_mean_squared_error: 16.5962 - val_pearson_correlation: 0.8836 - val_pearson_correlation_log: 0.6561 - val_zero_penalty_metric: 138.5822 - learning_rate: 1.0000e-05
Epoch 4/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 901s 35ms/step - concordance_correlation_coefficient: 0.9194 - cosine_similarity: 0.8988 - loss: -0.7060 - mean_absolute_error: 1.8694 - mean_squared_error: 11.7583 - pearson_correlation: 0.9270 - pearson_correlation_log: 0.6924 - zero_penalty_metric: 131.6307 - val_concordance_correlation_coefficient: 0.8808 - val_cosine_similarity: 0.8826 - val_loss: -0.6398 - val_mean_absolute_error: 2.1156 - val_mean_squared_error: 16.3505 - val_pearson_correlation: 0.8849 - val_pearson_correlation_log: 0.6537 - val_zero_penalty_metric: 140.8437 - learning_rate: 1.0000e-05
Epoch 5/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 900s 35ms/step - concordance_correlation_coefficient: 0.9271 - cosine_similarity: 0.9040 - loss: -0.7237 - mean_absolute_error: 1.7944 - mean_squared_error: 10.7341 - pearson_correlation: 0.9335 - pearson_correlation_log: 0.7006 - zero_penalty_metric: 130.8759 - val_concordance_correlation_coefficient: 0.8678 - val_cosine_similarity: 0.8821 - val_loss: -0.6354 - val_mean_absolute_error: 2.1226 - val_mean_squared_error: 16.8899 - val_pearson_correlation: 0.8834 - val_pearson_correlation_log: 0.6544 - val_zero_penalty_metric: 138.6294 - learning_rate: 1.0000e-05
Epoch 6/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 894s 35ms/step - concordance_correlation_coefficient: 0.9330 - cosine_similarity: 0.9088 - loss: -0.7400 - mean_absolute_error: 1.7288 - mean_squared_error: 9.9419 - pearson_correlation: 0.9384 - pearson_correlation_log: 0.7083 - zero_penalty_metric: 130.0102 - val_concordance_correlation_coefficient: 0.8732 - val_cosine_similarity: 0.8812 - val_loss: -0.6340 - val_mean_absolute_error: 2.1171 - val_mean_squared_error: 16.4989 - val_pearson_correlation: 0.8845 - val_pearson_correlation_log: 0.6573 - val_zero_penalty_metric: 137.7851 - learning_rate: 1.0000e-05
Epoch 7/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 891s 35ms/step - concordance_correlation_coefficient: 0.9377 - cosine_similarity: 0.9134 - loss: -0.7551 - mean_absolute_error: 1.6696 - mean_squared_error: 9.2848 - pearson_correlation: 0.9425 - pearson_correlation_log: 0.7157 - zero_penalty_metric: 129.1130 - val_concordance_correlation_coefficient: 0.8642 - val_cosine_similarity: 0.8797 - val_loss: -0.6255 - val_mean_absolute_error: 2.1472 - val_mean_squared_error: 17.0836 - val_pearson_correlation: 0.8828 - val_pearson_correlation_log: 0.6474 - val_zero_penalty_metric: 138.5760 - learning_rate: 1.0000e-05
Epoch 8/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 889s 34ms/step - concordance_correlation_coefficient: 0.9420 - cosine_similarity: 0.9178 - loss: -0.7689 - mean_absolute_error: 1.6156 - mean_squared_error: 8.6974 - pearson_correlation: 0.9462 - pearson_correlation_log: 0.7233 - zero_penalty_metric: 128.1267 - val_concordance_correlation_coefficient: 0.8752 - val_cosine_similarity: 0.8770 - val_loss: -0.6183 - val_mean_absolute_error: 2.1505 - val_mean_squared_error: 16.5057 - val_pearson_correlation: 0.8834 - val_pearson_correlation_log: 0.6468 - val_zero_penalty_metric: 137.4193 - learning_rate: 1.0000e-05
Epoch 9/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 890s 35ms/step - concordance_correlation_coefficient: 0.9473 - cosine_similarity: 0.9242 - loss: -0.7888 - mean_absolute_error: 1.5369 - mean_squared_error: 7.9236 - pearson_correlation: 0.9512 - pearson_correlation_log: 0.7343 - zero_penalty_metric: 126.7948 - val_concordance_correlation_coefficient: 0.8741 - val_cosine_similarity: 0.8771 - val_loss: -0.6163 - val_mean_absolute_error: 2.1496 - val_mean_squared_error: 16.6648 - val_pearson_correlation: 0.8828 - val_pearson_correlation_log: 0.6447 - val_zero_penalty_metric: 137.2684 - learning_rate: 2.5000e-06
Epoch 10/10
25765/25765 ━━━━━━━━━━━━━━━━━━━━ 922s 36ms/step - concordance_correlation_coefficient: 0.9486 - cosine_similarity: 0.9259 - loss: -0.7936 - mean_absolute_error: 1.5176 - mean_squared_error: 7.7285 - pearson_correlation: 0.9524 - pearson_correlation_log: 0.7374 - zero_penalty_metric: 126.2074 - val_concordance_correlation_coefficient: 0.8729 - val_cosine_similarity: 0.8759 - val_loss: -0.6131 - val_mean_absolute_error: 2.1524 - val_mean_squared_error: 16.7594 - val_pearson_correlation: 0.8828 - val_pearson_correlation_log: 0.6455 - val_zero_penalty_metric: 136.1508 - learning_rate: 2.5000e-06


Run history:


batch/batch_step▁▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇████
batch/concordance_correlation_coefficient▁▃▄▄▄▆▆▇▇▇▇▇▇▇▇▇████████████████████████
batch/cosine_similarity▁▂▃▄▄▄▄▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇▇▇█████
batch/learning_rate█████████████████████████████████▁▁▁▁▁▁▁
batch/loss█▇▆▆▆▅▅▅▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
batch/mean_absolute_error██▇▇▇▅▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
batch/mean_squared_error█▅▅▅▅▅▅▅▅▅▄▄▃▃▃▃▃▃▃▃▂▂▂▂▁▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
batch/pearson_correlation▁▄▅▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇████████████████████
batch/pearson_correlation_log▁▁▂▂▂▃▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇████████████
batch/zero_penalty_metric█▆▆▆▆▅▅▅▅▄▄▅▅▅▄▄▄▄▄▄▄▃▃▃▃▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁
+18...

Run summary:


batch/batch_step257690
batch/concordance_correlation_coefficient0.94863
batch/cosine_similarity0.92586
batch/learning_rate0.0
batch/loss-0.79362
batch/mean_absolute_error1.5176
batch/mean_squared_error7.72866
batch/pearson_correlation0.95243
batch/pearson_correlation_log0.73744
batch/zero_penalty_metric126.21301
+18...

View run testrun at: https://wandb.ai/cas-blaauw/biccn_borzoi_atac/runs/vgmw92y2
View project at: https://wandb.ai/cas-blaauw/biccn_borzoi_atac
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20260217_115238-vgmw92y2/logs

Further finetuning on specific regions#

We found that finetuning on the full peak set, then on the filtered peak set improved performance over training only on either set. Therefore, we’ll filter the peaks to keep only cell type-specific peaks and further finetune the model.

datamodule = crested.tl.data.AnnDataModule(
    adata_filtered,
    genome=genome,
    batch_size=32,  # lower this if you encounter OOM errors
    max_stochastic_shift=3,  # optional augmentation
    always_reverse_complement=True,  # default True. Will double the effective size of the training dataset.
)
optimizer = keras.optimizers.Adam(learning_rate=5e-5)
loss = crested.tl.losses.CosineMSELogLoss(max_weight=100, multiplier=1)
metrics = [
    keras.metrics.MeanAbsoluteError(),
    keras.metrics.MeanSquaredError(),
    keras.metrics.CosineSimilarity(axis=1),
    crested.tl.metrics.PearsonCorrelation(),
    crested.tl.metrics.ConcordanceCorrelationCoefficient(),
    crested.tl.metrics.PearsonCorrelationLog(),
    crested.tl.metrics.ZeroPenaltyMetric(),
]

config = crested.tl.TaskConfig(optimizer, loss, metrics)
print(config)
TaskConfig(optimizer=<keras.src.optimizers.adam.Adam object at 0x1469e0d60910>, loss=CosineMSELogLoss: {'name': 'CosineMSELogLoss', 'reduction': 'sum_over_batch_size', 'max_weight': 100}, metrics=[<MeanAbsoluteError name=mean_absolute_error>, <MeanSquaredError name=mean_squared_error>, <CosineSimilarity name=cosine_similarity>, <PearsonCorrelation name=pearson_correlation>, <ConcordanceCorrelationCoefficient name=concordance_correlation_coefficient>, <PearsonCorrelationLog name=pearson_correlation_log>, <ZeroPenaltyMetric name=zero_penalty_metric>])
model_architecture = keras.models.load_model("biccn_borzoi_atac/testrun/checkpoints/03.keras", compile=False)
# setup the trainer
trainer = crested.tl.Crested(
    data=datamodule,
    model=model_architecture,
    config=config,
    project_name="biccn_borzoi_atac",
    run_name="testrun_ft",
    logger="wandb",
)
# train the model
trainer.fit(epochs=5)

Hide code cell output

Tracking run with wandb version 0.24.1
Run data is saved locally in /lustre1/project/stg_00002/lcb/cblaauw/python_modules/CREsted/docs/tutorials/wandb/run-20260217_153059-zs2yvgas
Model: "Borzoi_scalar"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape          Param #  Connected to      ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ input (InputLayer)  │ (None, 2048, 4)   │          0 │ -                 │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_conv (Conv1D)  │ (None, 2048, 512) │     31,232 │ input[0][0]       │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ stem_pool           │ (None, 1024, 512) │          0 │ stem_conv[0][0]   │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_batch… │ (None, 1024, 512) │      2,048 │ stem_pool[0][0]   │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_activ… │ (None, 1024, 512) │          0 │ tower_conv_1_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_conv   │ (None, 1024, 608) │  1,557,088 │ tower_conv_1_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_1_pool   │ (None, 512, 608)  │          0 │ tower_conv_1_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_batch… │ (None, 512, 608)  │      2,432 │ tower_conv_1_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_activ… │ (None, 512, 608)  │          0 │ tower_conv_2_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_conv   │ (None, 512, 736)  │  2,238,176 │ tower_conv_2_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_2_pool   │ (None, 256, 736)  │          0 │ tower_conv_2_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_batch… │ (None, 256, 736)  │      2,944 │ tower_conv_2_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_activ… │ (None, 256, 736)  │          0 │ tower_conv_3_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_conv   │ (None, 256, 896)  │  3,298,176 │ tower_conv_3_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_3_pool   │ (None, 128, 896)  │          0 │ tower_conv_3_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_batch… │ (None, 128, 896)  │      3,584 │ tower_conv_3_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_activ… │ (None, 128, 896)  │          0 │ tower_conv_4_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_conv   │ (None, 128, 1056) │  4,731,936 │ tower_conv_4_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_4_pool   │ (None, 64, 1056)  │          0 │ tower_conv_4_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_batch… │ (None, 64, 1056)  │      4,224 │ tower_conv_4_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_activ… │ (None, 64, 1056)  │          0 │ tower_conv_5_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_conv   │ (None, 64, 1280)  │  6,759,680 │ tower_conv_5_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_5_pool   │ (None, 32, 1280)  │          0 │ tower_conv_5_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_batch… │ (None, 32, 1280)  │      5,120 │ tower_conv_5_poo… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_activ… │ (None, 32, 1280)  │          0 │ tower_conv_6_bat… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_conv   │ (None, 32, 1536)  │  9,831,936 │ tower_conv_6_act… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ tower_conv_6_pool   │ (None, 16, 1536)  │          0 │ tower_conv_6_con… │
│ (MaxPooling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │      3,072 │ tower_conv_6_poo… │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_1_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add (Add)           │ (None, 16, 1536)  │          0 │ tower_conv_6_poo… │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_l… │ (None, 16, 1536)  │      3,072 │ add[0][0]         │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_a… │ (None, 16, 3072)  │          0 │ transformer_ff_1… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_1… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_1_d… │ (None, 16, 1536)  │          0 │ transformer_ff_1… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_1 (Add)         │ (None, 16, 1536)  │          0 │ add[0][0],        │
│                     │                   │            │ transformer_ff_1… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │      3,072 │ add_1[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_2_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_2 (Add)         │ (None, 16, 1536)  │          0 │ add_1[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_l… │ (None, 16, 1536)  │      3,072 │ add_2[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_a… │ (None, 16, 3072)  │          0 │ transformer_ff_2… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_2… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_2_d… │ (None, 16, 1536)  │          0 │ transformer_ff_2… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_3 (Add)         │ (None, 16, 1536)  │          0 │ add_2[0][0],      │
│                     │                   │            │ transformer_ff_2… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │      3,072 │ add_3[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_3_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_4 (Add)         │ (None, 16, 1536)  │          0 │ add_3[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_l… │ (None, 16, 1536)  │      3,072 │ add_4[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_a… │ (None, 16, 3072)  │          0 │ transformer_ff_3… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_3… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_3_d… │ (None, 16, 1536)  │          0 │ transformer_ff_3… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_5 (Add)         │ (None, 16, 1536)  │          0 │ add_4[0][0],      │
│                     │                   │            │ transformer_ff_3… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │      3,072 │ add_5[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_4_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_6 (Add)         │ (None, 16, 1536)  │          0 │ add_5[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_l… │ (None, 16, 1536)  │      3,072 │ add_6[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_a… │ (None, 16, 3072)  │          0 │ transformer_ff_4… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_4… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_4_d… │ (None, 16, 1536)  │          0 │ transformer_ff_4… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_7 (Add)         │ (None, 16, 1536)  │          0 │ add_6[0][0],      │
│                     │                   │            │ transformer_ff_4… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │      3,072 │ add_7[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_5_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_8 (Add)         │ (None, 16, 1536)  │          0 │ add_7[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_l… │ (None, 16, 1536)  │      3,072 │ add_8[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_a… │ (None, 16, 3072)  │          0 │ transformer_ff_5… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_5… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_5_d… │ (None, 16, 1536)  │          0 │ transformer_ff_5… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_9 (Add)         │ (None, 16, 1536)  │          0 │ add_8[0][0],      │
│                     │                   │            │ transformer_ff_5… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │      3,072 │ add_9[0][0]       │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_6_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_10 (Add)        │ (None, 16, 1536)  │          0 │ add_9[0][0],      │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_l… │ (None, 16, 1536)  │      3,072 │ add_10[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_a… │ (None, 16, 3072)  │          0 │ transformer_ff_6… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_6… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_6_d… │ (None, 16, 1536)  │          0 │ transformer_ff_6… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_11 (Add)        │ (None, 16, 1536)  │          0 │ add_10[0][0],     │
│                     │                   │            │ transformer_ff_6… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │      3,072 │ add_11[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_7_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_12 (Add)        │ (None, 16, 1536)  │          0 │ add_11[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_l… │ (None, 16, 1536)  │      3,072 │ add_12[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_a… │ (None, 16, 3072)  │          0 │ transformer_ff_7… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_7… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_7_d… │ (None, 16, 1536)  │          0 │ transformer_ff_7… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_13 (Add)        │ (None, 16, 1536)  │          0 │ add_12[0][0],     │
│                     │                   │            │ transformer_ff_7… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │      3,072 │ add_13[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │  6,310,400 │ transformer_mha_… │
│ (MultiheadAttentio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_mha_8_… │ (None, 16, 1536)  │          0 │ transformer_mha_… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_14 (Add)        │ (None, 16, 1536)  │          0 │ add_13[0][0],     │
│                     │                   │            │ transformer_mha_… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_l… │ (None, 16, 1536)  │      3,072 │ add_14[0][0]      │
│ (LayerNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 3072)  │  4,721,664 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_a… │ (None, 16, 3072)  │          0 │ transformer_ff_8… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_p… │ (None, 16, 1536)  │  4,720,128 │ transformer_ff_8… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ transformer_ff_8_d… │ (None, 16, 1536)  │          0 │ transformer_ff_8… │
│ (Dropout)           │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_15 (Add)        │ (None, 16, 1536)  │          0 │ add_14[0][0],     │
│                     │                   │            │ transformer_ff_8… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │      6,144 │ add_15[0][0]      │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_batchn… │ (None, 32, 1536)  │      6,144 │ tower_conv_6_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_1_… │ (None, 16, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_activa… │ (None, 32, 1536)  │          0 │ unet_skip_2_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d       │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_2_conv    │ (None, 32, 1536)  │  2,360,832 │ unet_skip_2_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_16 (Add)        │ (None, 32, 1536)  │          0 │ up_sampling1d[0]… │
│                     │                   │            │ unet_skip_2_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 32, 1536)  │  2,365,440 │ add_16[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │          0 │ upsampling_conv_… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_batchn… │ (None, 64, 1280)  │      5,120 │ tower_conv_5_con… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_conv_2_… │ (None, 32, 1536)  │  2,360,832 │ upsampling_conv_… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_activa… │ (None, 64, 1280)  │          0 │ unet_skip_1_batc… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ up_sampling1d_1     │ (None, 64, 1536)  │          0 │ upsampling_conv_… │
│ (UpSampling1D)      │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ unet_skip_1_conv    │ (None, 64, 1536)  │  1,967,616 │ unet_skip_1_acti… │
│ (Conv1D)            │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ add_17 (Add)        │ (None, 64, 1536)  │          0 │ up_sampling1d_1[ │
│                     │                   │            │ unet_skip_1_conv… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ upsampling_separab… │ (None, 64, 1536)  │  2,365,440 │ add_17[0][0]      │
│ (SeparableConv1D)   │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_batchno… │ (None, 64, 1536)  │      6,144 │ upsampling_separ… │
│ (BatchNormalizatio… │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ final_conv_activat… │ (None, 64, 1536)  │          0 │ final_conv_batch… │
│ (Activation)        │                   │            │                   │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ flatten_1 (Flatten) │ (None, 98304)     │          0 │ final_conv_activ… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_out (Dense)   │ (None, 19)        │  1,867,795 │ flatten_1[0][0]   │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
 Total params: 170,213,747 (649.31 MB)
 Trainable params: 170,188,723 (649.22 MB)
 Non-trainable params: 25,024 (97.75 KB)
None
2026-02-17T15:31:03.903866+0100 INFO Loading sequences into memory...
2026-02-17T15:31:14.837551+0100 INFO Loading sequences into memory...
Epoch 1/5
4195/4197 ━━━━━━━━━━━━━━━━━━━ 0s 34ms/step - concordance_correlation_coefficient: 0.7304 - cosine_similarity: 0.8712 - loss: -0.6451 - mean_absolute_error: 1.5989 - mean_squared_error: 12.1889 - pearson_correlation: 0.7960 - pearson_correlation_log: 0.6209 - zero_penalty_metric: 320.1040
4197/4197 ━━━━━━━━━━━━━━━━━━━━ 0s 43ms/step - concordance_correlation_coefficient: 0.7304 - cosine_similarity: 0.8712 - loss: -0.6451 - mean_absolute_error: 1.5989 - mean_squared_error: 12.1888 - pearson_correlation: 0.7960 - pearson_correlation_log: 0.6209 - zero_penalty_metric: 320.1049
4197/4197 ━━━━━━━━━━━━━━━━━━━━ 279s 47ms/step - concordance_correlation_coefficient: 0.7365 - cosine_similarity: 0.8727 - loss: -0.6489 - mean_absolute_error: 1.5869 - mean_squared_error: 12.0048 - pearson_correlation: 0.8006 - pearson_correlation_log: 0.6207 - zero_penalty_metric: 322.0251 - val_concordance_correlation_coefficient: 0.7076 - val_cosine_similarity: 0.8538 - val_loss: -0.5959 - val_mean_absolute_error: 1.7093 - val_mean_squared_error: 13.5235 - val_pearson_correlation: 0.7544 - val_pearson_correlation_log: 0.5745 - val_zero_penalty_metric: 313.7826 - learning_rate: 5.0000e-05
Epoch 2/5
4197/4197 ━━━━━━━━━━━━━━━━━━━━ 147s 35ms/step - concordance_correlation_coefficient: 0.8039 - cosine_similarity: 0.9009 - loss: -0.7157 - mean_absolute_error: 1.4239 - mean_squared_error: 9.4808 - pearson_correlation: 0.8448 - pearson_correlation_log: 0.6465 - zero_penalty_metric: 318.5771 - val_concordance_correlation_coefficient: 0.6928 - val_cosine_similarity: 0.8511 - val_loss: -0.5913 - val_mean_absolute_error: 1.6979 - val_mean_squared_error: 13.6837 - val_pearson_correlation: 0.7604 - val_pearson_correlation_log: 0.6195 - val_zero_penalty_metric: 309.6161 - learning_rate: 5.0000e-05
Epoch 3/5
4197/4197 ━━━━━━━━━━━━━━━━━━━━ 147s 35ms/step - concordance_correlation_coefficient: 0.8477 - cosine_similarity: 0.9220 - loss: -0.7674 - mean_absolute_error: 1.2879 - mean_squared_error: 7.6633 - pearson_correlation: 0.8762 - pearson_correlation_log: 0.6729 - zero_penalty_metric: 314.6130 - val_concordance_correlation_coefficient: 0.7266 - val_cosine_similarity: 0.8512 - val_loss: -0.5924 - val_mean_absolute_error: 1.7060 - val_mean_squared_error: 13.1160 - val_pearson_correlation: 0.7594 - val_pearson_correlation_log: 0.6219 - val_zero_penalty_metric: 317.2416 - learning_rate: 5.0000e-05
Epoch 4/5
4197/4197 ━━━━━━━━━━━━━━━━━━━━ 146s 35ms/step - concordance_correlation_coefficient: 0.8756 - cosine_similarity: 0.9375 - loss: -0.8059 - mean_absolute_error: 1.1800 - mean_squared_error: 6.4305 - pearson_correlation: 0.8971 - pearson_correlation_log: 0.6932 - zero_penalty_metric: 309.9731 - val_concordance_correlation_coefficient: 0.7357 - val_cosine_similarity: 0.8477 - val_loss: -0.5801 - val_mean_absolute_error: 1.7317 - val_mean_squared_error: 13.4246 - val_pearson_correlation: 0.7537 - val_pearson_correlation_log: 0.6214 - val_zero_penalty_metric: 306.4696 - learning_rate: 5.0000e-05
Epoch 5/5
4197/4197 ━━━━━━━━━━━━━━━━━━━━ 146s 35ms/step - concordance_correlation_coefficient: 0.8959 - cosine_similarity: 0.9490 - loss: -0.8358 - mean_absolute_error: 1.0895 - mean_squared_error: 5.4970 - pearson_correlation: 0.9126 - pearson_correlation_log: 0.7111 - zero_penalty_metric: 304.5094 - val_concordance_correlation_coefficient: 0.7285 - val_cosine_similarity: 0.8486 - val_loss: -0.5834 - val_mean_absolute_error: 1.6984 - val_mean_squared_error: 13.1764 - val_pearson_correlation: 0.7603 - val_pearson_correlation_log: 0.6204 - val_zero_penalty_metric: 306.8181 - learning_rate: 5.0000e-05


Run history:


batch/batch_step▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▆▆▇▇▇▇█
batch/concordance_correlation_coefficient▁▁▁▁▁▁▁▁▁▁▄▄▄▄▄▄▄▄▄▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██████
batch/cosine_similarity▁▁▁▁▁▁▁▄▄▄▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███████████
batch/learning_rate▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
batch/loss████████▅▅▅▅▅▅▅▅▃▃▃▃▃▃▃▄▄▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁
batch/mean_absolute_error████████▆▆▆▆▆▆▆▆▆▄▄▄▄▄▄▄▄▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/mean_squared_error██████████▅▅▅▅▅▅▅▅▅▅▃▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁
batch/pearson_correlation▁▃▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▅▅▇▆▆▆▇▇▇▇▇▇▇▇████████
batch/pearson_correlation_log▁▂▂▂▂▂▂▂▂▂▂▂▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▅▅▇▇▇█████
batch/zero_penalty_metric▄▇▇▇▇▇▇███▇▆▇▇▇▇▇▇▇▇▅▅▅▅▄▄▄▄▄▄▄▂▁▁▁▁▁▁▁▁
+18...

Run summary:


batch/batch_step20990
batch/concordance_correlation_coefficient0.89586
batch/cosine_similarity0.94899
batch/learning_rate5e-05
batch/loss-0.83585
batch/mean_absolute_error1.08931
batch/mean_squared_error5.49498
batch/pearson_correlation0.91263
batch/pearson_correlation_log0.71108
batch/zero_penalty_metric304.48483
+18...

View run testrun_ft at: https://wandb.ai/cas-blaauw/biccn_borzoi_atac/runs/zs2yvgas
View project at: https://wandb.ai/cas-blaauw/biccn_borzoi_atac
Synced 4 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
Find logs at: ./wandb/run-20260217_153059-zs2yvgas/logs

Evaluate model#

We’ll evaluate both the finetuned and further finetuned models, and compare against DeepBICCN2, a dilated CNN model from our registry trained on the same dataset (closely matching the main tutorial).

model = keras.models.load_model("biccn_borzoi_atac/testrun/checkpoints/03.keras", compile=False)
model_ft = keras.models.load_model("biccn_borzoi_atac/testrun_ft/checkpoints/01.keras", compile=False)
# Add predictions for model checkpoint to the adatas
adata.layers["Once finetuned"] = crested.tl.predict(adata, model).T
adata.layers["Double finetuned"] = crested.tl.predict(adata, model_ft).T

# Copy predictions for specific peaks
adata_filtered.layers["Once finetuned"] = adata.layers["Once finetuned"][:, adata.var_names.get_indexer(adata_filtered.var_names)]
adata_filtered.layers["Double finetuned"] = adata.layers["Double finetuned"][:, adata.var_names.get_indexer(adata_filtered.var_names)]
# Load in the DeepBICCN2 model
model_file_db2, output_names_db2 = crested.get_model('deepbiccn2')
deepbiccn2 = keras.models.load_model(model_file_db2, compile=False)
# Get resized anndatas to deal with the slightly different input size
adata_2114 = crested.pp.change_regions_width(adata, 2114, inplace=False)
# Predict values
adata.layers["DeepBICCN2"] = crested.tl.predict(adata_2114, deepbiccn2).T
adata_filtered.layers["DeepBICCN2"] = adata.layers["DeepBICCN2"][:, adata.var_names.get_indexer(adata_filtered.var_names)]
del adata_2114
4245/4249 ━━━━━━━━━━━━━━━━━━━ 0s 17ms/step
4249/4249 ━━━━━━━━━━━━━━━━━━━━ 121s 22ms/step
# Save the anndata with these predictions
adata.write_h5ad("crested/mouse_cortex_preds.h5ad")
# Read in data again
adata = ad.read_h5ad("crested/mouse_cortex_preds.h5ad")
adata_filtered = crested.pp.filter_regions_on_specificity(adata, gini_std_threshold=1.0, inplace=False)
2026-02-17T16:41:31.725655+0100 INFO Lazily importing module crested.pp. This could take a second...
2026-02-17T16:41:37.760591+0100 INFO After specificity filtering, kept 90995 out of 543847 regions.

Many of the plotting functions in the crested.pl module can be used to visualize these model predictions.

# Define a dataframe with test set regions
test_df = adata.var[adata.var["split"] == "test"]
test_df_ft = adata_filtered.var[adata_filtered.var["split"] == "test"]
test_df
chr start end target_start target_end split
region
chr1:3094031-3096079 chr1 3094031 3096079 3094555 3095555 test
chr1:3094696-3096744 chr1 3094696 3096744 3095220 3096220 test
chr1:3111400-3113448 chr1 3111400 3113448 3111924 3112924 test
chr1:3112760-3114808 chr1 3112760 3114808 3113284 3114284 test
chr1:3118972-3121020 chr1 3118972 3121020 3119496 3120496 test
... ... ... ... ... ... ...
chrX:21361135-21363183 chrX 21361135 21363183 21361659 21362659 test
chrX:21388522-21390570 chrX 21388522 21390570 21389046 21390046 test
chrX:21392726-21394774 chrX 21392726 21394774 21393250 21394250 test
chrX:21427413-21429461 chrX 21427413 21429461 21427937 21428937 test
chrX:21433814-21435862 chrX 21433814 21435862 21434338 21435338 test

58874 rows × 6 columns

# plot predictions vs ground truth for a random region in the test set defined by index
%matplotlib inline
idx = 22
region = test_df_ft.index[idx]
print(region)
crested.pl.region.bar(adata_filtered, region, suptitle=f"Predictions vs ground truth ({region})")
chr1:3899526-3901574
../_images/8f85ee330517c9442cc2fad5bee7dab13fb1905c7a1b21008dfcb8922c33599b.png
# Self-correlation values of all peaks and specific peaks, upper bound on the correlations between truth and predictions
%matplotlib inline
fig, axs = plt.subplots(1, 2, figsize = (15, 8))
crested.pl.corr.heatmap_self(
    adata,
    title="All peaks",
    show=False,
    ax=axs[0],
    cbar_kws={'shrink': 0.5}
)
crested.pl.corr.heatmap_self(
    adata_filtered,
    title="Specific peaks",
    suptitle="Self-correlation heatmap",
    ax=axs[1],
    show=False,
    cbar_kws={'shrink': 0.5}
)
plt.show()
2026-02-17T16:43:27.543253+0100 WARNING Using keyword argument layout does not do anything when passing a pre-existing axis.
2026-02-17T16:43:27.640792+0100 WARNING Using keyword argument layout does not do anything when passing a pre-existing axis.
(<Figure size 1500x800 with 4 Axes>,
 <Axes: title={'center': 'Specific peaks'}>)
../_images/5280be6d00cb99cb9e5d1c6bbc62feb3992321a272e2a76e0a404315d4c0a97c.png
%matplotlib inline
crested.pl.corr.heatmap(
    adata,
    split="test",
    suptitle="Correlations between ground truths and predictions - all test regions",
    log_transform=True,
    vmin = 0,
    vmax = 1,
)
crested.pl.corr.heatmap(
    adata_filtered,
    split="test",
    suptitle="Correlations between ground truths and predictions - specific test regions",
    log_transform=True,
    vmin = 0,
    vmax = 1,
)
../_images/70a7f6fa50a8bd20f83b636c87a3e47818cae4778e49335241d8a090c135c30c.png ../_images/b91732c9cddbffc5d7346390aae24f4c51020755053c1de4ed9c5056e1895c51.png
%matplotlib inline
crested.pl.corr.scatter(
    adata,
    split="test",
    suptitle="All test regions",
    log_transform=True,
    density_indication=True,
    identity_line=True,
    square=True
)

crested.pl.corr.scatter(
    adata_filtered,
    split="test",
    suptitle="Specific test regions",
    log_transform=True,
    density_indication=True,
    identity_line=True,
    square=True
)
2026-02-17T16:45:22.573983+0100 INFO Plotting density scatter for all targets and predictions, models: ['DeepBICCN2', 'Double finetuned', 'Once finetuned'], split: test
2026-02-17T16:46:41.194634+0100 INFO Plotting density scatter for all targets and predictions, models: ['DeepBICCN2', 'Double finetuned', 'Once finetuned'], split: test
../_images/6ee01b4ace22075282b34095f11af909e309f26cd976a7e3a3473ebaec678981.png ../_images/2d54e6abbe3874642a999e1bcfed055dc5f8267279077e7eaa6d921f6ef3d118.png
%matplotlib inline
model_order = ["DeepBICCN2", "Once finetuned", "Double finetuned"]
fig, axs = plt.subplots(1, 2, figsize=(15, 8))
crested.pl.corr.violin(
    adata,
    ax=axs[0],
    title="All peaks",
    suptitle="Per-class correlations compared between models and datasets",
    plot_kws={'order': model_order},
    show=False,
)
crested.pl.corr.violin(adata_filtered, ax=axs[1], title="Specific peaks", plot_kws={'order': model_order}, show=False)
plt.show()
../_images/570ec748449ed9c36ebca6011eb2a35751e69671c20e6facf7cec3594d7d65d6.png

Here, we see that the fine-tuned models generally stack up very comparably to the CNN-based models on this dataset, but that they don’t have an edge.

Besides looking at prediction scores, we can also use these models to explain the features in the sequence that contributed to predicted accessibility in a certain cell type.
Here, we’ll look at three regions, expected to be active in microglia (Micro_PVM), Sst/Chodl GABAergic neurons (SstChodl), or in layer 6b glutamatergic neurons (L6b) respectively.

regions_of_interest = [
    "chr18:61107803-61109851",
    "chr13:92952218-92954266",
    "chr9:56036511-56038559",
]
classes_of_interest = ["Micro_PVM", "SstChodl", "L6b"]
class_idx = list(adata.obs_names.get_indexer(classes_of_interest))

scores, one_hot_encoded_sequences = crested.tl.contribution_scores(
    regions_of_interest,
    target_idx=class_idx,
    model=model_ft,
)
2026-02-17T16:52:53.490277+0100 INFO Calculating contribution scores for 3 class(es) and 3 region(s).
# Plot attribution scores
crested.pl.explain.contribution_scores(
    scores,
    one_hot_encoded_sequences,
    sequence_labels=regions_of_interest,
    class_labels=classes_of_interest,
    zoom_n_bases=500,
    title_fontsize=20,
)
../_images/81aae77b6b14bb715800e4d3c904c917ed34931503b7a5b8efbff39b2b88787b.png