Train a span classifier#

Span Classification models are used to model problems such as entity linking, where you already have extracted some relevant spans within the Sentence and want to predict some more fine-grained labels.

This tutorial section show you how to train models using the Span Classifier in Flair.

Training an entity linker (NEL) model with transformers#

For a state-of-the-art NER system you should fine-tune transformer embeddings, and use full document context (see our FLERT paper for details).

Use the following script:

from flair.datasets import ZELDA
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SpanClassifier
from flair.models.entity_linker_model import CandidateGenerator
from flair.trainers import ModelTrainer
from flair.nn.decoder import PrototypicalDecoder

# 1. get the corpus
corpus = ZELDA()

# 2. what label do we want to predict?
label_type = 'nel'

# 3. make the label dictionary from the corpus
label_dict = corpus.make_label_dictionary(label_type=label_type, add_unk=True)

# 4. initialize fine-tuneable transformer embeddings WITH document context
embeddings = TransformerWordEmbeddings(

# 5. initialize bare-bones sequence tagger (no CRF, no RNN, no reprojection)
tagger = SpanClassifier(
        embeddings_size=embeddings.embedding_length * 2, # we use "first_last" encoding for spans

# 6. initialize trainer
trainer = ModelTrainer(tagger, corpus)

# 7. run fine-tuning
    mini_batch_chunk_size=1,  # remove this parameter to speed up computation if you have a big GPU

As you can see, we use TransformerWordEmbeddings based on bert-base-uncased embeddings. We enable fine-tuning and set use_context to True. We use Prototypical Networks, to generalize bettwer in the few-shot classification setting. Also, we set a CandidateGenerator in the SpanClassifier. This way we limit the classification to a small set of candidates that are chosen depending on the text of the respective span.

Loading a ColumnCorpus#

In cases you want to train over a custom named entity linking dataset, you can load them with the ColumnCorpus object. Most sequence labeling datasets in NLP use some sort of column format in which each line is a word and each column is one level of linguistic annotation. See for instance this sentence:

George B-George_Washington
Washington I-George_Washington
went O
to O
Washington B-Washington_D_C

Sam B-Sam_Houston
Houston I-Sam_Houston
stayed O
home O

The first column is the word itself, the second BIO-annotated tags used to specify the spans that will be classified. To read such a dataset, define the column structure as a dictionary and instantiate a ColumnCorpus.

from import Corpus
from flair.datasets import ColumnCorpus

# define columns
columns = {0: "text", 1: "nel"}

# this is the folder in which train, test and dev files reside
data_folder = '/path/to/data/folder'

# init a corpus using column format, data folder and the names of the train, dev and test files
corpus: Corpus = ColumnCorpus(data_folder, columns)

constructing a dataset in memory#

If you have a pipeline where you need to construct your dataset from a different data source, you can always construct a Corpus with FlairDatapointDataset by hand. Let’s assume you create a function create_datapoint(datapoint) -> Sentence that looks somewhat like this:

from import Sentence

def create_sentence(datapoint) -> Sentence:
    tokens = ...  # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
    spans = ...  # create a list of tuples (start_token, end_token, label) from your data structure
    sentence = Sentence(tokens)
    for (start, end, label) in spans:
        sentence[start:end+1].add_label("nel", label)

Then you can use this function to create a full dataset:

from import Corpus
from flair.datasets import FlairDatapointDataset

def construct_corpus(data):
    return Corpus(
        train=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["train"])]),
        dev=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["dev"])]),
        test=FlairDatapointDataset([create_sentence(datapoint for datapoint in data["test"])]),

And use this to construct a corpus instead of loading a dataset.

Combining NEL with Mention Detection#

often, you don’t just want to use a Named Entity Linking model alone, but combine it with a Mention Detection or Named Entity Recognition model. For this, you can use a Multitask Model to combine a SequenceTagger and a Span Classifier.

from flair.datasets import NER_MULTI_WIKINER, ZELDA
from flair.embeddings import TransformerWordEmbeddings
from flair.models import SequenceTagger, SpanClassifier
from flair.models.entity_linker_model import CandidateGenerator
from flair.trainers import ModelTrainer
from flair.nn import PrototypicalDecoder
from flair.nn.multitask import make_multitask_model_and_corpus

# 1. get the corpus
ner_corpus = NER_MULTI_WIKINER()
nel_corpus = ZELDA(column_format={0: "text", 2: "nel"})  # need to set the label type to be the same as the ner one

# --- Embeddings that are shared by both models --- #
shared_embeddings = TransformerWordEmbeddings("distilbert-base-uncased", fine_tune=True)

ner_label_dict = ner_corpus.make_label_dictionary("ner", add_unk=False)

ner_model = SequenceTagger(

nel_label_dict = nel_corpus.make_label_dictionary("nel", add_unk=True)

nel_model = SpanClassifier(
        embeddings_size=shared_embeddings.embedding_length * 2, # we use "first_last" encoding for spans

# -- Define mapping (which tagger should train on which model) -- #
multitask_model, multicorpus = make_multitask_model_and_corpus(
        (ner_model, ner_corpus),
        (nel_model, nel_corpus),

# -- Create model trainer and train -- #
trainer = ModelTrainer(multitask_model, multicorpus)

Here, the make_multitask_model_and_corpus method creates a multitask model and a multicorpus where each sub-model is aligned for a sub-corpus.

Multitask with aligned training data#

If you have sentences with both annotations for ner and for nel, you might want to use a single corpus for both models.

This means, that you need to manually the multitask_id to the sentences:

from import Sentence

def create_sentence(datapoint) -> Sentence:
    tokens = ...  # calculate the tokens from your internal data structure (e.g. pandas dataframe or json dictionary)
    spans = ...  # create a list of tuples (start_token, end_token, label) from your data structure
    sentence = Sentence(tokens)
    for (start, end, ner_label, nel_label) in spans:
        sentence[start:end+1].add_label("ner", ner_label)
        sentence[start:end+1].add_label("nel", nel_label)
    sentence.add_label("multitask_id", "Task_0")  # Task_0 for the NER model
    sentence.add_label("multitask_id", "Task_1")  # Task_1 for the NEL model

Then you can run the multitask training script with the exception that you create the MultitaskModel directly.

multitask_model = MultitaskModel([ner_model, nel_model], use_all_tasks=True)

Here, setting use_all_tasks=True means that we will jointly train on both tasks at the same time. This will save a lot of training time, as the shared embedding will be calculated once but used twice (once for each model).