flair.trainers.ModelTrainer#
- class flair.trainers.ModelTrainer(model, corpus)View on GitHub#
Bases:
Pluggable
Use this class to train a Flair model.
The ModelTrainer is initialized using a
flair.nn.Model
(the architecture you want to train) and aflair.data.Corpus
(the labeled data you use to train and evaluate the model). It offers two main training functions for the two main modes of training a model: (1)train()
, which is used to train a model from scratch or to fit a classification head on a frozen transformer language model. (2)fine_tune()
, which is used if you do not freeze the transformer language model and rather fine-tune it for a specific task.Additionally, there is also a train_custom method that allows you to fully customize the training run.
ModelTrainer inherits from
flair.trainers.plugins.base.Pluggable
and thus uses a plugin system to inject specific functionality into the training process. You can add any number of plugins to the above-mentioned training modes. For instance, if you want to use an annealing scheduler during training, you can add theflair.trainers.plugins.functional.AnnealingPlugin
plugin to the train command.- __init__(model, corpus)View on GitHub#
Initialize a model trainer by passing a
flair.nn.Model
(the architecture you want to train) and aflair.data.Corpus
(the labeled data you use to train and evaluate the model).- Parameters:
model (
Model
) – The model that you want to train. The model should inherit fromflair.nn.Model
. So for instance you should pass aflair.models.TextClassifier
if you want to train a text classifier, orflair.models.SequenceTagger
if you want to train an RNN-based sequence labeler.corpus (
Corpus
) – The dataset (of typeflair.data.Corpus
) used to train the model.
Methods
__init__
(model, corpus)Initialize a model trainer by passing a
flair.nn.Model
(the architecture you want to train) and aflair.data.Corpus
(the labeled data you use to train and evaluate the model).append_plugin
(plugin)dispatch
(event, *args, **kwargs)Call all functions hooked to a certain event.
fine_tune
(base_path[, warmup_fraction, ...])get_batch_steps
(batch, mini_batch_chunk_size)register_hook
(func, *events)Register a hook.
remove_hook
(handle)Remove a hook handle from this instance.
train
(base_path[, anneal_factor, patience, ...])train_custom
(base_path[, learning_rate, ...])Trains any class that implements the
flair.nn.Model
interface.validate_event
(*events)Attributes
plugins
Returns all plugins attached to this instance as a list of
BasePlugin
.- valid_events: Optional[set[EventIdenifier]] = {'_training_exception', '_training_finally', 'after_evaluation', 'after_setup', 'after_training', 'after_training_batch', 'after_training_epoch', 'after_training_loop', 'before_training_batch', 'before_training_epoch', 'before_training_optimizer_step', 'metric_recorded', 'training_interrupt'}#
- model: flair.nn.Model#
- corpus: Corpus#
- return_values: dict#
- reset_training_attributes()View on GitHub#
- static check_for_and_delete_previous_best_models(base_path)View on GitHub#
- static get_batch_steps(batch, mini_batch_chunk_size)View on GitHub#
- train(base_path, anneal_factor=0.5, patience=3, min_learning_rate=0.0001, initial_extra_patience=0, anneal_with_restarts=False, learning_rate=0.1, decoder_learning_rate=None, mini_batch_size=32, eval_batch_size=64, mini_batch_chunk_size=None, max_epochs=100, optimizer=<class 'torch.optim.sgd.SGD'>, train_with_dev=False, train_with_test=False, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=False, gold_label_dictionary_for_eval=None, exclude_labels=None, sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='cpu', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, multi_gpu=False, plugins=None, attach_default_scheduler=True, **kwargs)View on GitHub#
- fine_tune(base_path, warmup_fraction=0.1, learning_rate=5e-05, decoder_learning_rate=None, mini_batch_size=4, eval_batch_size=16, mini_batch_chunk_size=None, max_epochs=10, optimizer=<class 'torch.optim.adamw.AdamW'>, train_with_dev=False, train_with_test=False, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=True, gold_label_dictionary_for_eval=None, exclude_labels=None, sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='none', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, use_amp=False, multi_gpu=False, plugins=None, attach_default_scheduler=True, **kwargs)View on GitHub#
- train_custom(base_path, learning_rate=0.1, decoder_learning_rate=None, mini_batch_size=32, eval_batch_size=64, mini_batch_chunk_size=None, max_epochs=100, optimizer=<class 'torch.optim.sgd.SGD'>, train_with_dev=False, train_with_test=False, max_grad_norm=5.0, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=False, gold_label_dictionary_for_eval=None, exclude_labels=None, sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='cpu', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, use_amp=False, multi_gpu=False, plugins=None, **kwargs)View on GitHub#
Trains any class that implements the
flair.nn.Model
interface.- Parameters:
base_path (
Union
[Path
,str
]) – Main path to which all output during training is logged and models are savedlearning_rate (
float
) – The learning rate of the optimizerdecoder_learning_rate (
Optional
[float
]) – Optional, if set, the decoder is trained with a separate learning ratemini_batch_size (
int
) – Size of mini-batches during trainingeval_batch_size (
int
) – Size of mini-batches during evaluationmini_batch_chunk_size (
Optional
[int
]) – If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposesmax_epochs (
int
) – Maximum number of epochs to train. Terminates training if this number is surpassed.optimizer (
type
[Optimizer
]) – The optimizer to use (typically SGD or Adam)train_with_dev (
bool
) – If True, the data from dev split is added to the training datatrain_with_test (
bool
) – If True, the data from test split is added to the training datareduce_transformer_vocab (bool) – If True, temporary reduce the vocab size to limit ram usage during training.
main_evaluation_metric (
tuple
[str
,str
]) – The metric to optimize (often micro-average or macro-average F1-score, or accuracy)monitor_test (
bool
) – If True, test data is evaluated at end of each epochmonitor_train_sample (
float
) – Set this to evaluate on a sample of the train data at the end of each epoch. If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample a percentage of data points from train.max_grad_norm (
Optional
[float
]) – If not None, gradients are clipped to this value before an optimizer.step is called.use_final_model_for_eval (
bool
) – If True, the final model is used for the final evaluation. If False, the model from the best epoch as determined by main_evaluation_metric is used for the final evaluation.gold_label_dictionary_for_eval (
Optional
[Dictionary
]) – Set to force evaluation to use a particular label dictionaryexclude_labels (
Optional
[list
[str
]]) – Optionally define a list of labels to exclude from the evaluationsampler (
Union
[FlairSampler
,type
[FlairSampler
],None
]) – You can pass a data sampler here for special sampling of data.shuffle (
bool
) – If True, data is shuffled during trainingshuffle_first_epoch (
bool
) – If True, data is shuffled during the first epoch of trainingembeddings_storage_mode (
Literal
['none'
,'cpu'
,'gpu'
]) – One of ‘none’ (all embeddings are deleted and freshly recomputed), ‘cpu’ (embeddings stored on CPU) or ‘gpu’ (embeddings stored on GPU)epoch (
int
) – The starting epoch (normally 0 but could be higher if you continue training model)save_final_model (
bool
) – If True, the final model is saved at the end of training.save_optimizer_state (
bool
) – If True, the optimizer state is saved alongside the modelsave_model_each_k_epochs (
int
) – Each k epochs, a model state will be written out. If set to ‘5’, a model will be saved each 5 epochs. Default is 0 which means no model saving.create_file_logs (
bool
) – If True, logging output is written to a filecreate_loss_file (
bool
) – If True, a loss file logging output is createduse_amp (
bool
) – If True, uses the torch automatic mixed precisionmulti_gpu (
bool
) – If True, distributes training across local GPUswrite_weights (
bool
) – If True, write weights to weights.txt on each batch logging event.plugins (
Optional
[list
[TrainerPlugin
]]) – Any additional plugins you want to pass to the trainer**kwargs – Additional arguments, for instance for the optimizer
- Return type:
dict
- Returns:
A dictionary with at least the key “test_score” containing the final evaluation score. Some plugins add additional information to this dictionary, such as the
flair.trainers.plugins.MetricHistoryPlugin