flair.trainers.plugins#
- class flair.trainers.plugins.AnnealingPlugin(base_path, min_learning_rate, anneal_factor, patience, initial_extra_patience, anneal_with_restarts)View on GitHub#
Bases:
TrainerPlugin
Plugin for annealing logic in Flair.
- store_learning_rate()View on GitHub#
- after_setup(train_with_dev, optimizer, **kw)View on GitHub#
Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers.
- after_evaluation(current_model_is_best, validation_scores, **kw)View on GitHub#
Scheduler step of AnnealOnPlateau.
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.CheckpointPlugin(save_model_each_k_epochs, save_optimizer_state, base_path)View on GitHub#
Bases:
TrainerPlugin
- after_training_epoch(epoch, **kw)View on GitHub#
Saves the model each k epochs.
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.ClearmlLoggerPlugin(task_id_or_task)View on GitHub#
Bases:
TrainerPlugin
- property logger#
- metric_recorded(record)View on GitHub#
- Return type:
None
- class flair.trainers.plugins.LinearSchedulerPlugin(warmup_fraction)View on GitHub#
Bases:
TrainerPlugin
Plugin for LinearSchedulerWithWarmup.
- store_learning_rate()View on GitHub#
- after_setup(dataset_size, mini_batch_size, max_epochs, **kwargs)View on GitHub#
Initialize different schedulers, including anneal target for AnnealOnPlateau, batch_growth_annealing, loading schedulers.
- before_training_epoch(**kwargs)View on GitHub#
Load state for anneal_with_restarts, batch_growth_annealing, logic for early stopping.
- after_training_batch(optimizer_was_run, **kwargs)View on GitHub#
Do the scheduler step if one-cycle or linear decay.
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.WeightExtractorPlugin(base_path)View on GitHub#
Bases:
TrainerPlugin
Simple Plugin for weight extraction.
- after_training_batch(batch_no, epoch, total_number_of_batches, **kw)View on GitHub#
Extracts weights.
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.LogFilePlugin(base_path)View on GitHub#
Bases:
TrainerPlugin
Plugin for the training.log file.
- close_file_handler(**kw)View on GitHub#
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.LossFilePlugin(base_path, epoch, metrics_to_collect=None)View on GitHub#
Bases:
TrainerPlugin
Plugin that manages the loss.tsv file output.
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- before_training_epoch(epoch, **kw)View on GitHub#
Get the current epoch for loss file logging.
- metric_recorded(record)View on GitHub#
Add the metric of a record to the current row.
- after_evaluation(epoch, **kw)View on GitHub#
This prints all relevant metrics.
- class flair.trainers.plugins.MetricHistoryPlugin(metrics_to_collect={('dev', 'loss'): 'dev_loss_history', ('dev', 'score'): 'dev_score_history', ('train', 'loss'): 'train_loss_history'})View on GitHub#
Bases:
TrainerPlugin
- metric_recorded(record)View on GitHub#
- after_training(**kw)View on GitHub#
Returns metric history.
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.TensorboardLogger(log_dir=None, comment='', tracked_metrics=())View on GitHub#
Bases:
TrainerPlugin
Plugin that takes care of tensorboard logging.
- __init__(log_dir=None, comment='', tracked_metrics=())View on GitHub#
Initializes the TensorboardLogger.
- Parameters:
log_dir – Directory into which tensorboard log files will be written
comment – The comment to specify Comment log_dir suffix appended to the default
log_dir
. Iflog_dir
is assigned, this argument has no effect.tracked_metrics – List of tuples that specify which metrics (in addition to the main_score) shall be plotted in tensorboard, could be [(“macro avg”, ‘f1-score’), (“macro avg”, ‘precision’)] for example
- metric_recorded(record)View on GitHub#
- _training_finally(**kw)View on GitHub#
Closes the writer.
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.BasePluginView on GitHub#
Bases:
object
Base class for all plugins.
- __init__()View on GitHub#
Initialize the base plugin.
- attach_to(pluggable)View on GitHub#
Attach this plugin to a Pluggable.
- detach()View on GitHub#
Detach a plugin from the Pluggable it is attached to.
- classmethod mark_func_as_hook(func, *events)View on GitHub#
Mark method as a hook triggered by the Pluggable.
- Return type:
Callable
- classmethod hook(first_arg=None, *other_args)View on GitHub#
Convience function for BasePlugin.mark_func_as_hook).
Enables using the @BasePlugin.hook syntax.
Can also be used as: @BasePlugin.hook(“some_event”, “another_event”)
- Return type:
Callable
- get_state()View on GitHub#
- Return type:
dict
[str
,Any
]
- class flair.trainers.plugins.Pluggable(*, plugins=[])View on GitHub#
Bases:
object
Dispatches events which attached plugins can react to.
-
valid_events:
Optional
[set
[str
]] = None#
- __init__(*, plugins=[])View on GitHub#
Initialize a Pluggable.
- Parameters:
plugins (
Sequence
[Union
[BasePlugin
,type
[BasePlugin]]]) – Plugins which should be attached to this Pluggable.
- property plugins#
- append_plugin(plugin)View on GitHub#
- validate_event(*events)View on GitHub#
- register_hook(func, *events)View on GitHub#
Register a hook.
- Parameters:
func (
Callable
) – Function to be called when the event is emitted.*events (
str
) – List of events to call this function on.
- dispatch(event, *args, **kwargs)View on GitHub#
Call all functions hooked to a certain event.
- Return type:
None
- remove_hook(handle)View on GitHub#
Remove a hook handle from this instance.
-
valid_events:
- class flair.trainers.plugins.TrainerPluginView on GitHub#
Bases:
BasePlugin
- property trainer#
- property model#
- property corpus#
- exception flair.trainers.plugins.TrainingInterruptView on GitHub#
Bases:
Exception
Allows plugins to interrupt the training loop.
- class flair.trainers.plugins.ReduceTransformerVocabPlugin(base_path, save_optimizer_state)View on GitHub#
Bases:
TrainerPlugin
- register_transformer_smaller_training_vocab(**kw)View on GitHub#
- save_model_at_the_end(**kw)View on GitHub#
- class flair.trainers.plugins.MetricName(name)View on GitHub#
Bases:
object
- class flair.trainers.plugins.MetricRecord(name, value, global_step, typ, *, walltime=None)View on GitHub#
Bases:
object
Represents a recorded metric value.
- __init__(name, value, global_step, typ, *, walltime=None)View on GitHub#
Create a metric record.
- Parameters:
name (
Union
[Iterable
[str
],str
]) – Name of the metric.typ (
RecordType
) – Type of metric.value (
Any
) – Value of the metric (can be anything: scalar, tensor, image, etc.).global_step (
int
) – The time_step of the log. This should be incremented the next time this metric is logged again. E.g. if you log every epoch, set the global_step to the current epoch.walltime (
Optional
[float
]) – Time of recording this metric.
- property joined_name: str#
- classmethod scalar(name, value, global_step, *, walltime=None)View on GitHub#
- classmethod scalar_list(name, value, global_step, *, walltime=None)View on GitHub#
- classmethod string(name, value, global_step, *, walltime=None)View on GitHub#
- classmethod histogram(name, value, global_step, *, walltime=None)View on GitHub#
- is_type(typ)View on GitHub#
- property is_scalar#
- property is_scalar_list#
- property is_string#
- property is_histogram#