HunFlair2 Tutorial 3: Training NER models#
This part of the tutorial shows how you can train your own biomedical named entity recognition models using state-of-the-art pretrained Transformers embeddings.
For this tutorial, we assume that you’re familiar with the base types of Flair and how transformers_word embeddings. You should also know how to load a corpus.
Train a biomedical NER model from scratch: single entity type#
Here is example code for a biomedical NER model trained on the NCBI_DISEASE
corpus using Transformer word embeddings.
This will result in a tagger specialized for a single entity type, i.e. “Disease”.
from flair.datasets import NCBI_DISEASE
# 1. get the corpus
corpus = NCBI_DISEASE()
print(corpus)
# 2. make the tag dictionary from the corpus
tag_dictionary = corpus.make_label_dictionary(label_type="ner", add_unk=False)
# 3. initialize embeddings
from flair.embeddings import TransformerWordEmbeddings
embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings(
"michiyasunaga/BioLinkBERT-base",
layers="-1",
subtoken_pooling="first",
fine_tune=True,
use_context=True,
model_max_length=512,
)
# 4. initialize sequence tagger
from flair.models import SequenceTagger
tagger: SequenceTagger = SequenceTagger(
hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_format="BIOES",
tag_type="ner",
use_crf=False,
use_rnn=False,
reproject_embeddings=False,
)
# 5. initialize trainer
from flair.trainers import ModelTrainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.fine_tune(
base_path="taggers/ncbi-disease",
train_with_dev=False,
max_epochs=16,
learning_rate=2.0e-5,
mini_batch_size=16,
shuffle=False,
)
Once the model is trained you can use it to predict tags for new sentences. Just call the predict method of the model.
# load the model you trained
model = SequenceTagger.load("taggers/ncbi-disease/best-model.pt")
# create example sentence
from flair.data import Sentence
sentence = Sentence("Women who smoke 20 cigarettes a day are four times more likely to develop breast cancer.")
# predict tags and print
model.predict(sentence)
print(sentence.to_tagged_string())
If the model works well, it will correctly tag “breast cancer” as disease in this example:
Women who smoke 20 cigarettes a day are four times more likely to develop breast <B-Disease> cancer <E-Disease> .
Train a biomedical NER model: multiple entity types#
If you are dealing with multiple entity types, e.g. “Disease” and “Chemicals”, you can opt
to train a single model capable of handling multiple entity types at once.
This can be achieved by using the PrefixedSequenceTagger()
class, which implements the method described in [1].
This model uses prompting, i.e. it adds a prefix (hence the name) string in front of specifying the
entity types requested for tagging: [Tag <entity-type-0>, <entity-type-1>, ...]
.
This is especially useful for training, as it allows to combine multiple corpora even if they cover different subsets of entity types.
# 1. get the corpora
from flair.datasets.biomedical import HUNER_ALL_CDR, HUNER_CHEMICAL_NLM_CHEM
corpora = (HUNER_ALL_CDR(), HUNER_CHEMICAL_NLM_CHEM())
# 2. add prefixed strings to each corpus by prepending its tagged entity
# types "[Tag <entity-type-0>, <entity-type-1>, ...]"
from flair.data import MultiCorpus
from flair.models.prefixed_tagger import EntityTypeTaskPromptAugmentationStrategy
from flair.datasets.biomedical import (
BIGBIO_NER_CORPUS,
CELL_LINE_TAG,
CHEMICAL_TAG,
DISEASE_TAG,
GENE_TAG,
SPECIES_TAG,
)
mapping = {
CELL_LINE_TAG: "cell lines",
CHEMICAL_TAG: "chemicals",
DISEASE_TAG: "diseases",
GENE_TAG: "genes",
SPECIES_TAG: "species",
}
prefixed_corpora = []
all_entity_types = set()
for corpus in corpora:
entity_types = sorted(
set(
[
mapping[tag]
for tag in corpus.get_entity_type_mapping().values()
]
)
)
all_entity_types.update(set(entity_types))
print(f"Entity types in {corpus}: {entity_types}")
augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy(
entity_types
)
prefixed_corpora.append(
augmentation_strategy.augment_corpus(corpus)
)
corpora = MultiCorpus(prefixed_corpora)
all_entity_types = sorted(all_entity_types)
# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_label_dictionary(label_type="ner")
# 4. the final model will on default predict all the entity types seen
# in the training corpora, e.g., disease and chemicals here
augmentation_strategy = EntityTypeTaskPromptAugmentationStrategy(
all_entity_types
)
# 5. initialize embeddings
from flair.embeddings import TransformerWordEmbeddings
embeddings: TransformerWordEmbeddings = TransformerWordEmbeddings(
"michiyasunaga/BioLinkBERT-base",
layers="-1",
subtoken_pooling="first",
fine_tune=True,
use_context=True,
model_max_length=512,
)
# 4. initialize sequence tagger
from flair.models.prefixed_tagger import PrefixedSequenceTagger
tagger: SequenceTagger = PrefixedSequenceTagger(
hidden_size=256,
embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_format="BIOES",
tag_type="ner",
use_crf=False,
use_rnn=False,
reproject_embeddings=False,
augmentation_strategy=augmentation_strategy,
)
# 5. initialize trainer
from flair.trainers import ModelTrainer
trainer: ModelTrainer = ModelTrainer(tagger, corpus)
trainer.fine_tune(
base_path="taggers/cdr_nlm_chem",
train_with_dev=False,
max_epochs=16,
learning_rate=2.0e-5,
mini_batch_size=16,
shuffle=False,
)
Training HunFlair2 from scratch#
HunFlair2 uses the PrefixedSequenceTagger()
class as defined above but adds the following corpora to the training set instead:
from flair.datasets.biomedical import (
HUNER_ALL_BIORED, HUNER_GENE_NLM_GENE,
HUNER_GENE_GNORMPLUS, HUNER_ALL_SCAI,
HUNER_CHEMICAL_NLM_CHEM, HUNER_SPECIES_LINNEAUS,
HUNER_SPECIES_S800, HUNER_DISEASE_NCBI
)
corpora = (
HUNER_ALL_BIORED(), HUNER_GENE_NLM_GENE(),
HUNER_GENE_GNORMPLUS(), HUNER_ALL_SCAI(),
HUNER_CHEMICAL_NLM_CHEM(), HUNER_SPECIES_LINNEAUS(),
HUNER_SPECIES_S800(), HUNER_DISEASE_NCBI()
)
References#
[1] Luo, L., Wei, C. H., Lai, P. T., Leaman, R., Chen, Q., & Lu, Z. (2023). AIONER: all-in-one scheme-based biomedical named entity recognition using deep learning. Bioinformatics, 39(5), btad310.