A simple crf module written in pytorch. The implementation is based

pip install pytorch-text-crf==0.1


PyTorch Text CRF

This package contains a simple wrapper for using conditional random fields(CRF). This code is based on the excellent Allen NLP implementation of CRF.


pip install pytorch-text-crf


from crf.crf import ConditionalRandomField

# Initilization
crf = ConditionalRandomField(n_tags,
                            idx2tag={0:"B-GEO", 1:"I-GEO", 2:"0"} # Index to tag mapping
# Likelihood estimation
log_likelihood = crf(logits, tags, mask)

# Decoding
best_tag_sequence = crf.best_viterbi_tag(logits, mask)
top_5_viterbi_tags = crf.viterbi_tags(logits, mask, top_k=5)

LSTM CRF Implementation

Refer to for a complete working implementation.

from crf.crf import ConditionalRandomField

class LSTMCRF:
    An Example implementation for using a CRF model on top of LSTM.
    def __init__(self):
        # Initilize the conditional CRF model
        self.crf = ConditionalRandomField(
            n_class, # Number of tags
            label_encoding="BIO", # Label encoding format
            idx2tag=idx2tag # Dict mapping index to a tag

    def forward(self, inputs, tags):
        logits = self.lstm(inputs) # logits dim:(batch_size, seq_length, num_tags)
        mask = inputs != "<pad token>" # mask for ignoring pad tokens. mask dim: (batch_size, seq_length)
        log_likelihood = self.crf(logits, tags, mask)
        loss = -log_likelihood # Log likelihood is not normalized (It is not divided by the batch size).

        # To obtain the best sequence using viterbi decoding
        best_tag_sequence = self.crf.best_viterbi_tag(logits, mask)

        # To obtain output similar to the lstm prediction we can use the below code
        class_probabilities = out * 0.0
        for i, instance_tags in enumerate(best_tag_sequence):
            for j, tag_id in enumerate(instance_tags[0][0]):
                class_probabilities[i, j, int(tag_id)] = 1
        return {"loss": loss, "class_probabilities": class_probabilities} 

 # Training
 lstm_crf = LSTMCRF()
 output = lstm_crf(sentences, tags)
 loss = output["loss"]