flair.trainers#

Classes

ModelTrainer(model, corpus)

Use this class to train a Flair model.

LanguageModelTrainer(model, corpus[, ...])

TextCorpus(path, dictionary[, forward, ...])

class flair.trainers.ModelTrainer(model, corpus)View on GitHub#

Bases: Pluggable

Use this class to train a Flair model.

The ModelTrainer is initialized using a flair.nn.Model (the architecture you want to train) and a flair.data.Corpus (the labeled data you use to train and evaluate the model). It offers two main training functions for the two main modes of training a model: (1) train(), which is used to train a model from scratch or to fit a classification head on a frozen transformer language model. (2) fine_tune(), which is used if you do not freeze the transformer language model and rather fine-tune it for a specific task.

Additionally, there is also a train_custom method that allows you to fully customize the training run.

ModelTrainer inherits from flair.trainers.plugins.base.Pluggable and thus uses a plugin system to inject specific functionality into the training process. You can add any number of plugins to the above-mentioned training modes. For instance, if you want to use an annealing scheduler during training, you can add the flair.trainers.plugins.functional.AnnealingPlugin plugin to the train command.

valid_events: Optional[set[EventIdenifier]] = {'_training_exception', '_training_finally', 'after_evaluation', 'after_setup', 'after_training', 'after_training_batch', 'after_training_epoch', 'after_training_loop', 'before_training_batch', 'before_training_epoch', 'before_training_optimizer_step', 'metric_recorded', 'training_interrupt'}#
reset_training_attributes()View on GitHub#
static check_for_and_delete_previous_best_models(base_path)View on GitHub#
static get_batch_steps(batch, mini_batch_chunk_size)View on GitHub#
train(base_path, anneal_factor=0.5, patience=3, min_learning_rate=0.0001, initial_extra_patience=0, anneal_with_restarts=False, learning_rate=0.1, decoder_learning_rate=None, mini_batch_size=32, eval_batch_size=64, mini_batch_chunk_size=None, max_epochs=100, optimizer=<class 'torch.optim.sgd.SGD'>, train_with_dev=False, train_with_test=False, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=False, gold_label_dictionary_for_eval=None, exclude_labels=None, sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='cpu', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, multi_gpu=False, plugins=None, attach_default_scheduler=True, **kwargs)View on GitHub#
fine_tune(base_path, warmup_fraction=0.1, learning_rate=5e-05, decoder_learning_rate=None, mini_batch_size=4, eval_batch_size=16, mini_batch_chunk_size=None, max_epochs=10, optimizer=<class 'torch.optim.adamw.AdamW'>, train_with_dev=False, train_with_test=False, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=True, gold_label_dictionary_for_eval=None, exclude_labels=None, sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='none', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, use_amp=False, multi_gpu=False, plugins=None, attach_default_scheduler=True, **kwargs)View on GitHub#
train_custom(base_path, learning_rate=0.1, decoder_learning_rate=None, mini_batch_size=32, eval_batch_size=64, mini_batch_chunk_size=None, max_epochs=100, optimizer=<class 'torch.optim.sgd.SGD'>, train_with_dev=False, train_with_test=False, max_grad_norm=5.0, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=False, gold_label_dictionary_for_eval=None, exclude_labels=None, sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='cpu', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, use_amp=False, multi_gpu=False, plugins=None, **kwargs)View on GitHub#

Trains any class that implements the flair.nn.Model interface.

Parameters:
  • base_path (Union[Path, str]) – Main path to which all output during training is logged and models are saved

  • learning_rate (float) – The learning rate of the optimizer

  • decoder_learning_rate (Optional[float]) – Optional, if set, the decoder is trained with a separate learning rate

  • mini_batch_size (int) – Size of mini-batches during training

  • eval_batch_size (int) – Size of mini-batches during evaluation

  • mini_batch_chunk_size (Optional[int]) – If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes

  • max_epochs (int) – Maximum number of epochs to train. Terminates training if this number is surpassed.

  • optimizer (type[Optimizer]) – The optimizer to use (typically SGD or Adam)

  • train_with_dev (bool) – If True, the data from dev split is added to the training data

  • train_with_test (bool) – If True, the data from test split is added to the training data

  • reduce_transformer_vocab (bool) – If True, temporary reduce the vocab size to limit ram usage during training.

  • main_evaluation_metric (tuple[str, str]) – The metric to optimize (often micro-average or macro-average F1-score, or accuracy)

  • monitor_test (bool) – If True, test data is evaluated at end of each epoch

  • monitor_train_sample (float) – Set this to evaluate on a sample of the train data at the end of each epoch. If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample a percentage of data points from train.

  • max_grad_norm (Optional[float]) – If not None, gradients are clipped to this value before an optimizer.step is called.

  • use_final_model_for_eval (bool) – If True, the final model is used for the final evaluation. If False, the model from the best epoch as determined by main_evaluation_metric is used for the final evaluation.

  • gold_label_dictionary_for_eval (Optional[Dictionary]) – Set to force evaluation to use a particular label dictionary

  • exclude_labels (Optional[list[str]]) – Optionally define a list of labels to exclude from the evaluation

  • sampler (Union[FlairSampler, type[FlairSampler], None]) – You can pass a data sampler here for special sampling of data.

  • shuffle (bool) – If True, data is shuffled during training

  • shuffle_first_epoch (bool) – If True, data is shuffled during the first epoch of training

  • embeddings_storage_mode (Literal['none', 'cpu', 'gpu']) – One of ‘none’ (all embeddings are deleted and freshly recomputed), ‘cpu’ (embeddings stored on CPU) or ‘gpu’ (embeddings stored on GPU)

  • epoch (int) – The starting epoch (normally 0 but could be higher if you continue training model)

  • save_final_model (bool) – If True, the final model is saved at the end of training.

  • save_optimizer_state (bool) – If True, the optimizer state is saved alongside the model

  • save_model_each_k_epochs (int) – Each k epochs, a model state will be written out. If set to ‘5’, a model will be saved each 5 epochs. Default is 0 which means no model saving.

  • create_file_logs (bool) – If True, logging output is written to a file

  • create_loss_file (bool) – If True, a loss file logging output is created

  • use_amp (bool) – If True, uses the torch automatic mixed precision

  • multi_gpu (bool) – If True, distributes training across local GPUs

  • write_weights (bool) – If True, write weights to weights.txt on each batch logging event.

  • plugins (Optional[list[TrainerPlugin]]) – Any additional plugins you want to pass to the trainer

  • **kwargs – Additional arguments, for instance for the optimizer

Return type:

dict

Returns:

A dictionary with at least the key “test_score” containing the final evaluation score. Some plugins add additional information to this dictionary, such as the flair.trainers.plugins.MetricHistoryPlugin

class flair.trainers.LanguageModelTrainer(model, corpus, optimizer=<class 'torch.optim.sgd.SGD'>, test_mode=False, epoch=0, split=0, loss=10000, optimizer_state=None, scaler_state=None)View on GitHub#

Bases: object

train(base_path, sequence_length, learning_rate=20, mini_batch_size=100, anneal_factor=0.25, patience=10, clip=0.25, max_epochs=1000, checkpoint=False, grow_to_sequence_length=0, num_workers=2, use_amp=False, **kwargs)View on GitHub#
evaluate(data_source, eval_batch_size, sequence_length)View on GitHub#
static load_checkpoint(checkpoint_file, corpus, optimizer=<class 'torch.optim.sgd.SGD'>)View on GitHub#
class flair.trainers.TextCorpus(path, dictionary, forward=True, character_level=True, random_case_flip=True, document_delimiter='\\n')View on GitHub#

Bases: object