Train a text classifier#
This tutorial shows you how to train your own text classifier models with Flair. For instance, you could train your own sentiment analysis model, or offensive language detection model.
Training a text classification model with transformers#
For text classification, you reach state-of-the-art scores by fine-tuning a transformer.
Training a model is easy: load the appropriate corpus, make a label dictionary, then fine-tune a TextClassifier
model using the ModelTrainer.fine_tune()
method. See the example script below:
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
# 1. get the corpus
corpus: Corpus = TREC_6()
# 2. what label do we want to predict?
label_type = 'question_class'
# 3. create the label dictionary
label_dict = corpus.make_label_dictionary(label_type=label_type)
# 4. initialize transformer document embeddings (many models are available)
document_embeddings = TransformerDocumentEmbeddings('distilbert-base-uncased', fine_tune=True)
# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict, label_type=label_type)
# 6. initialize trainer
trainer = ModelTrainer(classifier, corpus)
# 7. run training with fine-tuning
trainer.fine_tune('resources/taggers/question-classification-with-transformer',
learning_rate=5.0e-5,
mini_batch_size=4,
max_epochs=10,
)
Once the model is trained you can load it to predict the class of new sentences. Just call the predict
method of the model.
classifier = TextClassifier.load('resources/taggers/question-classification-with-transformer/final-model.pt')
# create example sentence
sentence = Sentence('Who built the Eiffel Tower ?')
# predict class and print
classifier.predict(sentence)
print(sentence.labels)