transformer-korean

Text-to-Text Transformer for Korean QA Task


License
MIT
Install
pip install transformer-korean==0.0.3

Documentation

Text-to-Text Transformer

๋ณธ repository์—์„œ๋Š” Google์˜ T5(T5: Text-To-Text Transfer Transformer)์˜ text-to-text ํ˜•ํƒœ๋กœ ํ•œ๊ตญ์–ด QA Task๋ฅผ ์œ„ํ•œ Transformer ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค. ์ „์ฒด ๋ชจ๋ธ์˜ ์•„ํ‚คํ…์ฒ˜๋Š” ๊ธฐ๋ณธ Transformer ๋ชจ๋ธ์„ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.

pip install transformer-korean
  • 2019.12.25, version 0.0.3 : load_data_txt, load_data_csv ์˜ค๋ฅ˜ ์ˆ˜์ •
  • 2019.12.23, version 0.0.1 : ์ตœ์ดˆ ๋ฆด๋ฆฌ์ฆˆ

0. Pre-training Model

  • Text-to-Text Transformer-Base, Korean Model: 12-layer, 768-hidden, 12-heads(๋น„๊ณต๊ฐœ)
  • Text-to-Text Transformer-Small, Korean Model: 6-layer, 512-hidden, 8-heads(๋น„๊ณต๊ฐœ)

Base This is our baseline model, whose hyperparameters are described in Section 3.1.1. It has roughly 220million parameters. Small. We consider a smaller model, which scales the baseline down by using dmodel= 512, dff= 2,048, 8-headed attention, and only 6layers each in the encoder and decoder. This varianthas about 60million parameters.

1. Pre-training

1.1 Unsupervised objective

T5 ๋…ผ๋ฌธ์—์„œ ๊ฐ€์žฅ ์„ฑ๋Šฅ์ด ์ž˜ ๋‚˜์˜จ๋‹ค๊ณ  ์„œ์ˆ ๋œ BERT Style Objective๋กœ ๋ฌธ์žฅ์„ ๊ตฌ์„ฑํ•˜์—ฌ, Pre-training ํ•˜๋„๋ก ๊ตฌ์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค. BERT์™€ ๋™์ผํ•˜๊ฒŒ ์ž…๋ ฅ ๋ฌธ์žฅ์˜ 15%๋ฅผ Random ํ•˜๊ฒŒ ๋งˆ์Šคํ‚น ์ฒ˜๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค. ๋งˆ์Šคํ‚น ๋Œ€์ƒ์˜ 80%๋Š” ํ† ํฐ์œผ๋กœ ๋Œ€์ฒดํ•˜๋ฉฐ, 10%๋Š” ์‚ฌ์ „ ๋‚ด ์ž„์˜์˜ ํ† ํฐ์œผ๋กœ ๋‚˜๋จธ์ง€ 10%๋Š” ์›๋ž˜์˜ ๋‹จ์–ด๋ฅผ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ–ˆ์Šต๋‹ˆ๋‹ค.

1.2 ๋ฌธ์žฅ ์˜ˆ์‹œ

Input ๋ฌธ์žฅ : 1900๋…„, <MASK> <MASK> ํ‘ธ์น˜๋‹ˆ์˜ ์˜คํŽ˜๋ผ ํ† ์Šค์นด๋กœ '๋‹ค์–‘ํ•˜๊ฒŒ' ๊ฐ์ƒ‰๋˜์—ˆ๋‹ค. (BERT Style)


Target ๋ฌธ์žฅ : 1900๋…„, ์‚ฌ๋ฅด๋‘์˜ ์—ฐ๊ทน์€ ํ‘ธ์น˜๋‹ˆ์˜ ์˜คํŽ˜๋ผ ํ† ์Šค์นด๋กœ ์ƒˆ๋กญ๊ฒŒ ๊ฐ์ƒ‰๋˜์—ˆ๋‹ค. (original text)

1.3 Unlabeld dataset

ํ•™์Šต ๋ฐ์ดํ„ฐ๋Š” ํ•œ๊ตญ์–ด ์œ„ํ‚ค๋ฐ์ดํ„ฐ(2019.01 dump file ๊ธฐ์ค€, ์•ฝ 350๋งŒ ๋ฌธ์žฅ) ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต์„ ์ง„ํ–‰ํ–ˆ์œผ๋ฉฐ, ํ•™์Šต ๋ฌธ์žฅ ๊ตฌ์„ฑ์€ ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค.

