flair.models.RelationClassifier#
- class flair.models.RelationClassifier(embeddings, label_dictionary, label_type, entity_label_types, entity_pair_labels=None, entity_threshold=None, cross_augmentation=True, encoding_strategy=<flair.models.relation_classifier_model.TypedEntityMarker object>, zero_tag_value='O', allow_unk_tag=True, **classifierargs)View on GitHub#
Bases:
DefaultClassifier
[EncodedSentence
,EncodedSentence
]Relation Classifier to predict the relation between two entities.
Task#
Relation Classification (RC) is the task of identifying the semantic relation between two entities in a text. In contrast to (end-to-end) Relation Extraction (RE), RC requires pre-labelled entities.
Example:#
For the founded_by relation from ORG (head) to PER (tail) and the sentence “Larry Page and Sergey Brin founded Google .”, we extract the relations - founded_by(head=’Google’, tail=’Larry Page’) and - founded_by(head=’Google’, tail=’Sergey Brin’).
Architecture#
The Relation Classifier Model builds upon a text classifier. The model generates an encoded sentence for each entity pair in the cross product of all entities in the original sentence. In the encoded representation, the entities in the current entity pair are masked/marked with control tokens. (For an example, see the docstrings of different encoding strategies, e.g.
TypedEntityMarker
.) Then, for each encoded sentence, the model takes its document embedding and puts the resulting text representation(s) through a linear layer to get the class relation label.The implemented encoding strategies are taken from this paper by Zhou et al.: https://arxiv.org/abs/2102.01373
Warning
Currently, the model has no multi-label support.
- __init__(embeddings, label_dictionary, label_type, entity_label_types, entity_pair_labels=None, entity_threshold=None, cross_augmentation=True, encoding_strategy=<flair.models.relation_classifier_model.TypedEntityMarker object>, zero_tag_value='O', allow_unk_tag=True, **classifierargs)View on GitHub#
Initializes a RelationClassifier.
- Parameters:
embeddings (
DocumentEmbeddings
) – The document embeddings used to embed each sentencelabel_dictionary (
Dictionary
) – A Dictionary containing all predictable labels from the corpuslabel_type (
str
) – The label type which is going to be predicted, in case a corpus has multiple annotationsentity_label_types (
Union
[str
,Sequence
[str
],dict
[str
,Optional
[set
[str
]]]]) – A label type or sequence of label types of the required relation entities. You can also specify a label filter in a dictionary with the label type as key and the valid entity labels as values in a set. E.g. to use only ‘PER’ and ‘ORG’ labels from a NER-tagger: {‘ner’: {‘PER’, ‘ORG’}}. To use all labels from ‘ner’, pass ‘ner’.entity_pair_labels (
Optional
[set
[tuple
[str
,str
]]]) – A set of valid relation entity pair combinations, used as relation candidates. Specify valid entity pairs in a set of tuples of labels (<HEAD>, <TAIL>). E.g. for the born_in relation, only relations from ‘PER’ to ‘LOC’ make sense. Here, relations from ‘PER’ to ‘PER’ are not meaningful, so it is advised to specify the entity_pair_labels as {(‘PER’, ‘ORG’)}. This setting may help to reduce the number of relation candidates. Leaving this parameter as None (default) disables the relation-candidate-filter, i.e. the model classifies the relation for each entity pair in the cross product of all entity pairs (inefficient).entity_threshold (
Optional
[float
]) – Only pre-labelled entities above this threshold are taken into account by the model.cross_augmentation (
bool
) – If True, use cross augmentation to transform Sentence`s into `EncodedSentenece`s. When cross augmentation is enabled, the transformation functions, e.g. `transform_corpus, generate an encoded sentence for each entity pair in the cross product of all entities in the original sentence. When disabling cross augmentation, the transform functions only generate encoded sentences for each gold relation annotation in the original sentence.encoding_strategy (
EncodingStrategy
) – An instance of a class conforming theEncodingStrategy
protocolzero_tag_value (
str
) – The label to use for out-of-class relationsallow_unk_tag (
bool
) – If False, removes <unk> from the passed label dictionary, otherwise do nothing.classifierargs – The remaining parameters passed to the underlying
flair.models.DefaultClassifier
Methods
__init__
(embeddings, label_dictionary, ...)Initializes a RelationClassifier.
add_module
(name, module)Add a child module to the current module.
apply
(fn)Apply
fn
recursively to every submodule (as returned by.children()
) as well as self.bfloat16
()Casts all floating point parameters and buffers to
bfloat16
datatype.buffers
([recurse])Return an iterator over module buffers.
children
()Return an iterator over immediate children modules.
compile
(*args, **kwargs)Compile this Module's forward using
torch.compile()
.cpu
()Move all model parameters and buffers to the CPU.
cuda
([device])Move all model parameters and buffers to the GPU.
double
()Casts all floating point parameters and buffers to
double
datatype.eval
()Set the module in evaluation mode.
evaluate
(data_points, gold_label_type[, ...])Evaluates the model.
extra_repr
()Set the extra representation of the module.
float
()Casts all floating point parameters and buffers to
float
datatype.forward
(*input)Define the computation performed at every call.
forward_loss
(sentences)Performs a forward pass and returns a loss tensor for backpropagation.
get_buffer
(target)Return the buffer given by
target
if it exists, otherwise throw an error.get_extra_state
()Return any extra state to include in the module's state_dict.
get_parameter
(target)Return the parameter given by
target
if it exists, otherwise throw an error.get_submodule
(target)Return the submodule given by
target
if it exists, otherwise throw an error.get_used_tokens
(corpus[, context_length, ...])half
()Casts all floating point parameters and buffers to
half
datatype.ipu
([device])Move all model parameters and buffers to the IPU.
load
(model_path)Loads a Flair model from the given file or state dictionary.
load_state_dict
(state_dict[, strict, assign])Copy parameters and buffers from
state_dict
into this module and its descendants.modules
()Return an iterator over all modules in the network.
mtia
([device])Move all model parameters and buffers to the MTIA.
named_buffers
([prefix, recurse, ...])Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children
()Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules
([memo, prefix, remove_duplicate])Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters
([prefix, recurse, ...])Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters
([recurse])Return an iterator over module parameters.
predict
(sentences[, mini_batch_size, ...])Predicts the class labels for the given sentence(s).
print_model_card
()This method produces a log message that includes all recorded parameters the model was trained with.
register_backward_hook
(hook)Register a backward hook on the module.
register_buffer
(name, tensor[, persistent])Add a buffer to the module.
register_forward_hook
(hook, *[, prepend, ...])Register a forward hook on the module.
register_forward_pre_hook
(hook, *[, ...])Register a forward pre-hook on the module.
register_full_backward_hook
(hook[, prepend])Register a backward hook on the module.
register_full_backward_pre_hook
(hook[, prepend])Register a backward pre-hook on the module.
register_load_state_dict_post_hook
(hook)Register a post-hook to be run after module's
load_state_dict()
is called.register_load_state_dict_pre_hook
(hook)Register a pre-hook to be run before module's
load_state_dict()
is called.register_module
(name, module)Alias for
add_module()
.register_parameter
(name, param)Add a parameter to the module.
register_state_dict_post_hook
(hook)Register a post-hook for the
state_dict()
method.register_state_dict_pre_hook
(hook)Register a pre-hook for the
state_dict()
method.requires_grad_
([requires_grad])Change if autograd should record operations on parameters in this module.
save
(model_file[, checkpoint])Saves the current model to the provided file.
set_extra_state
(state)Set extra state contained in the loaded state_dict.
set_submodule
(target, module)Set the submodule given by
target
if it exists, otherwise throw an error.share_memory
()See
torch.Tensor.share_memory_()
.state_dict
(*args[, destination, prefix, ...])Return a dictionary containing references to the whole state of the module.
to
(*args, **kwargs)Move and/or cast the parameters and buffers.
to_empty
(*, device[, recurse])Move the parameters and buffers to the specified device without copying storage.
train
([mode])Set the module in training mode.
transform_corpus
(corpus)Transforms a corpus into a corpus containing encoded sentences specific to the RelationClassifier.
transform_dataset
(dataset)Transforms a dataset into a dataset containing encoded sentences specific to the RelationClassifier.
transform_sentence
(sentences)Transforms sentences into encoded sentences specific to the RelationClassifier.
type
(dst_type)Casts all parameters and buffers to
dst_type
.xpu
([device])Move all model parameters and buffers to the XPU.
zero_grad
([set_to_none])Reset gradients of all model parameters.
Attributes
T_destination
call_super_init
dump_patches
Each model predicts labels of a certain type.
model_card
multi_label_threshold
training
- transform_sentence(sentences)View on GitHub#
Transforms sentences into encoded sentences specific to the RelationClassifier.
For more information on the internal sentence transformation procedure, see the
flair.models.RelationClassifier
architecture and the differentflair.models.relation_classifier_model.EncodingStrategy
variants docstrings.
- transform_dataset(dataset)View on GitHub#
Transforms a dataset into a dataset containing encoded sentences specific to the RelationClassifier.
The returned dataset is stored in memory. For more information on the internal sentence transformation procedure, see the
RelationClassifier
architecture and the differentEncodingStrategy
variants docstrings.- Parameters:
dataset (
Dataset
[Sentence
]) – A dataset of sentences to transform- Return type:
FlairDatapointDataset
[EncodedSentence
]
Returns: A dataset of encoded sentences specific to the RelationClassifier
- transform_corpus(corpus)View on GitHub#
Transforms a corpus into a corpus containing encoded sentences specific to the RelationClassifier.
The splits of the returned corpus are stored in memory. For more information on the internal sentence transformation procedure, see the
RelationClassifier
architecture and the differentEncodingStrategy
variants docstrings.- Parameters:
corpus (
Corpus
[Sentence
]) – A corpus of sentences to transform- Return type:
Corpus
[EncodedSentence
]
Returns: A corpus of encoded sentences specific to the RelationClassifier
- predict(sentences, mini_batch_size=32, return_probabilities_for_all_classes=False, verbose=False, label_name=None, return_loss=False, embedding_storage_mode='none')View on GitHub#
Predicts the class labels for the given sentence(s).
Standard Sentence objects and EncodedSentences specific to the RelationClassifier are allowed as input. The (relation) labels are directly added to the sentences.
- Parameters:
sentences (
Union
[list
[Sentence
],list
[EncodedSentence
],Sentence
,EncodedSentence
]) – A list of (encoded) sentences.mini_batch_size (
int
) – The mini batch size to usereturn_probabilities_for_all_classes (
bool
) – Return probabilities for all classes instead of only best predictedverbose (
bool
) – Set to display a progress barreturn_loss (
bool
) – Set to return losslabel_name (
Optional
[str
]) – Set to change the predicted label type nameembedding_storage_mode (
Literal
['none'
,'cpu'
,'gpu'
]) – The default is ‘none’, which is always best. Only set to ‘cpu’ or ‘gpu’ if you wish to predict and keep the generated embeddings in CPU or GPU memory, respectively.
- Return type:
Optional
[tuple
[Tensor
,int
]]
Returns: The loss and the total number of classes, if return_loss is set
- property label_type: str#
Each model predicts labels of a certain type.
- property zero_tag_value: str#
- property allow_unk_tag: bool#
- get_used_tokens(corpus, context_length=0, respect_document_boundaries=True)View on GitHub#
- Return type:
Iterable
[list
[str
]]
- classmethod load(model_path)View on GitHub#
Loads a Flair model from the given file or state dictionary.
- Parameters:
model_path (
Union
[str
,Path
,dict
[str
,Any
]]) – Either the path to the model (as string or Path variable) or the already loaded state dict.- Return type:
- Returns:
The loaded Flair model.