GeNN (generative neural networks) is a high-level interface for text applications using PyTorch RNN's.
- Parsing txt, json, and csv files.
- NLTK, regex and spacy tokenization support.
- GloVe and fastText pretrained embeddings, with the ability to fine-tune for your data.
- Architectures and customization:
- GPT2 with small, medium, and large variants.
- LSTM and GRU, with variable size.
- Variable number of layers and batches.
- Text generation:
- Random seed sampling from the n first tokens in all instances, or the most frequent token.
- Top-K sampling for next token prediction with variable K.
- Nucleus sampling for next token prediction with variable probability threshold.
- Text Summarization:
- All GPT2 variants can be trained to perform text summarization.
How to install
pip install genn
- PyTorch 1.4.0
pip install torch==1.4.0
- Pytorch Transformers
pip install pytorch_transformers
pip install numpy
pip install fasttext
Use the package manager pip to install genn.
RNNs (You can switch LSTMGenerator with GRUGenerator:
from genn import Preprocessing, LSTMGenerator, GRUGenerator #LSTM example ds = Preprocessing("data.txt") gen = LSTMGenerator(ds, nLayers = 2, batchSize = 16, embSize = 64, lstmSize = 16, epochs = 20) #Train the model gen.run() # Generate 5 new documents print(gen.generate_document(5))
#GPT2 example gen = GPT2("data.txt", taskToken = "Movie:", epochs = 7, variant = "medium") #Train the model gen.run() #Generate 10 new documents print(gen.generate_document(10))
#GPT2 Summarizer example from genn import GPT2Summarizer summ = GPT2Summarizer("data.txt", epochs=3, batch_size=8) #Train the model summ.run() #Create 5 summaries of a source document src_doc = "This is the source document to summarize" print(summ.summarize_document(n=5, source = src_doc))
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.
Distributed under the MIT License. See LICENSE for more information.