๋ผ ํ† ์Šค์นด(La Tosca)๋Š” 1887๋…„์— ํ”„๋ž‘์Šค ๊ทน์ž‘๊ฐ€ ์‚ฌ๋ฅด๋‘๊ฐ€ ๋ฐฐ์šฐ ์‚ฌ๋ผ ๋ฒ ๋ฅด๋‚˜๋ฅด๋ฅผ ์œ„ํ•ด ๋งŒ๋“  ์ž‘ํ’ˆ์ด๋‹ค.
1887๋…„ ํŒŒ๋ฆฌ์—์„œ ์ฒ˜์Œ ์ƒ์—ฐ๋˜์—ˆ๋‹ค.
1990๋…„ ๋ฒ ๋ฅด๋‚˜๋ฅด๋ฅผ ์ฃผ์ธ๊ณต์œผ๋กœ ๋ฏธ๊ตญ ๋‰ด์š•์—์„œ ์žฌ์ƒ์—ฐ๋˜์—ˆ๋‹ค.
1800๋…„ 6์›” ์ค‘์ˆœ์˜ ์ดํƒˆ๋ฆฌ์•„ ๋กœ๋งˆ๋ฅผ ๋ฐฐ๊ฒฝ์œผ๋กœ ํ•˜๋ฉฐ, ๋‹น์‹œ์˜ ์‹œ๋Œ€์  ์ƒํ™ฉ ํ•˜์—์„œ ์ด์•ผ๊ธฐ๊ฐ€ ์ „๊ฐœ๋œ๋‹ค.
1900๋…„, ์‚ฌ๋ฅด๋‘์˜ ์—ฐ๊ทน์€ ํ‘ธ์น˜๋‹ˆ์˜ ์˜คํŽ˜๋ผ ํ† ์Šค์นด๋กœ ์ƒˆ๋กญ๊ฒŒ ๊ฐ์ƒ‰๋˜์—ˆ๋‹ค.
๋ฒ ๋ฅด๋””๋Š” ์‚ฌ๋“œ๋ฃจ์˜ ๊ฐ๋ณธ์—์„œ "๊ฐ‘์ž‘์Šค๋Ÿฐ ์ข…๊ฒฐ" ๋ถ€๋ถ„์„ ์ˆ˜์ •ํ•  ๊ฒƒ์„ ๊ถŒํ•˜์ง€๋งŒ, ์‚ฌ๋ฅด๋ฃจ๋Š” ์ด๋ฅผ ๊ฑฐ์ ˆํ•œ๋‹ค.
ํ›„์—, ํ‘ธ์น˜๋‹ˆ ๋˜ํ•œ ์‚ฌ๋ฅด๋‘์˜ ๊ฐ๋ณธ์—์„œ "๊ฐ‘์ž‘์Šค๋Ÿฐ ์ข…๊ฒฐ๋ถ€๋ถ„"์„ ์ˆ˜์ •ํ•  ๊ฒƒ์„ ์ œ์•ˆํ•˜์ง€๋งŒ ๋๋‚ด ์‚ฌ๋ฅด๋‘๋ฅผ ์„ค๋“ํ•˜์ง€ ๋ชปํ–ˆ๋‹ค.

1.4 ํ•™์Šต ์˜ˆ

from transformer_korean.run_training import Trainer
from transformer_korean.transformer import Transformer
from transformer_korean.preprocess import DataProcessor
from transformer_korean.custom_scheduler import CustomSchedule
import tensorflow as tf

path = "ko-wiki_20190621.txt"
# Data Processing
print('Loading Pre-training data')
data_preprocess = DataProcessor(txt_path=path,
                                batch_size=64,
                                pre_train=True,
                                max_length=128)
train = data_preprocess.load_data_txt()

print('Loading Vocab File')
vocab = data_preprocess.load_vocab_file(vocab_filename="vocab")

print('Create train dataset')
train_dataset = data_preprocess.preprocess(train)

EPOCHS = 100
num_layers = 6
d_model = 128
dff = 512
num_heads = 8
vocab_size = vocab.vocab_size
dropout_rate = 0.1
encoder_activation = 'gelu'
decoder_activation = 'relu'

# Custom Scheduler
learning_rate = CustomSchedule(d_model, warmup_steps=4000)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

# Transformer
transformer = Transformer(d_model=d_model,
                          num_heads=num_heads,
                          num_layers=num_layers,
                          vocab_size=vocab_size,
                          dff=dff,
                          enc_activation=encoder_activation,
                          dec_activation=decoder_activation,
                          rate=dropout_rate)

# Trainer
trainer = Trainer(train_dataset=train_dataset,
                  learning_rate=learning_rate,
                  optimizer=optimizer,
                  transformer=transformer,
                  epochs=EPOCHS,
                  checkpoint_path='./checkpoints/',
                  load_checkpoints=False,
                  save_checkpoints_epochs=10)
trainer.train()

2.Fine-Tuning(QA Task)

2.1 Labeld dataset

QA Task๋ฅผ ์œ„ํ•ด ํ•œ๊ตญ์–ด QA Dataset์ธ KorQuAD 1.1์„ ์‚ฌ์šฉํ•˜์—ฌ Fine-Tuning ํ•˜๋„๋ก ๊ตฌ์„ฑํ–ˆ์Šต๋‹ˆ๋‹ค. ๋ฐ์ดํ„ฐ ๊ตฌ์„ฑ์€ ์•„๋ž˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. input์€ question, target์€ answer๊ฐ€ ๋˜๋„๋ก ํ–ˆ์Šต๋‹ˆ๋‹ค.

