crested.tl.Crested#
- class crested.tl.Crested(data, model=None, config=None, project_name=None, run_name=None, logger=None, seed=None)#
Main class to handle training, testing, predicting and calculation of contribution scores.
- Parameters:
data (
AnnDataModule) – AnndataModule object containing the data.model (
Model|None(default:None)) – Model architecture to use for training.config (
TaskConfig|None(default:None)) – Task configuration (optimizer, loss, and metrics) for use in tl.Crested.project_name (
str|None(default:None)) – Name of the project. Used for logging and creating output directories. If not provided, the default project name “CREsted” will be used.run_name (
str|None(default:None)) – Name of the run. Used for wandb logging and creating output directories. If not provided, the current date and time will be used.logger (
str|None(default:None)) – Logger to use for logging. Can be “wandb”, “tensorboard”, or “dvc” (tensorboard not implemented yet) If not provided, no additional logging will be done.seed (
int(default:None)) – Seed to use for reproducibility. WARNING: this doesn’t make everything fully reproducible, especially on GPU. Some (GPU) operations are non-deterministic and simply can’t be controlled by the seed.
Examples
>>> from crested.tl import Crested >>> from crested.tl import default_configs >>> from crested.tl.data import AnnDataModule >>> from crested.tl.zoo import deeptopic_cnn
>>> # Load data >>> anndatamodule = AnnDataModule(anndata, genome="path/to/genome.fa") >>> model_architecture = deeptopic_cnn(seq_len=1000, n_classes=10) >>> configs = default_configs("topic_classification")
>>> # Initialize trainer >>> trainer = Crested( ... data=anndatamodule, ... model=model_architecture, ... config=configs, ... project_name="test", ... )
>>> # Fit the model >>> trainer.fit(epochs=100)
>>> # Evaluate the model >>> trainer.test()
Methods table#
|
Fit the model on the training and validation set. |
|
Load a (pretrained) model from a file. |
|
Evaluate the model on the test set. |
|
Perform transfer learning on the model. |
Methods#
- Crested.fit(epochs=100, mixed_precision=False, model_checkpointing=True, model_checkpointing_best_only=True, model_checkpointing_metric='val_loss', model_checkpointing_mode='min', early_stopping=True, early_stopping_patience=10, early_stopping_metric='val_loss', early_stopping_mode='min', learning_rate_reduce=True, learning_rate_reduce_patience=5, learning_rate_reduce_metric='val_loss', learning_rate_reduce_mode='min', save_dir=None, custom_callbacks=None)#
Fit the model on the training and validation set.
- Parameters:
epochs (
int(default:100)) – Number of epochs to train the model.mixed_precision (
bool(default:False)) – Enable mixed precision training.model_checkpointing (
bool(default:True)) – Save model checkpoints.model_checkpointing_best_only (
bool(default:True)) – Save only the best model checkpoint.model_checkpointing_metric (
str(default:'val_loss')) – Metric to monitor to choose best models.model_checkpointing_mode (
str(default:'min')) – ‘max’ if a high metric is better, ‘min’ if a low metric is betterearly_stopping (
bool(default:True)) – Enable early stopping.early_stopping_patience (
int(default:10)) – Number of epochs with no improvement after which training will be stopped.early_stopping_metric (
str(default:'val_loss')) – Metric to monitor for early stopping.early_stopping_mode (
str(default:'min')) – ‘max’ if a high metric is better, ‘min’ if a low metric is betterlearning_rate_reduce (
bool(default:True)) – Enable learning rate reduction.learning_rate_reduce_patience (
int(default:5)) – Number of epochs with no improvement after which learning rate will be reduced.learning_rate_reduce_metric (
str(default:'val_loss')) – Metric to monitor for reducing the learning rate.learning_rate_reduce_mode (
str(default:'min')) – ‘max’ if a high metric is better, ‘min’ if a low metric is bettersave_dir (
str|None(default:None)) – Directory for saving model to. Default to project name.custom_callbacks (
list|None(default:None)) – List of custom callbacks to use during training.
- Return type:
- Crested.load_model(model_path, compile=True)#
Load a (pretrained) model from a file.
- Parameters:
- Return type:
- Crested.test(return_metrics=False)#
Evaluate the model on the test set.
Make sure to load a model first using Crested.load_model() before calling this function. Make sure the model is compiled before calling this function.
- Crested.transferlearn(epochs_first_phase=50, epochs_second_phase=50, learning_rate_first_phase=0.0001, learning_rate_second_phase=1e-06, freeze_until_layer_name=None, freeze_until_layer_index=None, set_output_activation=None, **kwargs)#
Perform transfer learning on the model.
The first phase freezes layers up to a specified layer (if provided), removes the later layers, adds a dense output layer, and trains with a low learning rate. The second phase unfreezes all layers and continues training with an even lower learning rate.
Ensure that you load a model first using Crested.load_model() before calling this function and have a datamodule and config loaded in your Crested object.
One of freeze_until_layer_name or freeze_until_layer_index must be provided.
- Parameters:
epochs_first_phase (
int(default:50)) – Number of epochs to train in the first phase.epochs_second_phase (
int(default:50)) – Number of epochs to train in the second phase.learning_rate_first_phase (
float(default:0.0001)) – Learning rate for the first phase.learning_rate_second_phase (
float(default:1e-06)) – Learning rate for the second phase.freeze_until_layer_name (
str|None(default:None)) – Name of the layer up to which to freeze layers. If None, defaults to freezing all layers except the last layer.freeze_until_layer_index (
int|None(default:None)) – Index of the layer up to which to freeze layers. If None, defaults to freezing all layers except the last layer.set_output_activation (
str|None(default:None)) – Set output activation if different from the previous model.kwargs – Additional keyword arguments to pass to the fit method.
See also