SLEP023: Callback API

Author:

Jérémie du Boisberranger

Status:

Accepted

Type:

Standards Track

Created:

2024-03-01

Resolution:

https://github.com/scikit-learn/enhancement_proposals/pull/103

Abstract

This SLEP proposes an API to allow users to register callbacks to be called at specific points during the training of scikit-learn estimators.

Motivation

The current scikit-learn API provides a very limited way for users to inspect the steps of the training process (fit) of an estimator, and even less so when several estimators and meta-estimators are composed. Setting a verbosity level provides some information and only a few estimators expose public attributes containing some information accumulated during training (e.g. HistGradientBoostingClassifier.train_score_), and that’s pretty much it.

Some use cases that have been requested many times on the scikit-learn issue tracker which are not possible to achieve with the current API are for instance progress bars, structured logging, metric monitoring, snapshots, etc. A callback API could also provide a generic and consistent API for early stopping, which is currently implemented differently in only a few estimators. For such widely requested features, scikit-learn will provide built-in callbacks, but users would also be able to implement their own callbacks for less common use cases.

By providing a way to gather information during the training process, effectively making the scikit-learn implementation of machine learning algorithms more transparent, the callback API would also bring a lot of value for educational, testing, and debugging purposes.

Public interfaces

This section describes the proposed public interfaces for callbacks. They are divided into three subsets which target different kinds of users:

  • the end users of scikit-learn

  • the scikit-learn and third-party developers implementing estimators

  • the scikit-learn and third-party developers implementing callbacks

Using callbacks

Callbacks are objects that can be registered on an estimator to be called at specific points during fit, gathering information about the training process. scikit-learn provides a set of built-in callbacks for common use cases, exposed in the sklearn.callback module.

Callbacks can be registered on an estimator using its set_callbacks method:

from sklearn.linear_model import LogisticRegression
from sklearn.callback import ProgressBar

callback = ProgressBar()
clf = LogisticRegression()
clf.set_callbacks(callback)
clf.fit(X, y)

When composing estimators and meta-estimators, some callbacks are intuitively expected to only be applied to some of them (e.g. early stopping on inner estimators) while others are expected to be applied to all of them (e.g. progress bars). This SLEP proposes that some callbacks (referred to as “auto-propagated”) have the property of being automatically propagated down to sub-estimators, such that they only need to be registered on the outermost meta-estimator:

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.callback import EarlyStopping, ProgressBar

clf = LogisticRegression().set_callbacks(EarlyStopping())
gs = GridSearchCV(clf, param_grid)
gs.set_callbacks(ProgressBar(max_propagation_depth=1))
gs.fit(X, y)

Callback support in estimators

To add callback support to an estimator, scikit-learn provides three components:

  • CallbackContext: a class that is used to manage the callbacks during fit. Instances of this class represent contexts of tasks being executed, where tasks are units of work defined by the estimator. There is one CallbackContext for each task. Tasks (and therefore callback contexts) have a natural tree structure, where each task can be decomposed into subtasks and so on, with the root task being the whole fit task. Usually tasks correspond to iterations of loops during fit and nested loops correspond to nested tasks, but in general a task can be any unit of work defined by the estimator.

    The CallbackContext class exposes the following methods:

    • call_on_fit_task_begin, to call the callbacks at the beginning of the task.

    • call_on_fit_task_end, to call the callbacks at the end of the task.

    • subcontext, to create a child CallbackContext for a subtask.

    • propagate_callbacks, to propagate the callback context and the auto-propagated callbacks from a meta-estimator to its sub-estimators.

    When a callback is called for a given task, it receives the corresponding context which exposes attributes that provide information about the task being executed to be used by the callback to adapt its behavior. It also receives all the information that the estimator is able to provide about the state of the fitting process at this task.

  • CallbackSupportMixin: a class that the estimator should inherit from. This mixin exposes the following methods:

    • set_callbacks, to register callbacks on the estimator.

    • _init_callback_context, to create the root callback context for the estimator and set up the callbacks for the estimator.

  • with_callbacks: a decorator that the estimator should use to decorate the fit method. It runs fit in a try-finally block to ensure that callbacks are properly torn down no matter what happens during fit.

A typical implementation of callback support in an estimator would look like this:

from sklearn.base import BaseEstimator
from sklearn.callback import CallbackSupportMixin, with_callbacks

