flair.trainers#

class flair.trainers.ModelTrainer(model, corpus)View on GitHub#

Bases: Pluggable

valid_events: Optional[Set[str]] = {'_training_exception', '_training_finally', 'after_evaluation', 'after_setup', 'after_training', 'after_training_batch', 'after_training_epoch', 'after_training_loop', 'before_training_batch', 'before_training_epoch', 'before_training_optimizer_step', 'metric_recorded', 'training_interrupt'}#
__init__(model, corpus)View on GitHub#

Initialize a model trainer.

Parameters:
  • model (Model) – The model that you want to train. The model should inherit from flair.nn.Model # noqa: E501

  • corpus (Corpus) – The dataset used to train the model, should be of type Corpus

reset_training_attributes()View on GitHub#
static check_for_and_delete_previous_best_models(base_path)View on GitHub#
static get_batch_steps(batch, mini_batch_chunk_size)View on GitHub#
_backward(loss)View on GitHub#

Calls backward on the loss.

This allows plugins to overwrite the backward call.

train(base_path, anneal_factor=0.5, patience=3, min_learning_rate=0.0001, initial_extra_patience=0, anneal_with_restarts=False, learning_rate=0.1, decoder_learning_rate=None, mini_batch_size=32, eval_batch_size=64, mini_batch_chunk_size=None, max_epochs=100, optimizer=<class 'torch.optim.sgd.SGD'>, train_with_dev=False, train_with_test=False, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=False, gold_label_dictionary_for_eval=None, exclude_labels=[], sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='cpu', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, plugins=None, attach_default_scheduler=True, **kwargs)View on GitHub#
fine_tune(base_path, warmup_fraction=0.1, learning_rate=5e-05, decoder_learning_rate=None, mini_batch_size=4, eval_batch_size=16, mini_batch_chunk_size=None, max_epochs=10, optimizer=<class 'torch.optim.adamw.AdamW'>, train_with_dev=False, train_with_test=False, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=True, gold_label_dictionary_for_eval=None, exclude_labels=[], sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='none', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, use_amp=False, plugins=None, attach_default_scheduler=True, **kwargs)View on GitHub#
train_custom(base_path, learning_rate=0.1, decoder_learning_rate=None, mini_batch_size=32, eval_batch_size=64, mini_batch_chunk_size=None, max_epochs=100, optimizer=<class 'torch.optim.sgd.SGD'>, train_with_dev=False, train_with_test=False, max_grad_norm=5.0, reduce_transformer_vocab=False, main_evaluation_metric=('micro avg', 'f1-score'), monitor_test=False, monitor_train_sample=0.0, use_final_model_for_eval=False, gold_label_dictionary_for_eval=None, exclude_labels=[], sampler=None, shuffle=True, shuffle_first_epoch=True, embeddings_storage_mode='cpu', epoch=0, save_final_model=True, save_optimizer_state=False, save_model_each_k_epochs=0, create_file_logs=True, create_loss_file=True, write_weights=False, use_amp=False, plugins=[], **kwargs)View on GitHub#

Trains any class that implements the flair.nn.Model interface.

Parameters:
  • base_path (Union[Path, str]) – Main path to which all output during training is logged and models are saved

  • learning_rate (float) – The learning rate of the optimizer

  • decoder_learning_rate (Optional[float]) – Optional, if set, the decoder is trained with a separate learning rate

  • mini_batch_size (int) – Size of mini-batches during training

  • eval_batch_size (int) – Size of mini-batches during evaluation

  • mini_batch_chunk_size (Optional[int]) – If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes

  • max_epochs (int) – Maximum number of epochs to train. Terminates training if this number is surpassed.

  • optimizer (Type[Optimizer]) – The optimizer to use (typically SGD or Adam)

  • train_with_dev (bool) – If True, the data from dev split is added to the training data

  • train_with_test (bool) – If True, the data from test split is added to the training data

  • reduce_transformer_vocab (bool) – If True, temporary reduce the vocab size to limit ram usage during training.

  • main_evaluation_metric (Tuple[str, str]) – The metric to optimize (often micro-average or macro-average F1-score, or accuracy)

  • monitor_test (bool) – If True, test data is evaluated at end of each epoch

  • monitor_train_sample (float) – Set this to evaluate on a sample of the train data at the end of each epoch. If you set an int, it will sample this many sentences to evaluate on. If you set a float, it will sample a percentage of data points from train.

  • max_grad_norm (Optional[float]) – If not None, gradients are clipped to this value before an optimizer.step is called.

  • use_final_model_for_eval (bool) – If True, the final model is used for the final evaluation. If False, the model from the best epoch as determined by main_evaluation_metric is used for the final evaluation.

  • gold_label_dictionary_for_eval (Optional[Dictionary]) – Set to force evaluation to use a particular label dictionary

  • exclude_labels (List[str]) – Optionally define a list of labels to exclude from the evaluation

  • sampler (Optional[FlairSampler]) – You can pass a data sampler here for special sampling of data.

  • shuffle (bool) – If True, data is shuffled during training

  • shuffle_first_epoch (bool) – If True, data is shuffled during the first epoch of training

  • embeddings_storage_mode (str) – One of ‘none’ (all embeddings are deleted and freshly recomputed), ‘cpu’ (embeddings stored on CPU) or ‘gpu’ (embeddings stored on GPU)

  • epoch (int) – The starting epoch (normally 0 but could be higher if you continue training model)

  • save_final_model (bool) – If True, the final model is saved at the end of training.

  • save_optimizer_state (bool) – If True, the optimizer state is saved alongside the model

  • save_model_each_k_epochs (int) – Each k epochs, a model state will be written out. If set to ‘5’, a model will be saved each 5 epochs. Default is 0 which means no model saving.

  • create_file_logs (bool) – If True, logging output is written to a file

  • create_loss_file (bool) – If True, a loss file logging output is created

  • use_amp (bool) – If True, uses the torch automatic mixed precision

  • write_weights (bool) – If True, write weights to weights.txt on each batch logging event.

  • plugins (List[TrainerPlugin]) – Any additional plugins you want to pass to the trainer

  • **kwargs – Additional arguments, for instance for the optimizer

Return type:

dict

Returns:

A dictionary with at least the key “test_score” containing the final evaluation score. Some plugins add additional information to this dictionary, such as the flair.trainers.plugins.MetricHistoryPlugin

_initialize_model_card(**training_parameters)View on GitHub#

Initializes model card with library versions and parameters.

class flair.trainers.LanguageModelTrainer(model, corpus, optimizer=<class 'torch.optim.sgd.SGD'>, test_mode=False, epoch=0, split=0, loss=10000, optimizer_state=None, scaler_state=None)View on GitHub#

Bases: object

train(base_path, sequence_length, learning_rate=20, mini_batch_size=100, anneal_factor=0.25, patience=10, clip=0.25, max_epochs=1000, checkpoint=False, grow_to_sequence_length=0, num_workers=2, use_amp=False, **kwargs)View on GitHub#
evaluate(data_source, eval_batch_size, sequence_length)View on GitHub#
static _repackage_hidden(h)View on GitHub#

Wraps hidden states in new tensors, to detach them from their history.

static load_checkpoint(checkpoint_file, corpus, optimizer=<class 'torch.optim.sgd.SGD'>)View on GitHub#
class flair.trainers.TextCorpus(path, dictionary, forward=True, character_level=True, random_case_flip=True, document_delimiter='\\n')View on GitHub#

Bases: object