HunFlair2 Tutorial 3: Training NER models#
This part of the tutorial shows how you can train your own biomedical named entity recognition models using state-of-the-art pretrained Transformers embeddings.
For this tutorial, we assume that you’re familiar with the base types of Flair and how transformers_word embeddings. You should also know how to load a corpus.
Train a biomedical NER model from scratch: single entity type#
Here is example code for a biomedical NER model trained on the NCBI_DISEASE corpus using Transformer word embeddings.
This will result in a tagger specialized for a single entity type, i.e. “Disease”.
from flair.datasets import NCBI_DISEASE
# 1. get the corpus
corpus = NCBI_DISEASE()
print(corpus)
# 2. make the tag dictionary from the corpus
tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False)
# 3. initialize embeddings
from flair.embeddings import TransformerWordEmbeddings
embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings(
    "michiyasunaga/BioLinkBERT-base",
    layers="-1",
    subtoken_pooling="first",
    fine_tune=True,
    use_context=True,
    model_max_length=512,
)
# 4. initialize sequence tagger
from flair.models import SequenceTagger
tagger: SequenceTagger = SequenceTagger(
    hidden_size=256,
    embeddings=embeddings,
    tag_dictionary=tag_dictionary,
    tag_format="BIOES",
    tag_type="ner",
    use_crf=False,
    use_rnn=False,
    reproject_embeddings=False,
)
# 5. initialize trainer
from flair.trainers import ModelTrainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.fine_tune(
    base_path="taggers/ncbi-disease",
    train_with_dev=False,
    max_epochs=16,
    learning_rate=2.0e-5,
    mini_batch_size=16,
    shuffle=False,
)
Once the model is trained you can use it to predict tags for new sentences. Just call the predict method of the model.
# load the model you trained
model = SequenceTagger.load("taggers/ncbi-disease/best-model.pt")
# create example sentence
from flair.data import Sentence
sentence = Sentence("Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer.")
# predict tags and print
model.predict(sentence)
print(sentence.to_tagged_string())
If the model works well, it will correctly tag “breast cancer” as disease in this example:
Women who smoke 20 cigarettes a day are four times more likely to develop breast <B-Disease> cancer <E-Disease> .
Train a biomedical NER model: multiple entity types#
If you are dealing with multiple entity types, e.g. “Disease” and “Chemicals”, you can opt
to train a single model capable of handling multiple entity types at once.
This can be achieved by using the  PrefixedSequenceTagger() class, which implements the method described in [1].
This model uses prompting, i.e. it adds a prefix (hence the name) string in front of specifying the
entity types requested for tagging: [Tag <entity-type-0>, <entity-type-1>, ...].
This is especially useful for training, as it allows to combine multiple corpora even if they cover different subsets of entity types.
# 1. get the corpora
from flair.datasets.biomedical import HUNER_ALL_CDR, HUNER_CHEMICAL_NLM_CHEM
corpora = (HUNER_ALL_CDR(), HUNER_CHEMICAL_NLM_CHEM())
# 2. add prefixed strings to each corpus by prepending its tagged entity
#    types "[Tag <entity-type-0>, <entity-type-1>, ...]"
from flair.data import MultiCorpus
from flair.models.prefixed_tagger import EntityTypeTaskPromptAugmentationStrategy
from flair.datasets.biomedical import (
    BIGBIO_NER_CORPUS,
    CELL_LINE_TAG,
    CHEMICAL_TAG,
    DISEASE_TAG,
    GENE_TAG,
    SPECIES_TAG,
)
mapping = {
    CELL_LINE_TAG: "cell lines",
    CHEMICAL_TAG: "chemicals",
    DISEASE_TAG: "diseases",
    GENE_TAG: "genes",
    SPECIES_TAG: "species",
}
prefixed_corpora = []
all_entity_types = set()
for corpus in corpora:
    entity_types = sorted(
        set(
            [
                mapping[tag]
                for tag in corpus.get_entity_type_mapping().values()
            ]
        )
    )
    all_entity_types.update(set(entity_types))
    print(f"Entity types in {corpus}: {entity_types}")
    augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy(
        entity_types
    )
    prefixed_corpora.append(
        augmentation_strategy.augment_corpus(corpus)
    )
corpora = MultiCorpus(prefixed_corpora)
all_entity_types = sorted(all_entity_types)
# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_label_dictionary(label_type="ner")
# 4. the final model will on default predict all the entity types seen
#    in the training corpora, e.g., disease and chemicals here
augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy(
    all_entity_types
)
# 5. initialize embeddings
from flair.embeddings import TransformerWordEmbeddings
embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings(
    "michiyasunaga/BioLinkBERT-base",
    layers="-1",
    subtoken_pooling="first",
    fine_tune=True,
    use_context=True,
    model_max_length=512,
)
# 4. initialize sequence tagger
from flair.models.prefixed_tagger import PrefixedSequenceTagger
tagger: SequenceTagger = PrefixedSequenceTagger(
    hidden_size=256,
    embeddings=embeddings,
    tag_dictionary=tag_dictionary,
    tag_format="BIOES",
    tag_type="ner",
    use_crf=False,
    use_rnn=False,
    reproject_embeddings=False,
    augmentation_strategy=augmentation_strategy,
)
# 5. initialize trainer
from flair.trainers import ModelTrainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.fine_tune(
    base_path="taggers/cdr_nlm_chem",
    train_with_dev=False,
    max_epochs=16,
    learning_rate=2.0e-5,
    mini_batch_size=16,
    shuffle=False,
)
Training HunFlair2 from scratch#
HunFlair2 uses the PrefixedSequenceTagger() class as defined above but adds the following corpora to the training set instead:
from flair.datasets.biomedical import (
    HUNER_ALL_BIORED, HUNER_GENE_NLM_GENE,
    HUNER_GENE_GNORMPLUS, HUNER_ALL_SCAI,
    HUNER_CHEMICAL_NLM_CHEM, HUNER_SPECIES_LINNEAUS,
    HUNER_SPECIES_S800, HUNER_DISEASE_NCBI
)
corpora = (
    HUNER_ALL_BIORED(), HUNER_GENE_NLM_GENE(),
    HUNER_GENE_GNORMPLUS(), HUNER_ALL_SCAI(),
    HUNER_CHEMICAL_NLM_CHEM(), HUNER_SPECIES_LINNEAUS(),
    HUNER_SPECIES_S800(), HUNER_DISEASE_NCBI()
)
References#
[1] Luo, L., Wei, C. H., Lai, P. T., Leaman, R., Chen, Q., & Lu, Z. (2023). AIONER: all-in-one scheme-based biomedical named entity recognition using deep learning. Bioinformatics, 39(5), btad310.