class MyEstimator(CallbackSupportMixin, BaseEstimator):

    @with_callbacks
    def fit(self, X, y):
        callback_ctx = self._init_callback_context(
            task_name="fit", max_subtasks=self.n_iter
        )
        callback_ctx.call_on_fit_task_begin(estimator=self,X=X, y=y)

        for i in range(self.n_iter):
            callback_subctx = callback_ctx.subcontext(
                task_name=f"iteration {i}", task_id=i
            ).call_on_fit_task_begin(estimator=self, X=X, y=y)

            # <computation part of the estimator>

            callback_subctx.call_on_fit_task_end(estimator=self, X=X, y=y)

        callback_ctx.call_on_fit_task_end(estimator=self, X=X, y=y)

        return self

The callback protocol

Callbacks must implement the following protocol:

class FitCallback(Protocol):
    def setup(self, estimator, context): ...
    def on_fit_task_begin(self, estimator, context, *, X=None, y=None, metadata=None, fitted_estimator=None): ...
    def on_fit_task_end(self, estimator, context, *, X=None, y=None, metadata=None, fitted_estimator=None): ...
    def teardown(self, estimator, context): ...

The 4 protocol methods are referred to as callback hooks in the rest of this SLEP. The setup and teardown hooks are called at the beginning and end of fit and should be used to set up and tear down the callback for the estimator. The on_fit_task_begin and on_fit_task_end hooks are called at the beginning and end of each task performed during fit, including the root task.

The keyword arguments received by the on_fit_task_begin and on_fit_task_end hooks contain all the available information provided by the estimator about the state of the fitting process at this task:

  • “X”: the training data.

  • “y”: the training target.

  • “metadata”: a dictionary containing training and validation metadata, e.g. sample weights, X_val, y_val, etc.

  • “fitted_estimator”: an estimator instance that is ready to predict, transform, etc. as if fit stopped at the end of this task.

Note that some estimators may not be able to provide values for all of these keys for every task.

Auto-propagated callbacks must implement a small extension of this protocol:

class AutoPropagatedCallback(FitCallback, Protocol):
    @property
    def max_propagation_depth(self): ...

max_propagation_depth defines the maximum nesting level of sub-estimators to propagate the callback to.

Example traces of hook calls

This section gives example traces of callback hook calls in different scenarios. First, consider a callback registered on an estimator with a single level of iterations:

>>> estimator = MyEstimator(n_iter=10).set_callbacks(MyCallback())
>>> estimator.fit(X, y)
setup by MyEstimator for fit
on_fit_task_begin by MyEstimator for fit
    on_fit_task_begin by MyEstimator for iteration 0
    on_fit_task_end by MyEstimator for iteration 0
    ...  # iterations 1 to 8
    on_fit_task_begin by MyEstimator for iteration 9
    on_fit_task_end by MyEstimator for iteration 9
on_fit_task_end by MyEstimator for fit
teardown by MyEstimator for fit

It doesn’t matter whether MyCallback is auto-propagated or not since there is no sub-estimator to propagate it to in that case. Next, consider the case where the estimator from the previous example is wrapped in a meta-estimator that fits clones of that sub-estimator for different cross-validation folds:

>>> estimator = MyEstimator(n_iter=10).set_callbacks(MyCallback())
>>> MetaEstimatorCV(estimator, cv=5).fit(X, y)
setup by MyEstimator for fit (MetaEstimatorCV fold 0)
on_fit_task_begin by MyEstimator for fit (MetaEstimatorCV fold 0)
    on_fit_task_begin by MyEstimator for iteration 0
    on_fit_task_end by MyEstimator for iteration 0
    ...  # iterations 1 to 8
    on_fit_task_begin by MyEstimator for iteration 9
    on_fit_task_end by MyEstimator for iteration 9
on_fit_task_end by MyEstimator for fit (MetaEstimatorCV fold 0)
teardown by MyEstimator for fit (MetaEstimatorCV fold 0)
...  # folds 1 to 3
setup by MyEstimator for fit (MetaEstimatorCV fold 4)
on_fit_task_begin by MyEstimator for fit (MetaEstimatorCV fold 4)
    on_fit_task_begin by MyEstimator for iteration 0
    on_fit_task_end by MyEstimator for iteration 0
    ...  # iterations 1 to 8
    on_fit_task_begin by MyEstimator for iteration 9
    on_fit_task_end by MyEstimator for iteration 9
on_fit_task_end by MyEstimator for fit (MetaEstimatorCV fold 4)
teardown by MyEstimator for fit (MetaEstimatorCV fold 4)

