Pruning models
by: Chris Cave & Jaan Kasak
(Feyn version 3.0 or newer)
The function feyn.prune_models
removes redundant and poor performing Model
s from a list. It expects the list to be sorted by some metric. The best Model
in the list is never removed.
Model
s get removed based on a decay function. A Model
decays as a function of how many epochs it has been fitted and its performance compared to its peers. It will also remove models that are structurally identical from the list, leaving only the best performing one.
This prevents us fitting Model
s that are not going to improve any further.
This is typically used after fitting and prior to updating. This ensures we keep Model
s with high potential, conserve computational resources, as well breaking out of local minima.
Example
Continuing from the previous sections, we now add the prune_models
function to our workflow.
import feyn
from feyn.datasets import make_classification
train, test = make_classification()
ql = feyn.QLattice()
models = ql.sample_models(train.columns, 'y', 'classification', max_complexity=10)
models = feyn.fit_models(models, train, 'binary_cross_entropy', 'bic', 4)
models = feyn.prune_models(
models=models
)
prune_models
Parameters of models
This is the list of Model
s you want to remove the redundant and poor performing ones from.
keep_n
At most this many models will be returned. If keep_n
is None, models are left to be pruned only by decay.