flair.trainers.plugins#

class flair.trainers.plugins.AnnealingPlugin(base_path, min_learning_rate, anneal_factor, patience, initial_extra_patience, anneal_with_restarts)View on GitHub#

Bases: TrainerPlugin

Plugin for annealing logic in Flair.

store_learning_rate()View on GitHub#
after_setup(train_with_dev, optimizer, **kw)View on GitHub#

Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers.

after_evaluation(current_model_is_best, validation_scores, **kw)View on GitHub#

Scheduler step of AnnealOnPlateau.

get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.CheckpointPlugin(save_model_each_k_epochs, save_optimizer_state, base_path)View on GitHub#

Bases: TrainerPlugin

after_training_epoch(epoch, **kw)View on GitHub#

Saves the model each k epochs.

get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.ClearmlLoggerPlugin(task_id_or_task)View on GitHub#

Bases: TrainerPlugin

property logger#
metric_recorded(record)View on GitHub#
Return type:

None

class flair.trainers.plugins.LinearSchedulerPlugin(warmup_fraction)View on GitHub#

Bases: TrainerPlugin

Plugin for LinearSchedulerWithWarmup.

store_learning_rate()View on GitHub#
after_setup(dataset_size, mini_batch_size, max_epochs, **kwargs)View on GitHub#

Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers.

before_training_epoch(**kwargs)View on GitHub#

Load state for anneal_with_restarts, batch_growth_annealing, logic for early stopping.

after_training_batch(optimizer_was_run, **kwargs)View on GitHub#

Do the scheduler step if one-cycle or linear decay.

get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.WeightExtractorPlugin(base_path)View on GitHub#

Bases: TrainerPlugin

Simple Plugin for weight extraction.

after_training_batch(batch_no, epoch, total_number_of_batches, **kw)View on GitHub#

Extracts weights.

get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.LogFilePlugin(base_path)View on GitHub#

Bases: TrainerPlugin

Plugin for the training.log file.

close_file_handler(**kw)View on GitHub#
get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.LossFilePlugin(base_path, epoch, metrics_to_collect=None)View on GitHub#

Bases: TrainerPlugin

Plugin that manages the loss.tsv file output.

get_state()View on GitHub#
Return type:

dict[str, Any]

before_training_epoch(epoch, **kw)View on GitHub#

Get the current epoch for loss file logging.

metric_recorded(record)View on GitHub#

Add the metric of a record to the current row.

after_evaluation(epoch, **kw)View on GitHub#

This prints all relevant metrics.

class flair.trainers.plugins.MetricHistoryPlugin(metrics_to_collect={('dev', 'loss'): 'dev_loss_history', ('dev', 'score'): 'dev_score_history', ('train', 'loss'): 'train_loss_history'})View on GitHub#

Bases: TrainerPlugin

metric_recorded(record)View on GitHub#
after_training(**kw)View on GitHub#

Returns metric history.

get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.TensorboardLogger(log_dir=None, comment='', tracked_metrics=())View on GitHub#

Bases: TrainerPlugin

Plugin that takes care of tensorboard logging.

__init__(log_dir=None, comment='', tracked_metrics=())View on GitHub#

Initializes the TensorboardLogger.

Parameters:
  • log_dir – Directory into which tensorboard log files will be written

  • comment – The comment to specify Comment log_dir suffix appended to the default log_dir. If log_dir is assigned, this argument has no effect.

  • tracked_metrics – List of tuples that specify which metrics (in addition to the main_score) shall be plotted in tensorboard, could be [(“macro avg”, ‘f1-score’), (“macro avg”, ‘precision’)] for example

metric_recorded(record)View on GitHub#
_training_finally(**kw)View on GitHub#

Closes the writer.

get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.BasePluginView on GitHub#

Bases: object

Base class for all plugins.

__init__()View on GitHub#

Initialize the base plugin.

attach_to(pluggable)View on GitHub#

Attach this plugin to a Pluggable.

detach()View on GitHub#

Detach a plugin from the Pluggable it is attached to.

classmethod mark_func_as_hook(func, *events)View on GitHub#

Mark method as a hook triggered by the Pluggable.

Return type:

Callable

classmethod hook(first_arg=None, *other_args)View on GitHub#

Convience function for BasePlugin.mark_func_as_hook).

Enables using the @BasePlugin.hook syntax.

Can also be used as: @BasePlugin.hook(“some_event”, “another_event”)

Return type:

Callable

property pluggable: Pluggable | None#
get_state()View on GitHub#
Return type:

dict[str, Any]

class flair.trainers.plugins.Pluggable(*, plugins=[])View on GitHub#

Bases: object

Dispatches events which attached plugins can react to.

valid_events: Optional[set[str]] = None#
__init__(*, plugins=[])View on GitHub#

Initialize a Pluggable.

Parameters:

plugins (Sequence[Union[BasePlugin, type[BasePlugin]]]) – Plugins which should be attached to this Pluggable.

property plugins#
append_plugin(plugin)View on GitHub#
validate_event(*events)View on GitHub#
register_hook(func, *events)View on GitHub#

Register a hook.

Parameters:
  • func (Callable) – Function to be called when the event is emitted.

  • *events (str) – List of events to call this function on.

dispatch(event, *args, **kwargs)View on GitHub#

Call all functions hooked to a certain event.

Return type:

None

remove_hook(handle)View on GitHub#

Remove a hook handle from this instance.

class flair.trainers.plugins.TrainerPluginView on GitHub#

Bases: BasePlugin

property trainer#
property model#
property corpus#
exception flair.trainers.plugins.TrainingInterruptView on GitHub#

Bases: Exception

Allows plugins to interrupt the training loop.

class flair.trainers.plugins.ReduceTransformerVocabPlugin(base_path, save_optimizer_state)View on GitHub#

Bases: TrainerPlugin

register_transformer_smaller_training_vocab(**kw)View on GitHub#
save_model_at_the_end(**kw)View on GitHub#
class flair.trainers.plugins.MetricName(name)View on GitHub#

Bases: object

class flair.trainers.plugins.MetricRecord(name, value, global_step, typ, *, walltime=None)View on GitHub#

Bases: object

Represents a recorded metric value.

__init__(name, value, global_step, typ, *, walltime=None)View on GitHub#

Create a metric record.

Parameters:
  • name (Union[Iterable[str], str]) – Name of the metric.

  • typ (RecordType) – Type of metric.

  • value (Any) – Value of the metric (can be anything: scalar, tensor, image, etc.).

  • global_step (int) – The time_step of the log. This should be incremented the next time this metric is logged again. E.g. if you log every epoch, set the global_step to the current epoch.

  • walltime (Optional[float]) – Time of recording this metric.

property joined_name: str#
classmethod scalar(name, value, global_step, *, walltime=None)View on GitHub#
classmethod scalar_list(name, value, global_step, *, walltime=None)View on GitHub#
classmethod string(name, value, global_step, *, walltime=None)View on GitHub#
classmethod histogram(name, value, global_step, *, walltime=None)View on GitHub#
is_type(typ)View on GitHub#
property is_scalar#
property is_scalar_list#
property is_string#
property is_histogram#