The trace is similar to the first example, except that it is repeated for each fold. In particular, the setup and teardown hooks are called for every fit of the sub-estimator clone. Finally, consider the same composition with an auto-propagated callback registered on the meta-estimator:

>>> estimator = MyEstimator(n_iter=10)
>>> MetaEstimatorCV(estimator, cv=5).set_callbacks(MyAutoPropagatedCallback()).fit(X, y)
setup by MetaEstimatorCV for fit
on_fit_task_begin by MetaEstimatorCV for fit
    on_fit_task_begin by MyEstimator for fit (MetaEstimatorCV fold 0)
        on_fit_task_begin by MyEstimator for iteration 0
        on_fit_task_end by MyEstimator for iteration 0
        ...  # iterations 1 to 8
        on_fit_task_begin by MyEstimator for iteration 9
        on_fit_task_end by MyEstimator for iteration 9
    on_fit_task_end by MyEstimator for fit (MetaEstimatorCV fold 0)
    ...  # folds 1 to 3
    on_fit_task_begin by MyEstimator for fit (MetaEstimatorCV fold 4)
        on_fit_task_begin by MyEstimator for iteration 0
        on_fit_task_end by MyEstimator for iteration 0
        ...  # iterations 1 to 8
        on_fit_task_begin by MyEstimator for iteration 9
        on_fit_task_end by MyEstimator for iteration 9
    on_fit_task_end by MyEstimator for fit (MetaEstimatorCV fold 4)
on_fit_task_end by MetaEstimatorCV for fit
teardown by MetaEstimatorCV for fit

This time, the callback hooks are also called for the tasks of the meta-estimator. Note that the setup and teardown hooks are called only once, by the meta-estimator.

Considerations

This section discusses the main challenges and the proposed solutions for the callback API to be generic and flexible enough despite the wide diversity of scikit-learn estimators and meta-estimators.

Protocol extensions

To keep the scope of this SLEP reasonable, it only considers callbacks called during fit but the API could be extended to other methods (e.g. predict or transform) or unbound functions (e.g. cross_validate) in the future. PR #33404, for instance, proposes a new protocol for callback support in unbound functions.

The setup and teardown hooks

The setup and teardown hooks are not absolutely necessary at the time of writing since on_fit_task_begin and on_fit_task_end are also called at the beginning and end of fit as well, so they could in principle be used to set up and tear down the callback. The main motivations for including them are:

  • in anticipation for future extensions of the protocol to other methods than fit;

  • to separate technical concerns of resource management from semantic concerns related to reacting to specific events during fit.

Task granularity

The smallest task granularity considered in this SLEP is tasks dealing with the full dataset. For instance, the on_fit_task_end hook is called at the end of a loop iterating over the full dataset but not at the end of each step of such a loop. A smaller granularity would not allow the same level of flexibility and consistency across estimators.

Performance

It’s inevitable that callbacks will have a performance cost, especially when called within Cython nogil code. The most important thing is to make sure that when no callbacks are registered, the performance is not affected (not acquiring the GIL for instance).

Moreover, some information given to the hooks might be expensive to compute but is not needed by all callbacks. This SLEP proposes that callbacks must explicitly request the information which is expensive to compute, and that the estimator passes it to the hooks in a lazy way.

Parallelism

Many scikit-learn estimators use multiprocessing or multithreading which can make the design of callbacks more complex, because callbacks don’t share their state between processes. The file system or multiprocessing.Manager objects for instance should be used to overcome this issue.

Non-callback-aware meta-estimators

Callbacks can be registered on estimators that are passed to meta-estimators or unbound functions that do not support callbacks. In this case, callbacks should not break the workflow but might not provide full value and may perform suboptimally.

Additional dependencies

Some callbacks may require additional dependencies (e.g. rich for progress bars). Such dependencies must be optional and only imported when the corresponding callback is used, similarly to the display objects.

Implementation

An implementation of this SLEP is being developed in the callbacks feature branch. PR #33322 keeps an updated diff against the main branch. It currently contains the callback framework and the progress bar callback.

Discussion

The goal of this SLEP is to provide a common solution for a wide range of use cases discussed in many long-standing issues and PRs, going back to issue #78. Here are some related issues and PRs: #4863, #7574, #8433, #8994, #9136, #10489, #10973, #12325, #14338, #14531, #16118, #18507, #18748, #18773, #20127, #20668, #23156, #24524, #25187, #26395, #26494, #26532, #1171, #3817, #8317, #16925, #22000.

References and Footnotes