Q
๋ฐ”๊ทธ๋„ˆ๋Š” ๊ดดํ…Œ์˜ ํŒŒ์šฐ์ŠคํŠธ๋ฅผ ์ฝ๊ณ  ๋ฌด์—‡์„ ์“ฐ๊ณ ์ž ํ–ˆ๋Š”๊ฐ€?
๋ฐ”๊ทธ๋„ˆ๋Š” ๊ตํ–ฅ๊ณก ์ž‘๊ณก์„ ์–ด๋””๊นŒ์ง€ ์“ด ๋’ค์— ์ค‘๋‹จํ–ˆ๋Š”๊ฐ€?
๋ฐ”๊ทธ๋„ˆ๊ฐ€ ํŒŒ์šฐ์ŠคํŠธ ์„œ๊ณก์„ ์“ธ ๋•Œ ์–ด๋–ค ๊ณก์˜ ์˜ํ–ฅ์„ ๋ฐ›์•˜๋Š”๊ฐ€?
1839๋…„ ๋ฐ”๊ทธ๋„ˆ๊ฐ€ ๊ตํ–ฅ๊ณก์˜ ์†Œ์žฌ๋กœ ์“ฐ๋ ค๊ณ  ํ–ˆ๋˜ ์ฑ…์€?
ํŒŒ์šฐ์ŠคํŠธ ์„œ๊ณก์˜ ๋ผ๋‹จ์กฐ ์กฐ์„ฑ์ด ์˜ํ–ฅ์„ ๋ฐ›์€ ๋ฒ ํ† ๋ฒค์˜ ๊ณก์€?
A
๊ตํ–ฅ๊ณก
1์•…์žฅ
๋ฒ ํ† ๋ฒค์˜ ๊ตํ–ฅ๊ณก 9๋ฒˆ
ํŒŒ์šฐ์ŠคํŠธ
ํ•ฉ์ฐฝ๊ตํ–ฅ๊ณก

2.2 ํ•™์Šต ์˜ˆ

from transformer_korean.run_training import Trainer
from transformer_korean.transformer import Transformer
from transformer_korean.preprocess import DataProcessor
from transformer_korean.custom_scheduler import CustomSchedule

import tensorflow as tf

question = "KorQuAD_train_q.csv"
answer = "KorQuAD_train_a.csv"

# Data Processing
print('Loading fine-tuning data')
data_preprocess = DataProcessor(csv_path=[question, answer],
                                batch_size=64,
                                pre_train=False,
                                max_length= 128)
train = data_preprocess.load_data_csv()

print('Loading Vocab File')
vocab = data_preprocess.load_vocab_file(vocab_filename="vocab")

print('Create train dataset')
train_dataset = data_preprocess.preprocess(train)

EPOCHS = 100
num_layers = 6
d_model = 128
dff = 512
num_heads = 8
vocab_size = vocab.vocab_size
dropout_rate = 0.1
encoder_activation = 'gelu'
decoder_activation = 'relu'

# Custom Scheduler
learning_rate = CustomSchedule(d_model, warmup_steps=4000)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

# Transformer
transformer = Transformer(d_model=d_model,
                          num_heads=num_heads,
                          num_layers=num_layers,
                          vocab_size=vocab_size,
                          dff=dff,
                          enc_activation = encoder_activation,
                          dec_activation = decoder_activation,
                          rate=dropout_rate)

# Trainer
trainer = Trainer(train_dataset=train_dataset,
                  learning_rate=learning_rate,
                  optimizer=optimizer,
                  transformer=transformer,
                  epochs=EPOCHS,
                  checkpoint_path='./checkpoints/',
                  load_checkpoints=True,
                  save_checkpoints_epochs=10)

trainer.train()

3. Activation Function

๊ธฐ๋ณธ relu activation function ์™ธ์— 4๊ฐœ์˜ activation function ์ถ”๊ฐ€ํ•˜์˜€์œผ๋ฉฐ, Encoder์™€ Decoder ๋ธ”๋Ÿญ์— ์„œ๋กœ ๋‹ค๋ฅธ activation function์ด ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋„๋ก ํ–ˆ์Šต๋‹ˆ๋‹ค

  1. gelu
def gelu(x):
  cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
  return x * cdf
  1. swish
def swish(x):
    return x * tf.nn.sigmoid(x)
  1. swish_beta
def swish_beta(x):
    beta=tf.Variable(initial_value=1.0,trainable=True, name='swish_beta')
    return x * tf.nn.sigmoid(beta * x) #trainable parameter beta
  1. mish
def mish(x):
    return x * tf.math.tanh(tf.math.softplus(x))

4. Requirement

Python == 3.x
tensorflow >=2.0
tensorflow-datasets >= 1.3.2
pandas >= 0.24.2
numpy >= 1.16.3
six>=1.12.0

5. To-Do

  • TPU, Multi-GPU ์ง€์› ์˜ˆ์ •
  • Dropout ์ˆ˜์ • ์˜ˆ์ •
  • Predict ๋ชจ๋“ˆ ์ถ”๊ฐ€ ์˜ˆ์ •

6. Reference