Kochat
- ์ฑ๋ด ๋น๋๋ ์ฑ์ ์์ฐจ๊ณ , ์์ ๋ง์ ๋ฅ๋ฌ๋ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์ ์ ๋ง๋์๊ณ ์ถ์ผ์ ๊ฐ์?
- Kochat์ ์ด์ฉํ๋ฉด ์์ฝ๊ฒ ์์ ๋ง์ ๋ฅ๋ฌ๋ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์ ์ ๋น๋ํ ์ ์์ต๋๋ค.
# 1. ๋ฐ์ดํฐ์
๊ฐ์ฒด ์์ฑ
dataset = Dataset(ood=True)
# 2. ์๋ฒ ๋ฉ ํ๋ก์ธ์ ์์ฑ
emb = GensimEmbedder(model=embed.FastText())
# 3. ์๋(Intent) ๋ถ๋ฅ๊ธฐ ์์ฑ
clf = DistanceClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CenterLoss(dataset.intent_dict)
)
# 4. ๊ฐ์ฒด๋ช
(Named Entity) ์ธ์๊ธฐ ์์ฑ
rcn = EntityRecognizer(
model=entity.LSTM(dataset.entity_dict),
loss=CRFLoss(dataset.entity_dict)
)
# 5. ๋ฅ๋ฌ๋ ์ฑ๋ด RESTful API ํ์ต & ๋น๋
kochat = KochatApi(
dataset=dataset,
embed_processor=(emb, True),
intent_classifier=(clf, True),
entity_recognizer=(rcn, True),
scenarios=[
weather, dust, travel, restaurant
]
)
# 6. View ์์คํ์ผ๊ณผ ์ฐ๊ฒฐ
@kochat.app.route('/')
def index():
return render_template("index.html")
# 7. ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์๋ฒ ๊ฐ๋
if __name__ == '__main__':
kochat.app.template_folder = kochat.root_dir + 'templates'
kochat.app.static_folder = kochat.root_dir + 'static'
kochat.app.run(port=8080, host='0.0.0.0')
Why Kochat?
- ํ๊ตญ์ด๋ฅผ ์ง์ํ๋ ์ต์ด์ ์คํ์์ค ๋ฅ๋ฌ๋ ์ฑ๋ด ํ๋ ์์ํฌ์ ๋๋ค. (๋น๋์๋ ๋ค๋ฆ ๋๋ค.)
- ๋ค์ํ Pre built-in ๋ชจ๋ธ๊ณผ Lossํจ์๋ฅผ ์ง์ํฉ๋๋ค. NLP๋ฅผ ์ ๋ชฐ๋ผ๋ ์ฑ๋ด์ ๋ง๋ค ์ ์์ต๋๋ค.
- ์์ ๋ง์ ์ปค์คํ ๋ชจ๋ธ, Lossํจ์๋ฅผ ์ ์ฉํ ์ ์์ต๋๋ค. NLP ์ ๋ฌธ๊ฐ์๊ฒ ๋์ฑ ์ ์ฉํฉ๋๋ค.
- ์ฑ๋ด์ ํ์ํ ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ๋ชจ๋ธ, ํ์ต ํ์ดํ๋ผ์ธ, RESTful API๊น์ง ๋ชจ๋ ๋ถ๋ถ์ ์ ๊ณตํฉ๋๋ค.
- ๊ฐ๊ฒฉ ๋ฑ์ ์ ๊ฒฝ์ธ ํ์ ์์ผ๋ฉฐ, ์์ผ๋ก๋ ์ญ ์คํ์์ค ํ๋ก์ ํธ๋ก ์ ๊ณตํ ์์ ์ ๋๋ค.
- ์๋์ ๊ฐ์ ๋ค์ํ ์ฑ๋ฅ ํ๊ฐ ๋ฉํธ๋ฆญ๊ณผ ๊ฐ๋ ฅํ ์๊ฐํ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
Table of contents
- 1. Kochat ์ด๋?
- 2. About Chatbot
- 3. Getting Started
- 4. Usage
-
5. Visualization Support
- 5.1. Train/Test Accuracy
- 5.2. Train/Test Recall (macro average)
- 5.3. Train/Test Precision (macro average)
- 5.4. Train/Test F1-Score (macro average)
- 5.5. Train/Test Confusion Matrix
- 5.6. Train/Test Classification Performance Report
- 5.7. Train/Test Fallback Detection Performance Report
- 5.8. Feature Space Visualization
- 6. Performance Issue
- 7. Demo Application
- 8. Contributor
- 9. TODO List
- 10. Reference
-
11. License
1. Kochat ์ด๋?
Kochat์ ํ๊ตญ์ด ์ ์ฉ ์ฑ๋ด ๊ฐ๋ฐ ํ๋ ์์ํฌ๋ก, ๋จธ์ ๋ฌ๋ ๊ฐ๋ฐ์๋ผ๋ฉด
๋๊ตฌ๋ ๋ฌด๋ฃ๋ก ์์ฝ๊ฒ ํ๊ตญ์ด ์ฑ๋ด์ ๊ฐ๋ฐ ํ ์ ์๋๋ก ๋๋ ์คํ์์ค ํ๋ ์์ํฌ์
๋๋ค.
๋จ์ Chit-chat์ด ์๋ ์ฌ์ฉ์์๊ฒ ์ฌ๋ฌ ๊ธฐ๋ฅ์ ์ ๊ณตํ๋ ์์ฉํ ๋ ๋ฒจ์ ์ฑ๋ด ๊ฐ๋ฐ์
๋จ์ผ ๋ชจ๋ธ๋ง์ผ๋ก ๊ฐ๋ฐ๋๋ ๊ฒฝ์ฐ๋ณด๋ค ๋ค์ํ ๋ฐ์ดํฐ, configuration, ML๋ชจ๋ธ,
Restful Api ๋ฐ ์ ํ๋ฆฌ์ผ์ด์
, ๋ ์ด๋ค์ ์ ๊ธฐ์ ์ผ๋ก ์ฐ๊ฒฐํ ํ์ดํ๋ผ์ธ์ ๊ฐ์ถ์ด์ผ ํ๋๋ฐ
์ด ๊ฒ์ ์ฒ์๋ถํฐ ๊ฐ๋ฐ์๊ฐ ์ค์ค๋ก ๊ตฌํํ๋ ๊ฒ์ ๊ต์ฅํ ๋ฒ๊ฑฐ๋กญ๊ณ ์์ด ๋ง์ด ๊ฐ๋ ์์
์
๋๋ค.
์ค์ ๋ก ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์ ์ง์ ๊ตฌํํ๋ค๋ณด๋ฉด ์๋ ๊ทธ๋ฆผ์ฒ๋ผ ์ค์ง์ ์ผ๋ก ๋ชจ๋ธ ๊ฐ๋ฐ๋ณด๋ค๋
์ด๋ฐ ๋ถ๋ถ๋ค์ ํจ์ฌ ์๊ฐ๊ณผ ๋
ธ๋ ฅ์ด ๋ง์ด ํ์ํฉ๋๋ค.
Kochat์ ์ด๋ฌํ ๋ถ๋ถ์ ํด๊ฒฐํ๊ธฐ ์ํด ์ ์๋์์ต๋๋ค.
๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ์ํคํ
์ฒ, ๋ชจ๋ธ๊ณผ์ ํ์ดํ๋ผ์ธ, ์คํ ๊ฒฐ๊ณผ ์๊ฐํ,
์ฑ๋ฅํ๊ฐ ๋ฑ์ Kochat์ ๊ตฌ์ฑ์ ์ฌ์ฉํ๋ฉด์ ๊ฐ๋ฐ์๊ฐ ์ํ๋ ๋ชจ๋ธ์ด๋ Lossํจ์,
๋ฐ์ดํฐ ์
๋ฑ๋ง ๊ฐ๋จํ๊ฒ ์์ฑํ์ฌ ๋ด๊ฐ ์ํ๋ ๋ชจ๋ธ์ ์ฑ๋ฅ์ ๋น ๋ฅด๊ฒ ์คํํ ์ ์๊ฒ ๋์์ค๋๋ค.
๋ํ ํ๋ฆฌ ๋นํธ์ธ ๋ชจ๋ธ๋ค๊ณผ Loss ํจ์๋ฑ์ ์ง์ํ์ฌ ๋ฅ๋ฌ๋์ด๋ ์์ฐ์ด์ฒ๋ฆฌ์ ๋ํด ์ ๋ชจ๋ฅด๋๋ผ๋
ํ๋ก์ ํธ์ ๋ฐ์ดํฐ๋ง ์ถ๊ฐํ๋ฉด ์์ฝ๊ฒ ์๋นํ ๋์ ์ฑ๋ฅ์ ์ฑ๋ด์ ๊ฐ๋ฐํ ์ ์๊ฒ ๋์์ค๋๋ค.
์์ง์ ์ด๊ธฐ๋ ๋ฒจ์ด๊ธฐ ๋๋ฌธ์ ๋ง์ ๋ชจ๋ธ๊ณผ ๊ธฐ๋ฅ์ ์ง์ํ์ง๋ ์์ง๋ง ์ ์ฐจ ๋ชจ๋ธ๊ณผ
๊ธฐ๋ฅ์ ๋๋ ค๋๊ฐ ๊ณํ์
๋๋ค.
1.1. ๊ธฐ์กด ์ฑ๋ด ๋น๋์์ ์ฐจ์ด์
-
๊ธฐ์กด์ ์์ฉํ๋ ๋ง์ ์ฑ๋ด ๋น๋์ Kochat์ ํ๊น์ผ๋ก ํ๋ ์ฌ์ฉ์๊ฐ ๋ค๋ฆ ๋๋ค. ์์ฉํ๋ ์ฑ๋ด ๋น๋๋ค์ ๋งค์ฐ ๊ฐํธํ ์น ๊ธฐ๋ฐ์ UX/UI๋ฅผ ์ ๊ณตํ๋ฉฐ ์ผ๋ฐ์ธ์ ํ๊น์ผ๋ก ํฉ๋๋ค. ๊ทธ์ ๋ฐํด Kochat์ ์ฑ๋ด๋น๋ ๋ณด๋ค๋ ๊ฐ๋ฐ์๋ฅผ ํ๊น์ผ๋กํ๋ ํ๋ ์์ํฌ์ ๊ฐ๊น์ต๋๋ค. ๊ฐ๋ฐ์๋ ์์ค์ฝ๋๋ฅผ ์์ฑํจ์ ๋ฐ๋ผ์ ํ๋ ์์ํฌ์ ๋ณธ์ธ๋ง์ ๋ชจ๋ธ์ ์ถ๊ฐํ ์ ์๊ณ , Loss ํจ์๋ฅผ ๋ฐ๊พธ๊ฑฐ๋ ๋ณธ์ธ์ด ์ํ๋ฉด ์์ ์๋ก์ด ๊ธฐ๋ฅ์ ์ฒจ๊ฐํ ์๋ ์์ต๋๋ค.
-
Kochat์ ์คํ์์ค ํ๋ก์ ํธ์ ๋๋ค. ๋ฐ๋ผ์ ๋ง์ ์ฌ๋์ด ์ฐธ์ฌํด์ ํจ๊ป ๊ฐ๋ฐํ ์ ์๊ณ ๋ง์ฝ ์๋ก์ด ๋ชจ๋ธ์ ๊ฐ๋ฐํ๊ฑฐ๋ ์๋ก์ด ๊ธฐ๋ฅ์ ์ถ๊ฐํ๊ณ ์ถ๋ค๋ฉด ์ผ๋ง๋ ์ง ๋ ํฌ์งํ ๋ฆฌ์ ์ปจํธ๋ฆฌ๋ทฐ์ ํ ์ ์์ต๋๋ค.
-
Kochat์ ๋ฌด๋ฃ์ ๋๋ค. ๋งค๋ฌ ์ฌ์ฉ๋ฃ๋ฅผ ๋ด์ผํ๋ ์ฑ๋ด ๋น๋๋ค์ ๋นํด ์์ฒด์ ์ธ ์๋ฒ๋ง ๊ฐ์ง๊ณ ์๋ค๋ฉด ๋น์ฉ์ ์ฝ ์์ด ์ผ๋ง๋ ์ง ์ฑ๋ด์ ๊ฐ๋ฐํ๊ณ ์๋น์ค ํ ์ ์์ต๋๋ค. ์์ง์ ๊ธฐ๋ฅ์ด ๋ฏธ์ฝํ์ง๋ง ์ถํ์๋ ์ ๋ง ์ฌ๋งํ ์ฑ๋ด ๋น๋๋ค ๋ณด๋ค ๋ ๋ค์ํ ๊ธฐ๋ฅ์ ๋ฌด๋ฃ๋ก ์ ๊ณตํ ์์ ์ ๋๋ค.
1.2. Kochat ์ ์ ๋๊ธฐ
์ด์ ์ ์ฌ๊ธฐ์ ๊ธฐ์ ์ฝ๋๋ฅผ ๊ธ์ด๋ชจ์์ ๋ง๋ , ์์ค ๋ฎ์ ์ ๋ฅ๋ฌ๋ chatbot ๋ ํฌ์งํ ๋ฆฌ๊ฐ
์๊ฐ๋ณด๋ค ํฐ ๊ด์ฌ์ ๋ฐ์ผ๋ฉด์, ํ๊ตญ์ด๋ก ๋ ๋ฅ๋ฌ๋ ์ฑ๋ด ๊ตฌํ์ฒด๊ฐ ์ ๋ง ๋ง์ด ์๋ค๋ ๊ฒ์ ๋๊ผ์ต๋๋ค.
ํ์ฌ ๋๋ถ๋ถ์ ์ฑ๋ด ๋น๋๋ค์ ๋๋ถ๋ถ ์ผ๋ฐ์ธ์ ๊ฒจ๋ฅํ๊ธฐ ๋๋ฌธ์ ์น์์์ ์์ฌ์ด UX/UI
๊ธฐ๋ฐ์ผ๋ก ์๋น์ค ์ค์
๋๋ค. ์ผ๋ฐ์ธ ์ฌ์ฉ์๋ ์ฌ์ฉํ๊ธฐ ํธ๋ฆฌํ๊ฒ ์ง๋ง, ์ ์ ๊ฐ์ ๊ฐ๋ฐ์๋ค์
๋ชจ๋ธ๋ ์ปค์คํฐ๋ง์ด์ง ํ๊ณ ์ถ๊ณ , ๋ก์คํจ์๋ ๋ฐ๊ฟ๋ณด๊ณ ์ถ๊ณ , ์๊ฐํ๋ ํ๋ฉด์ ๋์ฑ ๋์ ์ฑ๋ฅ์
์ถ๊ตฌํ๊ณ ์ถ์ง๋ง ์์ฝ๊ฒ๋ ํ๊ตญ์ด ์ฑ๋ด ๋น๋ ์ค์์ ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ์ ์๋ ค์ง ๊ฒ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ์ค, ๋ฏธ๊ตญ์ RASA๋ผ๋ ์ฑ๋ด ํ๋ ์์ํฌ๋ฅผ ๋ณด๊ฒ ๋์์ต๋๋ค.
RASA๋ ๊ฐ๋ฐ์๊ฐ ์ง์ ์์ค์ฝ๋๋ฅผ ์์ ํ ์ ์๊ธฐ ๋๋ฌธ์ ๋ค์ํ ๋ถ๋ถ์ ์ปค์คํฐ๋ง์ด์ง ํ ์ ์์ต๋๋ค.
๊ทธ๋ฌ๋ ํ๊ตญ์ด๋ฅผ ์ ๋๋ก ์ง์ํ์ง ์์์, ์ ์ฉ ํ ํฌ๋์ด์ ๋ฅผ ์ถ๊ฐํ๋ ๋ฑ ๋งค์ฐ ๋ฒ๊ฑฐ๋ก์ด ์์
์ด
ํ์ํ๊ณ ์ค์ ๋ก ๋๋ฌด ๋ค์ํ ์ปดํฌ๋ํธ๊ฐ ์กด์ฌํ์ฌ ์ต์ํด์ง๋๋ฐ ์กฐ๊ธ ์ด๋ ค์ด ํธ์
๋๋ค.
๋๋ฌธ์ ๋๊ตฐ๊ฐ ํ๊ตญ์ด ๊ธฐ๋ฐ์ด๋ฉด์ ์กฐ๊ธ ๋ ์ปดํฉํธํ ์ฑ๋ด ํ๋ ์์ํฌ๋ฅผ ์ ์ํ๋ค๋ฉด
์ฑ๋ด์ ๊ฐ๋ฐํด์ผํ๋ ๊ฐ๋ฐ์๋ค์๊ฒ ์ ๋ง ์ ์ฉํ ๊ฒ์ด๋ผ๊ณ ์๊ฐ๋์๊ณ ์ง์ ์ด๋ฌํ ํ๋ ์์ํฌ๋ฅผ
๋ง๋ค์ด๋ณด์๋ ์๊ฐ์ Kochat์ ์ ์ํ๊ฒ ๋์์ต๋๋ค.
Kochat์ ํ๊ตญ์ด(Korean)์ ์๊ธ์์ธ Ko์ ์ ์ด๋ฆ ์ ๊ธ์์ธ Ko๋ฅผ ๋ฐ์์ ์ง์์ต๋๋ค. Kochat์ ์์ผ๋ก๋ ๊ณ์ ์คํ์์ค ํ๋ก์ ํธ๋ก ์ ์ง๋ ๊ฒ์ด๋ฉฐ, ์ ์ด๋ 1~2๋ฌ์ 1๋ฒ ์ด์์ ์๋ก์ด ๋ชจ๋ธ์ ์ถ๊ฐํ๊ณ , ๊ธฐ์กด ์์ค์ฝ๋์ ๋ฒ๊ทธ๋ฅผ ์์ ํ๋ ๋ฑ ์ ์ง๋ณด์ ์์ ์ ์ด์ด๊ฐ ๊ฒ์ด๋ฉฐ ์ฒ์์๋ ๋ฏธ์ฒํ ์ค๋ ฅ์ธ ์ ๊ฐ ์์ํ์ง๋ง, ๊ทธ ๋์ RASA์ฒ๋ผ ์ ๋ง ์ ์ฉํ๊ณ ๋์ ์ฑ๋ฅ์ ๋ณด์ฌ์ฃผ๋ ์์ค๋์ ์คํ์์ค ํ๋ ์์ํฌ๊ฐ ๋์์ผ๋ฉด ์ข๊ฒ ์ต๋๋ค. :)
2. About Chatbot
์ด ์ฑํฐ์์๋ ์ฑ๋ด์ ๋ถ๋ฅ์ ๊ตฌํ๋ฐฉ๋ฒ, Kochat์ ์ด๋ป๊ฒ ์ฑ๋ด์ ๊ตฌํํ๊ณ ์๋์ง์ ๋ํด
๊ฐ๋จํ๊ฒ ์๊ฐํฉ๋๋ค.
2.1. ์ฑ๋ด์ ๋ถ๋ฅ
์ฑ๋ด์ ํฌ๊ฒ ๋น๋ชฉ์ ๋ํ๋ฅผ ์ํ Open domain ์ฑ๋ด๊ณผ ๋ชฉ์ ๋ํ๋ฅผ ์ํ Close domain ์ฑ๋ด์ผ๋ก ๋๋ฉ๋๋ค.
Open domain ์ฑ๋ด์ ์ฃผ๋ก ์ก๋ด ๋ฑ์ ์ํํ๋ ์ฑ๋ด์ ์๋ฏธํ๋๋ฐ,
์ฌ๋ฌ๋ถ์ด ์ ์๊ณ ์๋ ์ฌ์ฌ์ด ๋ฑ์ด ์ฑ๋ด์ด ๋ํ์ ์ธ Open domain ์ฑ๋ด์ด๋ฉฐ Chit-chat์ด๋ผ๊ณ ๋ ๋ถ๋ฆฝ๋๋ค.
Close domain ์ฑ๋ด์ด๋ ํ์ ๋ ๋ํ ๋ฒ์ ์์์ ์ฌ์ฉ์๊ฐ ์ํ๋ ๋ชฉ์ ์ ๋ฌ์ฑํ๊ธฐ ์ํ ์ฑ๋ด์ผ๋ก
์ฃผ๋ก ๊ธ์ต์๋ด๋ด, ์๋น์์ฝ๋ด ๋ฑ์ด ์ด์ ํด๋นํ๋ฉฐ Goal oriented ์ฑ๋ด์ด๋ผ๊ณ ๋ ๋ถ๋ฆฝ๋๋ค.
์์ฆ ์ถ์๋๋ ์๋ฆฌ๋ ๋น
์ค๋น ๊ฐ์ ์ธ๊ณต์ง๋ฅ ๋น์, ์ธ๊ณต์ง๋ฅ ์คํผ์ปค๋ค์ ํน์ ๊ธฐ๋ฅ๋ ์ํํด์ผํ๊ณ
์ฌ์ฉ์์ ์ก๋ด๋ ์ ํด์ผํ๋ฏ๋ก Open domain ์ฑ๋ด๊ณผ Close domain ์ฑ๋ด์ด ๋ชจ๋ ํฌํจ๋์ด ์๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค.
2.2. ์ฑ๋ด์ ๊ตฌํ
์ฑ๋ด์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ํฌ๊ฒ ํต๊ณ๊ธฐ๋ฐ์ ์ฑ๋ด๊ณผ ๋ฅ๋ฌ๋ ๊ธฐ๋ฐ์ ์ฑ๋ด์ผ๋ก ๋๋ฉ๋๋ค.
์ฌ๊ธฐ์์๋ ๋ฅ๋ฌ๋ ๊ธฐ๋ฐ์ ์ฑ๋ด๋ง ์๊ฐํ๋๋ก ํ๊ฒ ์ต๋๋ค.
2.2.1. Open domain ์ฑ๋ด
๋จผ์ Open domain ์ฑ๋ด์ ๊ฒฝ์ฐ๋ ๋ฅ๋ฌ๋ ๋ถ์ผ์์๋ ๋๋ถ๋ถ, End to End
์ ๊ฒฝ๋ง ๊ธฐ๊ณ๋ฒ์ญ ๋ฐฉ์(Seq2Seq)์ผ๋ก ๊ตฌํ๋์ด์์ต๋๋ค. Seq2Seq์ ํ ๋ฌธ์ฅ์ ๋ค๋ฅธ ๋ฌธ์ฅ์ผ๋ก
๋ณํ/๋ฒ์ญํ๋ ๋ฐฉ์์
๋๋ค. ๋ฒ์ญ๊ธฐ์๊ฒ "๋๋ ๋ฐฐ๊ณ ํ๋ค"๋ผ๋ ์
๋ ฅ์ด ์ฃผ์ด์ง๋ฉด "I'm Hungry"๋ผ๊ณ
๋ฒ์ญํด๋ด๋ฏ์ด, ์ฑ๋ด Seq2Seq๋ "๋๋ ๋ฐฐ๊ณ ํ๋ค"๋ผ๋ ์
๋ ฅ์ด ์ฃผ์ด์ง ๋, "๋ง์ด ๋ฐฐ๊ณ ํ์ ๊ฐ์?" ๋ฑ์ ๋๋ต์ผ๋ก ๋ฒ์ญํฉ๋๋ค.
์ต๊ทผ์ ๋ฐํ๋ Google์ Meena
๊ฐ์ ๋ชจ๋ธ์ ๋ณด๋ฉด, ๋ณต์กํ ๋ชจ๋ธ ์ํคํ
์ฒ๋ ํ์ต ํ๋ ์์ํฌ ์์ด End to End (Seq2Seq) ๋ชจ๋ธ๋ง์ผ๋ก๋
๋งค์ฐ ๋ฐฉ๋ํ ๋ฐ์ดํฐ์
๊ณผ ๋์ ์ฑ๋ฅ์ ์ปดํจํ
๋ฆฌ์์ค๋ฅผ ํ์ฉํ๋ฉด ์ ๋ง ์ฌ๋๊ณผ ๊ทผ์ ํ ์์ค์ผ๋ก ๋ํํ ์ ์๋ค๋ ๊ฒ์ผ๋ก ์๋ ค์ ธ์์ต๋๋ค.
(๊ทธ๋ฌ๋ ํ์ฌ๋ฒ์ ํ๋ ์์ํฌ์์๋ Close domain ๋ง ์ง์ํฉ๋๋ค. ์ฐจํ ๋ฒ์ ์์ ๋ค์ํ Seq2Seq ๋ชจ๋ธ๋ ์ถ๊ฐํ ์์ ์
๋๋ค.)
2.2.2. Close domain ์ฑ๋ด
Close domain ์ฑ๋ด์ ๋๋ถ๋ถ Slot Filling ๋ฐฉ์์ผ๋ก ๊ตฌํ๋์ด ์์ต๋๋ค. ๋ฌผ๋ก Close domain ์ฑ๋ด๋
Open domain์ฒ๋ผ End to end๋ก ๊ตฌํํ๋ ค๋ ๋ค์ํ
์๋ ๋ค๋
์กด์ฌ ํ์์ผ๋, ๋
ผ๋ฌธ์์ ์ ์ํ๋
๋ฐ์ดํฐ์
์์๋ง ์ ์๋ํ๊ณ , ์ค์ ๋ค๋ฅธ ๋ฐ์ดํฐ ์
(Task6์ DSTC dataset)์ ์ ์ฉํ๋ฉด ๊ทธ ์ ๋์
์ฑ๋ฅ์ด ๋์ค์ง ์์๊ธฐ ๋๋ฌธ์ ํ์
์ ์ ์ฉ๋๊ธฐ๋ ์ด๋ ค์์ด ์์ต๋๋ค. ๋๋ฌธ์ ํ์ฌ๋ ๋๋ถ๋ถ์ ๋ชฉ์ ์งํฅ
์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์ด ๊ธฐ์กด ๋ฐฉ์์ธ Slot Filling ๋ฐฉ์์ผ๋ก ๊ตฌํ๋๊ณ ์์ต๋๋ค.
Slot Filling ๋ฐฉ์์ ๋ฏธ๋ฆฌ ๊ธฐ๋ฅ์ ์ํํ ์ ๋ณด๋ฅผ ๋ด๋ '์ฌ๋กฏ'์ ๋จผ์ ์ ์ํ ๋ค์,
์ฌ์ฉ์์ ๋ง์ ๋ฃ๊ณ ์ด๋ค ์ฌ๋กฏ์ ์ ํํ ์ง ์ ํ๊ณ , ํด๋น ์ฌ๋กฏ์ ์ฑ์๋๊ฐ๋ ๋ฐฉ์์
๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ด๋ฌํ Slot Filling ๋ฐฉ์ ์ฑ๋ด์ ๊ตฌํ์ ์ํด '์ธํ
ํธ'์ '์ํฐํฐ'๋ผ๋ ๊ฐ๋
์ด ๋ฑ์ฅํฉ๋๋ค.
๋ง๋ก๋ง ์ค๋ช
ํ๋ฉด ์ด๋ ค์ฐ๋ ์์๋ฅผ ๋ด
์๋ค. ๊ฐ์ฅ ๋จผ์ ์ฐ๋ฆฌ๊ฐ ์ฌํ ์ ๋ณด ์๋ฆผ ์ฑ๋ด์ ๋ง๋ ๋ค๊ณ ๊ฐ์ ํ๊ณ ,
์ฌํ์ ๋ณด ์ ๊ณต์ ์ํด "๋ ์จ ์ ๋ณด์ ๊ณต", "๋ฏธ์ธ๋จผ์ง ์ ๋ณด์ ๊ณต", "๋ง์ง ์ ๋ณด์ ๊ณต", "์ฌํ์ง ์ ๋ณด์ ๊ณต"์ด๋ผ๋ 4๊ฐ์ง
ํต์ฌ ๊ธฐ๋ฅ์ ๊ตฌํํด์ผํ๋ค๊ณ ํฉ์๋ค.
2.2.2.1. ์ธํ ํธ(์๋) ๋ถ๋ฅํ๊ธฐ : ์ฌ๋กฏ ๊ณ ๋ฅด๊ธฐ
๊ฐ์ฅ ๋จผ์ ์ฌ์ฉ์์๊ฒ ๋ฌธ์ฅ์ ์
๋ ฅ๋ฐ์์ ๋, ์ฐ๋ฆฌ๋ ์ 4๊ฐ์ง ์ ๋ณด์ ๊ณต ๊ธฐ๋ฅ ์ค
์ด๋ค ๊ธฐ๋ฅ์ ์คํํด์ผํ๋์ง ์์์ฑ์ผํฉ๋๋ค. ์ด ๊ฒ์ ์ธํ
ํธ(Intent)๋ถ๋ฅ. ์ฆ, ์๋ ๋ถ๋ฅ๋ผ๊ณ ํฉ๋๋ค.
์ฌ์ฉ์๋ก๋ถํฐ "์์์ผ ๋ถ์ฐ ๋ ์จ ์ด๋ ๋?"๋ผ๋ ๋ฌธ์ฅ์ด ์
๋ ฅ๋๋ฉด 4๊ฐ์ง ๊ธฐ๋ฅ ์ค ๋ ์จ ์ ๋ณด์ ๊ณต ๊ธฐ๋ฅ์
์ํํด์ผ ํ๋ค๋ ๊ฒ์ ์์๋ด์ผํฉ๋๋ค. ๋๋ฌธ์ ๋ฌธ์ฅ ๋ฒกํฐ๊ฐ ์
๋ ฅ๋๋ฉด, Text Classification์ ์ํํ์ฌ
์ด๋ค API๋ฅผ ์ฌ์ฉํด์ผํ ์ง ์์๋
๋๋ค.
2.2.2.2. ํด๋ฐฑ ๊ฒ์ถํ๊ธฐ : ๋ชจ๋ฅด๊ฒ ์ผ๋ฉด ๋ชจ๋ฅธ๋ค๊ณ ๋งํ๊ธฐ
๊ทธ๋ฌ๋ ์ฌ๊ธฐ์ ์ ๊ฒฝ์จ์ผํ ๋ถ๋ถ์ด ํ ๋ถ๋ถ ์กด์ฌํฉ๋๋ค. ์ผ๋ฐ์ ์ธ ๋ฅ๋ฌ๋ ๋ถ๋ฅ๋ชจ๋ธ์ ๋ชจ๋ธ์ด ํ์ตํ ํด๋์ค ๋ด์์๋ง ๋ถ๋ฅ๊ฐ ๊ฐ๋ฅํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ฌ์ฉ์๊ฐ 4๊ฐ์ง์ ๋ฐํ์๋ ์์์๋ง ๋งํ ๊ฒ์ด๋ผ๋ ๋ณด์ฅ์ ์์ต๋๋ค. ๋ง์ฝ ์์ฒ๋ผ "๋ ์จ ์ ๋ณด์ ๊ณต", "๋ฏธ์ธ๋จผ์ง ์ ๋ณด์ ๊ณต", "๋ง์ง ์ ๋ณด์ ๊ณต", "์ฌํ์ง ์ ๋ณด์ ๊ณต"์ ๋ฐ์ดํฐ๋ง ํ์ตํ ์ธํ ํธ ๋ถ๋ฅ๋ชจ๋ธ์ "์๋ ๋ฐ๊ฐ๋ค."๋ผ๋ ๋ง์ ํ๊ฒ ๋๋ฉด ์ด๋ป๊ฒ ๋ ๊น์? ์ 4๊ฐ์ง์ ์ํ์ง ์์ ๋ฐํ ์๋์ธ "์ธ์ฌ"์ ํด๋นํ์ง๋ง ๋ชจ๋ธ์ ์ธ์ฟ๋ง์ ํ๋ฒ๋ ๋ณธ์ ์ด ์๊ธฐ ๋๋ฌธ์ ์ด๊ฒ๋ ์ญ์ 4๊ฐ์ง ์ค ํ๋๋ก ๋ถ๋ฅํ๊ฒ ๋ฉ๋๋ค. ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ์๋ ๋ถ๋ฅ๋ชจ๋ธ์๋ ๋ฐ๋์ ํด๋ฐฑ (Fallback) ๊ฒ์ถ ์ ๋ต์ด ํฌํจ๋์ด์ผํฉ๋๋ค.
๋ณดํต์ ์ฑ๋ด๋น๋๋ค์ ์ ๋ ฅ ๋จ์ด๋ค์ ์๋ฒ ๋ฉ์ธ ๋ฌธ์ฅ ๋ฒกํฐ์ ๊ธฐ์กด ๋ฐ์ดํฐ์ ์ ์๋ ๋ฌธ์ฅ ๋ฒกํฐ๋ค์ Cosine ์ ์ฌ๋๋ฅผ ๋น๊ตํฉ๋๋ค. ์ด ๋ ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ์ ํด๋์ค์์ ๊ฐ๋๊ฐ ์๊ณ์น ์ด์์ด๋ฉด Fallback์ด๊ณ , ๊ทธ๋ ์ง ์์ผ๋ฉด ๊ฐ์ฅ ๊ฐ๊น์ด ์ธ์ ํด๋์ค๋ก ๋ฐ์ดํฐ ์ํ์ ๋ถ๋ฅํ๊ฒ ๋ฉ๋๋ค. ์๋ ๊ทธ๋ฆผ์ ๋ณด๋ฉด ์ผ๋ฐ์ ์ธ ์ฑ๋ด ๋น๋๋ค์ด ์ด๋ค์์ผ๋ก Fallback์ ๊ฒ์ถํ๋์ง ์ ์ ์์ต๋๋ค.
Kochat์ ์ด๋ ๊ฒ ๋จ์ํ ๋ฌธ์ฅ๋ค์ ๋ฒกํฐ Cosine ์ ์ฌ๋๋ฅผ ๋น๊ตํ์ง ์๊ณ
๋์ฑ ๊ณ ์ฐจ์์ ์ธ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ์ฌ Fallback ๋ํ
์
์ ๋ณด๋ค ๋ ์ ์ํํ๋๋ก
์ค๊ณํ์๋๋ฐ ์ด์ ๋ํ ์์ธํ ๋ด์ฉ์ ์๋์ Usage์์ ์์ธํ ์ธ๊ธํ๋๋ก ํ๊ฒ ์ต๋๋ค.
2.2.2.3. ์ํฐํฐ(๊ฐ์ฒด๋ช ) ์ธ์ํ๊ธฐ : ์ฌ๋กฏ ์ฑ์ฐ๊ธฐ
๊ทธ ๋ค์ ํด์ผํ ์ผ์ ๋ฐ๋ก ๊ฐ์ฒด๋ช
์ธ์ (Named Entity Recognition)์
๋๋ค.
์ด๋ค API๋ฅผ ํธ์ถํ ์ง ์์๋๋ค๋ฉด, ์ด์ ๊ทธ API๋ฅผ ํธ์ถํ๊ธฐ ์ํ ํ๋ผ๋ฏธํฐ๋ฅผ ์ฐพ์์ผํฉ๋๋ค.
๋ง์ฝ ๋ ์จ API์ ์คํ์ ์ํ ํ๋ผ๋ฏธํฐ๊ฐ "์ง์ญ"๊ณผ "๋ ์จ"๋ผ๋ฉด ์ฌ์ฉ์์ ์
๋ ฅ ๋ฌธ์ฅ์์ "์ง์ญ"์ ๊ด๋ จ๋ ์ ๋ณด์
"๋ ์จ"์ ๊ด๋ จ๋ ์ ๋ณด๋ฅผ ์ฐพ์๋ด์ ํด๋น ์ฌ๋กฏ์ ์ฑ์๋๋ค. ๋ง์ฝ ์ฌ์ฉ์๊ฐ "์์์ผ ๋ ์จ ์๋ ค์ค"๋ผ๊ณ ๋ง ๋งํ๋ค๋ฉด,
์ง์ญ์ ๊ด๋ จ๋ ์ ๋ณด๋ ์์ง ์ฐพ์๋ด์ง ๋ชปํ๊ธฐ ๋๋ฌธ์ ๋ค์ ๋๋ฌผ์ด์ ์ฐพ์๋ด์ผํฉ๋๋ค.
2.2.2.4. API ํธ์ถํ๊ธฐ : ๋๋ต ์์ฑํ๊ธฐ
์ฌ๋กฏ์ด ๋ชจ๋ ์ฑ์์ก๋ค๋ฉด API๋ฅผ ์คํ์์ผ์ ์ธ๋ถ๋ก๋ถํฐ ์ ๋ณด๋ฅผ ์ ๊ณต๋ฐ์ต๋๋ค.
API๋ก๋ถํฐ ๊ฒฐ๊ณผ๊ฐ ๋์ฐฉํ๋ฉด, ๋ฏธ๋ฆฌ ๋ง๋ค์ด๋ ํ
ํ๋ฆฟ ๋ฌธ์ฅ์ ํด๋น ์คํ ๊ฒฐ๊ณผ๋ฅผ ์ฝ์
ํ์ฌ ๋๋ต์ ๋ง๋ค์ด๋ด๊ณ ,
์ด ๋๋ต์ ์ฌ์ฉ์์๊ฒ responseํฉ๋๋ค. ์ด API๋ ์์ ๋กญ๊ฒ ์ํ๋ API๋ฅผ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
์์ ์ ํ๋ฆฌ์ผ์ด์
์์๋ ์ฃผ๋ก ์น ํฌ๋กค๋ง์ ์ด์ฉํ์ฌ API๋ฅผ ๊ตฌ์ฑํ์๊ณ , ํฌ๋กค๋ฌ ๊ตฌํ ์ํคํ
์ฒ์ ๋ํด์๋ ํ์ ํ๋๋ก ํ๊ฒ ์ต๋๋ค.
Slot Filling ๋ฐฉ์์ ์ฑ๋ด์ ์์ ๊ฐ์ ํ๋ฆ์ผ๋ก ์งํ๋ฉ๋๋ค. ๋ฐ๋ผ์ ์ด๋ฌํ ๋ฐฉ์์ ์ฑ๋ด์ ๊ตฌํํ๋ ค๋ฉด ์ต์ํ 3๊ฐ์ง์ ๋ชจ๋์ด ํ์ํฉ๋๋ค. ์ฒซ๋ฒ์งธ๋ก ์ธํ ํธ ๋ถ๋ฅ๋ชจ๋ธ, ์ํฐํฐ ์ธ์๋ชจ๋ธ, ๊ทธ๋ฆฌ๊ณ ๋๋ต ์์ฑ๋ชจ๋(์์ ์์๋ ํฌ๋กค๋ง)์ ๋๋ค. Kochat์ ์ด ์ธ๊ฐ์ง ๋ชจ๋๊ณผ ์ด๋ฅผ ์๋นํ Restful API๊น์ง ๋ชจ๋ ํฌํจํ๊ณ ์์ต๋๋ค. ์ด์ ๋ํด์๋ ์๋์ Usage ์ฑํฐ์์ ๊ฐ๊ฐ ๋ชจ๋ธ์ด ์ด๋ป๊ฒ ๊ตฌํ๋์ด ์๋์ง ์์ธํ ์ค๋ช ํฉ๋๋ค.
3. Getting Started
3.1. Requirements
Kochat์ ์ด์ฉํ๋ ค๋ฉด ๋ฐ๋์ ๋ณธ์ธ์ OS์ ๋จธ์ ์ ๋ง๋ Pytorch๊ฐ ์ค์น ๋์ด์์ด์ผํฉ๋๋ค. ๋ง์ฝ Pytorch๋ฅผ ์ค์นํ์ง ์์ผ์ จ๋ค๋ฉด ์ฌ๊ธฐ ์์ ๋ค์ด๋ก๋ ๋ฐ์์ฃผ์ธ์. (Kochat์ ์ค์นํ๋ค๊ณ ํด์ Pytorch๊ฐ ํจ๊ป ์ค์น๋์ง ์์ต๋๋ค. ๋ณธ์ธ ๋ฒ์ ์ ๋ง๋ Pytorch๋ฅผ ๋ค์ด๋ก๋ ๋ฐ์์ฃผ์ธ์)
3.2. pip install
pip๋ฅผ ์ด์ฉํด Kochat์ ๊ฐ๋จํ๊ฒ ๋ค์ด๋ก๋ํ๊ณ ์ฌ์ฉํ ์ ์์ต๋๋ค. ์๋ ๋ช ๋ น์ด๋ฅผ ํตํด์ kochat์ ๋ค์ด๋ก๋ ๋ฐ์์ฃผ์ธ์.
pip install kochat
3.3 Dependencies
ํจํค์ง๋ฅผ ๊ตฌํํ๋๋ฐ ์ฌ์ฉ๋ ๋ํ๋์๋ ์๋์ ๊ฐ์ต๋๋ค. (Kochat ์ค์น์ ํจ๊ป ์ค์น๋ฉ๋๋ค.)
matplotlib==3.2.1
pandas==1.0.4
gensim==3.8.3
konlpy==0.5.2
numpy==1.18.5
joblib==0.15.1
scikit-learn==0.23.1
pytorch-crf==0.7.2
requests==2.24.0
flask==1.1.2
3.4 Configuration ํ์ผ ์ถ๊ฐํ๊ธฐ
pip๋ฅผ ์ด์ฉํด Kochat์ ๋ค์ด๋ก๋ ๋ฐ์๋ค๋ฉด ํ๋ก์ ํธ์, kochat์ configuration ํ์ผ์ ์ถ๊ฐํด์ผํฉ๋๋ค. kochat_config.zip ์ ๋ค์ด๋ก๋ ๋ฐ๊ณ ์์ถ์ ํ์ด์ interpreter์ working directory์ ๋ฃ์ต๋๋ค. (kochat api๋ฅผ ์คํํ๋ ํ์ผ๊ณผ ๋์ผํ ๊ฒฝ๋ก์ ์์ด์ผํฉ๋๋ค. ์์ธํ ์์๋ ์๋ ๋ฐ๋ชจ์์ ํ์ธํ์ค ์ ์์ต๋๋ค.) config ํ์ผ์๋ ๋ค์ํ ์ค์ ๊ฐ๋ค์ด ์กด์ฌํ๋ ํ์ธํ๊ณ ์ ๋ง๋๋ก ๋ณ๊ฒฝํ์๋ฉด ๋ฉ๋๋ค.
3.5 ๋ฐ์ดํฐ์ ๋ฃ๊ธฐ
์ด์ ์ฌ๋ฌ๋ถ์ด ํ์ต์ํฌ ๋ฐ์ดํฐ์
์ ๋ฃ์ด์ผํฉ๋๋ค.
๊ทธ ์ ์ ๋ฐ์ดํฐ์
์ ํฌ๋งท์ ๋ํด์ ๊ฐ๋จํ๊ฒ ์์๋ด
์๋ค.
Kochat์ ๊ธฐ๋ณธ์ ์ผ๋ก Slot filling์ ๊ธฐ๋ฐ์ผ๋ก
ํ๊ณ ์๊ธฐ ๋๋ฌธ์ Intent์ Entity ๋ฐ์ดํฐ์
์ด ํ์ํฉ๋๋ค.
๊ทธ๋ฌ๋ ์ด ๋๊ฐ์ง ๋ฐ์ดํฐ์
์ ๋ฐ๋ก ๋ง๋ค๋ฉด ์๋นํ ๋ฒ๊ฑฐ๋ก์์ง๊ธฐ ๋๋ฌธ์
ํ๊ฐ์ง ํฌ๋งท์ผ๋ก ๋๊ฐ์ง ๋ฐ์ดํฐ๋ฅผ ์๋์ผ๋ก ์์ฑํฉ๋๋ค.
์๋ ๋ฐ์ดํฐ์
๊ท์น๋ค์ ๋ง์ถฐ์ ๋ฐ์ดํฐ๋ฅผ ์์ฑํด์ฃผ์ธ์
3.5.1. ๋ฐ์ดํฐ ํฌ๋งท
๊ธฐ๋ณธ์ ์ผ๋ก intent์ entity๋ฅผ ๋๋๋ ค๋ฉด, ๋๊ฐ์ง๋ฅผ ๋ชจ๋ ๊ตฌ๋ถํ ์ ์์ด์ผํฉ๋๋ค.
๊ทธ๋์ ์ ํํ ๋ฐฉ์์ ์ธํ
ํธ๋ ํ์ผ๋ก ๊ตฌ๋ถ, ์ํฐํฐ๋ ๋ผ๋ฒจ๋ก ๊ตฌ๋ถํ๋ ๊ฒ์ด์์ต๋๋ค.
์ถํ ๋ฆด๋ฆฌ์ฆ ๋ฒ์ ์์๋ Rasa์ฒ๋ผ ํจ์ฌ ์ฌ์ด ๋ฐฉ์์ผ๋ก ๋ณ๊ฒฝํ๋ ค๊ณ ํฉ๋๋ค๋ง, ์ด๊ธฐ๋ฒ์ ์์๋
๋ค์ ๋ถํธํ๋๋ผ๋ ์๋์ ํฌ๋งท์ ๋ฐ๋ผ์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
- weather.csv
question,label
๋ ์จ ์๋ ค์ฃผ์ธ์,O O
์์์ผ ์ธ์ ๋น์ค๋,S-DATE S-LOCATION O
๊ตฐ์ฐ ๋ ์จ ์ถ์ธ๊น ์ ๋ง,S-LOCATION O O O
๊ณก์ฑ ๋น์ฌ๊น,S-LOCATION O
๋ด์ผ ๋จ์ ๋ ์ค๊ฒ ์ง ์๋ง,S-DATE S-LOCATION O O O
๊ฐ์๋ ์ถ์ฒ ๊ฐ๋๋ฐ ์ค๋ ๋ ์จ ์๋ ค์ค,B-LOCATION E-LOCATION O S-DATE O O
์ ๋ถ ๊ตฐ์ฐ ๊ฐ๋๋ฐ ํ์์ผ ๋ ์จ ์๋ ค์ค๋,B-LOCATION E-LOCATION O S-DATE O O
์ ์ฃผ ์๊ทํฌ ๊ฐ๋ ค๋๋ฐ ํ์์ผ ๋ ์จ ์๋ ค์ค,B-LOCATION E-LOCATION O S-DATE O O
์ค๋ ์ ์ฃผ๋ ๋ ์จ ์๋ ค์ค,S-DATE S-LOCATION O O
... (์๋ต)
- travel.csv
question,label
์ด๋ ๊ด๊ด์ง ๊ฐ๊ฒ ๋,O O O
ํ์ฃผ ์ ๋ช
ํ ๊ณต์ฐ์ฅ ์๋ ค์ค,S-LOCATION O S-PLACE O
์ฐฝ์ ์ฌํ ๊ฐ๋งํ ๋ฐ๋ค,S-LOCATION O O S-PLACE
ํํ ๊ฐ๋งํ ์คํค์ฅ ์ฌํ ํด๋ณด๊ณ ์ถ๋ค,S-LOCATION O S-PLACE O O O
์ ์ฃผ๋ ํ
ํ์คํ
์ด ์ฌํ ๊ฐ ๋ฐ ์ถ์ฒํด ์ค,S-LOCATION S-PLACE O O O O O
์ ์ฃผ ๊ฐ๊น์ด ๋ฐ๋ค ๊ด๊ด์ง ๋ณด์ฌ์ค ๋ด์,S-LOCATION O S-PLACE O O O
์ฉ์ธ ๊ฐ๊น์ด ์ถ๊ตฌ์ฅ ์ด๋จ์ด,S-LOCATION O S-PLACE O
๋ถ๋น๋ ๊ด๊ด์ง,O O
์ฒญ์ฃผ ๊ฐ์ ํ๊ฒฝ ์์ ์ฐ ๊ฐ๋ณด๊ณ ์ถ์ด,S-LOCATION S-DATE O O S-PLACE O O
... (์๋ต)
์ ์ฒ๋ผ question,label์ด๋ผ๋ ํค๋(์ปฌ๋ผ๋ช )์ ๊ฐ์ฅ ์์ค์ ์์น์ํค๊ณ , ๊ทธ ์๋๋ก ๋๊ฐ์ ์ปฌ๋ฆผ question๊ณผ label์ ํด๋นํ๋ ๋ด์ฉ์ ์์ฑํฉ๋๋ค. ๊ฐ ๋จ์ด ๋ฐ ์ํฐํฐ๋ ๋์ด์ฐ๊ธฐ๋ก ๊ตฌ๋ถ๋ฉ๋๋ค. ๋ฐ๋ชจ ๋ฐ์ดํฐ๋ BIOํ๊น ์ ๊ฐ์ ํ BIOESํ๊น ์ ์ฌ์ฉํ์ฌ ๋ผ๋ฒจ๋งํ๋๋ฐ, ์ํฐํฐ ํ๊น ๋ฐฉ์์ ์์ ๋กญ๊ฒ ๊ณ ๋ฅด์ ๋ ๋ฉ๋๋ค. (config์์ ์ค์ ๊ฐ๋ฅํฉ๋๋ค.) ์ํฐํฐ ํ๊น ์คํค๋ง์ ๊ด๋ จ๋ ์์ธํ ๋ด์ฉ์ ์ฌ๊ธฐ ๋ฅผ ์ฐธ๊ณ ํ์ธ์.
3.5.2. ๋ฐ์ดํฐ์ ์ ์ฅ ๊ฒฝ๋ก
๋ฐ์ดํฐ์ ์ ์ฅ๊ฒฝ๋ก๋ ๊ธฐ๋ณธ์ ์ผ๋ก configํ์ผ์ด ์๋ ๊ณณ์ root๋ก ์๊ฐํ์ ๋, "root/data/raw"์ ๋๋ค. ์ด ๊ฒฝ๋ก๋ config์ DATA ์ฑํฐ์์ ๋ณ๊ฒฝ ๊ฐ๋ฅํฉ๋๋ค.
root
|_data
|_raw
|_weather.csv
|_dust.csv
|_retaurant.csv
|_...
3.5.3. ์ธํ ํธ ๋จ์๋ก ํ์ผ ๋ถํ
๊ฐ ์ธํ ํธ ๋จ์๋ก ํ์ผ์ ๋ถํ ํฉ๋๋ค. ์ด ๋, ํ์ผ๋ช ์ด ์ธํ ํธ๋ช ์ด ๋ฉ๋๋ค. ํ์ผ๋ช ์ ํ๊ธ๋ก ํด๋ ์๊ด ์๊ธด ํ์ง๋ง, ๋ฆฌ๋ ์ค ์ด์์ฒด์ ์ ๊ฒฝ์ฐ ์๊ฐํ์ matplotlib์ ํ๊ธํฐํธ๊ฐ ์ค์น๋์ด์์ง ์๋ค๋ฉด ๊ธ์๊ฐ ๊นจ์ง๋, ๊ฐ๊ธ์ ์ด๋ฉด ์๊ฐํ๋ฅผ ์ํด ์์ด๋ก ํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค. (๋ง์ฝ ๊ธ์๊ฐ ๊นจ์ง์ง ์์ผ๋ฉด ํ๊ธ๋ก ํด๋ ๋ฌด๋ฐฉํ๋, ํ๊ธ๋ก ํ๋ ค๋ฉด ํฐํธ๋ฅผ ์ค์นํด์ฃผ์ธ์.)
root
|_data
|_raw
|_weather.csv โ intent : weather
|_dust.csv โ intent : dust
|_retaurant.csv โ intent : restaurant
|_...
3.5.4. ํ์ผ์ ํค๋(์ปฌ๋ผ๋ช ) ์ค์
ํ์ผ์ ํค๋(์ปฌ๋ผ๋ช )์ ๋ฐ๋์ question๊ณผ label๋ก ํด์ฃผ์ธ์. ํค๋๋ฅผ config์์ ๋ฐ๊ฟ ์ ์๊ฒ ํ ๊น๋ ์๊ฐํ์ง๋ง, ๋ณ๋ก ํฐ ์๋ฏธ๊ฐ ์๋ ๊ฒ ๊ฐ์์ ์ฐ์ ์ ๊ณ ์ ๋ ๊ฐ์ธ question๊ณผ label๋ก ์ค์ ํ์์ต๋๋ค.
question,label โ ์ค์ !!!
... (์๋ต)
3.5.5. ๋ผ๋ฒจ๋ง ์ค์ ๊ฒ์ถ
์ํ ๋น question์ ๋จ์ด ๊ฐฏ์์ label์ ์ํฐํฐ ๊ฐฏ์๋ ๋์ผํด์ผํ๋ฉฐ config์ ์ ์ํ ์ํฐํฐ๋ง ์ฌ์ฉ ๊ฐ๋ฅํฉ๋๋ค. ์ด๋ฌํ ๋ผ๋ฒจ๋ง ์ค์๋ Kochat์ด ๋ฐ์ดํฐ๋ฅผ ๋ณํํ ๋ ๊ฒ์ถํด์ ์ด๋๊ฐ ํ๋ ธ๋์ง ์๋ ค์ค๋๋ค.
case 1: ๋ผ๋ฒจ๋ง ๋งค์นญ ์ค์ ๋ฐฉ์ง
question = ์ ์ฃผ ๋ ์ฌ๊น (size : 3)
label = S-LOCATION O O O (size : 4)
โ ์๋ฌ ๋ฐ์! (question๊ณผ label์ ์๊ฐ ๋ค๋ฆ)
case 2: ๋ผ๋ฒจ๋ง ์คํ ๋ฐฉ์ง
(in kochat_config.py)
DATA = {
... (์๋ต)
'NER_categories': ['DATE', 'LOCATION', 'RESTAURANT', 'PLACE'], # ์ฌ์ฉ์ ์ ์ ํ๊ทธ
'NER_tagging': ['B', 'E', 'I', 'S'], # NER์ BEGIN, END, INSIDE, SINGLE ํ๊ทธ
'NER_outside': 'O', # NER์ Oํ๊ทธ (Outside๋ฅผ ์๋ฏธ)
}
question = ์ ์ฃผ ๋ ์ฌ๊น
label = Z-LOC O O
โ ์๋ฌ ๋ฐ์! (์ ์๋์ง ์์ ์ํฐํฐ : Z-LOC)
NER_tagging + '-' + NER_categories์ ํํ๊ฐ ์๋๋ฉด ์๋ฌ๋ฅผ ๋ฐํํฉ๋๋ค.
3.5.6. OOD ๋ฐ์ดํฐ์
OOD๋ Out of distribution์ ์ฝ์๋ก, ๋ถํฌ ์ธ ๋ฐ์ดํฐ์ ์ ์๋ฏธํฉ๋๋ค. ์ฆ, ํ์ฌ ์ฑ๋ด์ด ์ง์ํ๋ ๊ธฐ๋ฅ ์ด์ธ์ ๋ฐ์ดํฐ๋ฅผ ์๋ฏธํฉ๋๋ค. OOD ๋ฐ์ดํฐ์ ์ด ์์ด๋ Kochat์ ์ด์ฉํ๋๋ฐ์๋ ์๋ฌด๋ฐ ๋ฌธ์ ๊ฐ ์์ง๋ง, OOD ๋ฐ์ดํฐ์ ์ ๊ฐ์ถ๋ฉด ๋งค์ฐ ๊ท์ฐฎ์ ๋ช๋ช ๋ถ๋ถ๋ค์ ํจ๊ณผ์ ์ผ๋ก ์๋ํ ํ ์ ์์ต๋๋ค. (์ฃผ๋ก Fallback Detection threshold ์ค์ ) OOD ๋ฐ์ดํฐ์ ์ ์๋์ฒ๋ผ "root/data/ood"์ ์ถ๊ฐํฉ๋๋ค.
root
|_data
|_raw
|_weather.csv
|_dust.csv
|_retaurant.csv
|_...
|_ood
|_ood_data_1.csv โ data/oodํด๋์ ์์นํ๊ฒ ํฉ๋๋ค.
|_ood_data_2.csv โ data/oodํด๋์ ์์นํ๊ฒ ํฉ๋๋ค.
OOD ๋ฐ์ดํฐ์ ์ ์๋์ ๊ฐ์ด question๊ณผ OOD์ ์๋๋ก ๋ผ๋ฒจ๋งํฉ๋๋ค. ๋ฐ๋ชจ ๋ฐ์ดํฐ์ ์ ์ ๋ถ ์๋๋๋ก ๋ผ๋ฒจ๋งํ์ง๋ง, ์ด ์๋๊ฐ์ ์ฌ์ฉํ์ง ์๊ธฐ ๋๋ฌธ์ ๊ทธ๋ฅ ์๋ฌด๊ฐ์ผ๋ก๋ ๋ผ๋ฒจ๋งํด๋ ์ฌ์ค ๋ฌด๊ดํฉ๋๋ค.
๋ฐ๋ชจ_ood_๋ฐ์ดํฐ.csv
question,label
์ต๊ทผ ์๋์ผ ์ต๊ทผ ์ด์ ์๋ ค์ค,๋ด์ค์ด์
์ต๊ทผ ํซํ๋ ๊ฒ ์๋ ค์ค,๋ด์ค์ด์
๋ํํ
์ข์ ๋ช
์ธํด์ค ์ ์๋,๋ช
์ธ
๋ ์ข์ ๋ช
์ธ ์ข ๋ค๋ ค์ฃผ๋ผ,๋ช
์ธ
์ข์ ๋ช
์ธ ์ข ํด๋ด,๋ช
์ธ
๋ฐฑ์ฌ๋ฒ ๋
ธ๋ ๋ค์๋์,์์
๋น ๋
ธ๋ ๊นก ๋ฃ๊ณ ์ถ๋ค,์์
์ํ ost ์ถ์ฒํด์ค,์์
์ง๊ธ ์๊ฐ ์ข ์๋ ค๋ฌ๋ผ๊ณ ,๋ ์ง์๊ฐ
์ง๊ธ ์๊ฐ ์ข ์๋ ค์ค,๋ ์ง์๊ฐ
์ง๊ธ ๋ช ์ ๋ช ๋ถ์ธ์ง ์๋,๋ ์ง์๊ฐ
๋ช
์ ์คํธ๋ ์ค ใ
ใ
,์ก๋ด
๋ญํ๊ณ ๋์ง ใ
ใ
,์ก๋ด
๋๋ ๋์์ฃผ๋ผ ์ข,์ก๋ด
๋ญํ๊ณ ์ด์ง,์ก๋ด
... (์๋ต)
์ด๋ ๊ฒ ๋ผ๋ฒจ๋ง ํด๋ ๋์ง๋ง ์ด์ฐจํผ ๋ผ๋ฒจ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ง ์๊ธฐ ๋๋ฌธ์ ์๋์ฒ๋ผ ๋ผ๋ฒจ๋งํด๋ ๋ฌด๊ดํฉ๋๋ค.
๋ฐ๋ชจ_ood_๋ฐ์ดํฐ.csv
question,label
์ต๊ทผ ์๋์ผ ์ต๊ทผ ์ด์ ์๋ ค์ค,OOD
์ต๊ทผ ํซํ๋ ๊ฒ ์๋ ค์ค,OOD
๋ํํ
์ข์ ๋ช
์ธํด์ค ์ ์๋,OOD
๋ ์ข์ ๋ช
์ธ ์ข ๋ค๋ ค์ฃผ๋ผ,OOD
์ข์ ๋ช
์ธ ์ข ํด๋ด,OOD
๋ฐฑ์ฌ๋ฒ ๋
ธ๋ ๋ค์๋์,OOD
๋น ๋
ธ๋ ๊นก ๋ฃ๊ณ ์ถ๋ค,OOD
์ํ ost ์ถ์ฒํด์ค,OOD
์ง๊ธ ์๊ฐ ์ข ์๋ ค๋ฌ๋ผ๊ณ ,OOD
์ง๊ธ ์๊ฐ ์ข ์๋ ค์ค,OOD
์ง๊ธ ๋ช ์ ๋ช ๋ถ์ธ์ง ์๋,OOD
๋ช
์ ์คํธ๋ ์ค ใ
ใ
,OOD
๋ญํ๊ณ ๋์ง ใ
ใ
,OOD
๋๋ ๋์์ฃผ๋ผ ์ข,OOD
๋ญํ๊ณ ์ด์ง,OOD
... (์๋ต)
OOD ๋ฐ์ดํฐ๋ ๋ฌผ๋ก ๋ง์ผ๋ฉด ์ข๊ฒ ์ง๋ง ๋ง๋๋ ๊ฒ ์์ฒด๊ฐ ๋ถ๋ด์ด๊ธฐ ๋๋ฌธ์ ์ ์ ์๋ง ๋ฃ์ด๋ ๋ฉ๋๋ค.
๋ฐ๋ชจ ๋ฐ์ดํฐ์ ๊ฒฝ์ฐ๋ ์ด 3000๋ผ์ธ์ ๋ฐ์ดํฐ ์ค 600๋ผ์ธ์ ๋์ OOD ๋ฐ์ดํฐ๋ฅผ ์ฝ์
ํ์์ต๋๋ค.
๋ฐ์ดํฐ๊น์ง ๋ชจ๋ ์ฝ์
ํ์
จ๋ค๋ฉด kochat์ ์ด์ฉํ ์ค๋น๊ฐ ๋๋ฌ์ต๋๋ค. ์๋ ์ฑํฐ์์๋
์์ธํ ์ฌ์ฉ๋ฒ์ ๋ํด ์๋ ค๋๋ฆฌ๊ฒ ์ต๋๋ค.
4. Usage
from kochat.data
4.1. kochat.data
ํจํค์ง์๋ Dataset
ํด๋์ค๊ฐ ์์ต๋๋ค. Dataset
ํด๋์ค๋
๋ถ๋ฆฌ๋ raw ๋ฐ์ดํฐ ํ์ผ๋ค์ ํ๋๋ก ํฉ์ณ์ ํตํฉ intentํ์ผ๊ณผ ํตํฉ entityํ์ผ๋ก ๋ง๋ค๊ณ ,
embedding, intent, entity, inference์ ๊ด๋ จ๋ ๋ฐ์ดํฐ์
์ ๋ฏธ๋๋ฐฐ์น๋ก ์๋ผ์
pytorch์ DataLoader
ํํ๋ก ์ ๊ณตํฉ๋๋ค.
๋ํ ๋ชจ๋ธ, Loss ํจ์ ๋ฑ์ ์์ฑํ ๋ ํ๋ผ๋ฏธํฐ๋ก ์
๋ ฅํ๋ label_dict
๋ฅผ ์ ๊ณตํฉ๋๋ค.
Dataset
ํด๋์ค๋ฅผ ์์ฑํ ๋ ํ์ํ ํ๋ผ๋ฏธํฐ์ธ ood
๋ OOD ๋ฐ์ดํฐ์
์ฌ์ฉ ์ฌ๋ถ์
๋๋ค.
True๋ก ์ค์ ํ๋ฉด ood ๋ฐ์ดํฐ์
์ ์ฌ์ฉํฉ๋๋ค.
- Dataset ๊ธฐ๋ฅ 1. ๋ฐ์ดํฐ์ ์์ฑ
from kochat.data import Dataset
# ํด๋์ค ์์ฑ์ rawํ์ผ๋ค์ ๊ฒ์ฆํ๊ณ ํตํฉํฉ๋๋ค.
dataset = Dataset(ood=True, naver_fix=True)
# ์๋ฒ ๋ฉ ๋ฐ์ดํฐ์
์์ฑ
embed_dataset = dataset.load_embed()
# ์ธํ
ํธ ๋ฐ์ดํฐ์
์์ฑ (์๋ฒ ๋ฉ ํ๋ก์ธ์ ํ์)
intent_dataset = dataset.load_intent(emb)
# ์ํฐํฐ ๋ฐ์ดํฐ์
์์ฑ (์๋ฒ ๋ฉ ํ๋ก์ธ์ ํ์)
entity_dataset = dataset.load_entity(emb)
# ์ถ๋ก ์ฉ ๋ฐ์ดํฐ์
์์ฑ (์๋ฒ ๋ฉ ํ๋ก์ธ์ ํ์)
predict_dataset = dataset.load_predict("์์ธ ๋ง์ง ์ถ์ฒํด์ค", emb)
- Dataset ๊ธฐ๋ฅ 2. ๋ผ๋ฒจ ๋์ ๋๋ฆฌ ์์ฑ
from kochat.data import Dataset
# ํด๋์ค ์์ฑ์ rawํ์ผ๋ค์ ๊ฒ์ฆํ๊ณ ํตํฉํฉ๋๋ค.
dataset = Dataset(ood=True, naver_fix=True)
# ์ธํ
ํธ ๋ผ๋ฒจ ๋์
๋๋ฆฌ๋ฅผ ์์ฑํฉ๋๋ค.
intent_dict = dataset.intent_dict
# ์ํฐํฐ ๋ผ๋ฒจ ๋์
๋๋ฆฌ๋ฅผ ์์ฑํฉ๋๋ค.
entity_dict = dataset.entity_dict
โ Warning
Dataset
ํด๋์ค๋ ์ ์ฒ๋ฆฌ์ ํ ํฐํ๋ฅผ ์ํํ ๋,
ํ์ต/ํ
์คํธ ๋ฐ์ดํฐ๋ ๋์ด์ฐ๊ธฐ๋ฅผ ๊ธฐ์ค์ผ๋ก ํ ํฐํ๋ฅผ ์ํํ๊ณ , ์ค์ ์ฌ์ฉ์์ ์
๋ ฅ์
์ถ๋ก ํ ๋๋ ๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ์ Konlpy ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํ์ฌ ํ ํฐํ๋ฅผ ์ํํฉ๋๋ค.
๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ๋ฅผ ์ฌ์ฉํ๋ฉด ์ฑ๋ฅ์ ๋์ฑ ํฅ์๋๊ฒ ์ง๋ง, ์์
์ ์ผ๋ก ์ด์ฉ์ ๋ฌธ์ ๊ฐ
๋ฐ์ํ ์ ์๊ณ , ์ด์ ๋ํด ๊ฐ๋ฐ์๋ ์ด๋ ํ ์ฑ
์๋ ์ง์ง ์์ต๋๋ค.
๋ง์ฝ Kochat์ ์์
์ ์ผ๋ก ์ด์ฉํ์๋ ค๋ฉด Dataset
์์ฑ์ naver_fix
ํ๋ผ๋ฏธํฐ๋ฅผ
False
๋ก ์ค์ ํด์ฃผ์๊ธธ ๋ฐ๋๋๋ค. False
์ค์ ์์๋ Konlpy ํ ํฐํ๋ง ์ํํ๋ฉฐ,
์ถํ ๋ฒ์ ์์๋ ๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ๋ฅผ ์์ฒด์ ์ธ ๋์ด์ฐ๊ธฐ ๊ฒ์ฌ๋ชจ๋ ๋ฑ์ผ๋ก
๊ต์ฒดํ ์์ ์
๋๋ค.
from kochat.model
4.2. model
ํจํค์ง๋ ์ฌ์ ์ ์๋ ๋ค์ํ built-in ๋ชจ๋ธ๋ค์ด ์ ์ฅ๋ ํจํค์ง์
๋๋ค.
ํ์ฌ ๋ฒ์ ์์๋ ์๋ ๋ชฉ๋ก์ ํด๋นํ๋ ๋ชจ๋ธ๋ค์ ์ง์ํฉ๋๋ค. ์ถํ ๋ฒ์ ์ด ์
๋ฐ์ดํธ ๋๋ฉด
์ง๊ธ๋ณด๋ค ํจ์ฌ ๋ค์ํ built-in ๋ชจ๋ธ์ ์ง์ํ ์์ ์
๋๋ค. ์๋ ๋ชฉ๋ก์ ์ฐธ๊ณ ํ์ฌ ์ฌ์ฉํด์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
4.2.1. embed ๋ชจ๋ธ
from kochat.model import embed
# 1. Gensim์ Word2Vec ๋ชจ๋ธ์ Wrapper์
๋๋ค.
# (OOV ํ ํฐ์ ๊ฐ์ config์์ ์ค์ ๊ฐ๋ฅํฉ๋๋ค.)
word2vec = embed.Word2Vec()
# 2. Gensim์ FastText ๋ชจ๋ธ์ Wrapper์
๋๋ค.
fasttext = embed.FastText()
4.2.2. intent ๋ชจ๋ธ
from kochat.model import intent
# 1. Residual Learning์ ์ง์ํ๋ 1D CNN์
๋๋ค.
cnn = intent.CNN(label_dict=dataset.intent_dict, residual=True)
# 2. Bidirectional์ ์ง์ํ๋ LSTM์
๋๋ค.
lstm = intent.LSTM(label_dict=dataset.intent_dict, bidirectional=True)
4.2.3. entity ๋ชจ๋ธ
from kochat.model import entity
# 1. Bidirectional์ ์ง์ํ๋ LSTM์
๋๋ค.
lstm = entity.LSTM(label_dict=dataset.entity_dict, bidirectional=True)
4.2.4. ์ปค์คํ ๋ชจ๋ธ
Kochat์ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ํฉ๋๋ค.
Gensim์ด๋ Pytorch๋ก ์์ฑํ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ ํ์ต์ํค๊ธฐ๊ณ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์
์ฌ์ฉํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ง์ฝ ์ปค์คํ
๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด ์๋์ ๋ช๊ฐ์ง ๊ท์น์ ๋ฐ๋์
๋ฐ๋ผ์ผํฉ๋๋ค.
4.2.4.1. ์ปค์คํ Gensim embed ๋ชจ๋ธ
์๋ฒ ๋ฉ์ ๊ฒฝ์ฐ ํ์ฌ๋ Gensim ๋ชจ๋ธ๋ง ์ง์ํฉ๋๋ค. ์ถํ์ Pytorch๋ก ๋
์๋ฒ ๋ฉ ๋ชจ๋ธ(ELMO, BERT)๋ฑ๋ ์ง์ํ ๊ณํ์
๋๋ค.
Gensim Embedding ๋ชจ๋ธ์ ์๋์ ๊ฐ์ ํํ๋ก ๊ตฌํํด์ผํฉ๋๋ค.
-
@gensim
๋ฐ์ฝ๋ ์ดํฐ ์ค์ -
BaseWordEmbeddingsModel
๋ชจ๋ธ ์ค ํ ๊ฐ์ง ์์๋ฐ๊ธฐ -
super().__init__()
์ ํ๋ผ๋ฏธํฐ ์ฝ์ ํ๊ธฐ (self.XXX๋ก ์ ๊ทผ๊ฐ๋ฅ)
from gensim.models import FastText
from kochat.decorators import gensim
# 1. @gensim ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ GENSIM์ ์๋ ๋ชจ๋ ๋ฐ์ดํฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@gensim
class FastText(FastText):
# 2. BaseWordEmbeddingsModel ๋ชจ๋ธ์ค ํ ๊ฐ์ง๋ฅผ ์์๋ฐ์ต๋๋ค.
def __init__(self):
# 3. `super().__init__()`์ ํ์ํ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ฃ์ด์ ์ด๊ธฐํํด์ค๋๋ค.
super().__init__(size=self.vector_size,
window=self.window_size,
workers=self.workers,
min_count=self.min_count,
iter=self.iter)
4.2.4.2. ์ปค์คํ Intent ๋ชจ๋ธ
์ธํ
ํธ ๋ชจ๋ธ์ torch๋ก ๊ตฌํํฉ๋๋ค.
์ธํ
ํธ ๋ชจ๋ธ์๋ self.label_dict
๊ฐ ๋ฐ๋์ ์กด์ฌํด์ผํฉ๋๋ค.
๋ํ ์ต์ข
output ๋ ์ด์ด๋ ์๋์์ฑ๋๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํ๋ฉด ๋ฉ๋๋ค.
๋์ฑ ์ธ๋ถ์ ์ธ ๊ท์น์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
-
@intent
๋ฐ์ฝ๋ ์ดํฐ ์ค์ -
torch.nn.Module
์์๋ฐ๊ธฐ - ํ๋ผ๋ฏธํฐ๋ก label_dict๋ฅผ ์
๋ ฅ๋ฐ๊ณ
self.label_dict
์ ํ ๋นํ๊ธฐ -
forward()
ํจ์์์ feature๋ฅผ [batch_size, -1] ๋ก ๋ง๋ค๊ณ ๋ฆฌํด
from torch import nn
from torch import Tensor
from kochat.decorators import intent
from kochat.model.layers.convolution import Convolution
# 1. @intent ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ INTENT์ ์๋ ๋ชจ๋ ์ค์ ๊ฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@intent
class CNN(nn.Module):
# 2. torch.nn์ Module์ ์์๋ฐ์ต๋๋ค.
def __init__(self, label_dict: dict, residual: bool = True):
super(CNN, self).__init__()
self.label_dict = label_dict
# 3. intent๋ชจ๋ธ์ ๋ฐ๋์ ์์ฑ์ผ๋ก self.label_dict๋ฅผ ๊ฐ์ง๊ณ ์์ด์ผํฉ๋๋ค.
self.stem = Convolution(self.vector_size, self.d_model, kernel_size=1, residual=residual)
self.hidden_layers = nn.Sequential(*[
Convolution(self.d_model, self.d_model, kernel_size=1, residual=residual)
for _ in range(self.layers)])
def forward(self, x: Tensor) -> Tensor:
x = x.permute(0, 2, 1)
x = self.stem(x)
x = self.hidden_layers(x)
return x.view(x.size(0), -1)
# 4. feature๋ฅผ [batch_size, -1]๋ก ๋ง๋ค๊ณ ๋ฐํํฉ๋๋ค.
# ์ต์ข
output ๋ ์ด์ด๋ kochat์ด ์๋ ์์ฑํ๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํฉ๋๋ค.
import torch
from torch import nn, autograd
from torch import Tensor
from kochat.decorators import intent
# 1. @intent ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ INTENT์ ์๋ ๋ชจ๋ ์ค์ ๊ฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@intent
class LSTM(nn.Module):
# 2. torch.nn์ Module์ ์์๋ฐ์ต๋๋ค.
def __init__(self, label_dict: dict, bidirectional: bool = True):
super().__init__()
self.label_dict = label_dict
# 3. intent๋ชจ๋ธ์ ๋ฐ๋์ ์์ฑ์ผ๋ก self.label_dict๋ฅผ ๊ฐ์ง๊ณ ์์ด์ผํฉ๋๋ค.
self.direction = 2 if bidirectional else 1
self.lstm = nn.LSTM(input_size=self.vector_size,
hidden_size=self.d_model,
num_layers=self.layers,
batch_first=True,
bidirectional=bidirectional)
def init_hidden(self, batch_size: int) -> autograd.Variable:
param1 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
param2 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
return autograd.Variable(param1), autograd.Variable(param2)
def forward(self, x: Tensor) -> Tensor:
b, l, v = x.size()
out, (h_s, c_s) = self.lstm(x, self.init_hidden(b))
# 4. feature๋ฅผ [batch_size, -1]๋ก ๋ง๋ค๊ณ ๋ฐํํฉ๋๋ค.
# ์ต์ข
output ๋ ์ด์ด๋ kochat์ด ์๋ ์์ฑํ๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํฉ๋๋ค.
return h_s[0]
4.2.4.3. ์ปค์คํ Entity ๋ชจ๋ธ
์ํฐํฐ ๋ชจ๋ธ๋ ์ญ์ torch๋ก ๊ตฌํํฉ๋๋ค.
์ํฐํฐ ๋ชจ๋ธ์๋ ์ญ์ self.label_dict
๊ฐ ๋ฐ๋์ ์กด์ฌํด์ผํ๋ฉฐ,
๋ํ ์ต์ข
output ๋ ์ด์ด๋ ์๋์์ฑ๋๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํ๋ฉด ๋ฉ๋๋ค.
๋์ฑ ์ธ๋ถ์ ์ธ ๊ท์น์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
-
@entity
๋ฐ์ฝ๋ ์ดํฐ ์ค์ -
torch.nn.Module
์์๋ฐ๊ธฐ - ํ๋ผ๋ฏธํฐ๋ก label_dict๋ฅผ ์
๋ ฅ๋ฐ๊ณ
self.label_dict
์ ํ ๋นํ๊ธฐ -
forward()
ํจ์์์ feature๋ฅผ [batch_size, max_len, -1] ๋ก ๋ง๋ค๊ณ ๋ฆฌํด
import torch
from torch import nn, autograd
from torch import Tensor
from kochat.decorators import entity
# 1. @entity ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ์ค์ ํ๋ฉด
# config์ ENTITY์ ์๋ ๋ชจ๋ ์ค์ ๊ฐ์ ์ ๊ทผ ๊ฐ๋ฅํฉ๋๋ค.
@entity
class LSTM(nn.Module):
# 2. torch.nn์ Module์ ์์๋ฐ์ต๋๋ค.
def __init__(self, label_dict: dict, bidirectional: bool = True):
super().__init__()
self.label_dict = label_dict
# 3. entity๋ชจ๋ธ์ ๋ฐ๋์ ์์ฑ์ผ๋ก self.label_dict๋ฅผ ๊ฐ์ง๊ณ ์์ด์ผํฉ๋๋ค.
self.direction = 2 if bidirectional else 1
self.lstm = nn.LSTM(input_size=self.vector_size,
hidden_size=self.d_model,
num_layers=self.layers,
batch_first=True,
bidirectional=bidirectional)
def init_hidden(self, batch_size: int) -> autograd.Variable:
param1 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
param2 = torch.randn(self.layers * self.direction, batch_size, self.d_model).to(self.device)
return torch.autograd.Variable(param1), torch.autograd.Variable(param2)
def forward(self, x: Tensor) -> Tensor:
b, l, v = x.size()
out, _ = self.lstm(x, self.init_hidden(b))
# 4. feature๋ฅผ [batch_size, max_len, -1]๋ก ๋ง๋ค๊ณ ๋ฐํํฉ๋๋ค.
# ์ต์ข
output ๋ ์ด์ด๋ kochat์ด ์๋ ์์ฑํ๊ธฐ ๋๋ฌธ์ feature๋ง ์ถ๋ ฅํฉ๋๋ค.
return out
from kochat.proc
4.3. proc
์ Procssor์ ์ค์๋ง๋ก, ๋ค์ํ ๋ชจ๋ธ๋ค์
ํ์ต/ํ
์คํธ์ ์ํํ๋ ํจ์์ธ fit()
๊ณผ
์ถ๋ก ์ ์ํํ๋ ํจ์์ธ predict()
๋ฑ์ ์ํํ๋ ํด๋์ค ์งํฉ์
๋๋ค.
ํ์ฌ ์ง์ํ๋ ํ๋ก์ธ์๋ ์ด 4๊ฐ์ง๋ก ์๋์์ ์์ธํ๊ฒ ์ค๋ช
ํฉ๋๋ค.
from kochat.proc import GensimEmbedder
4.3.1. GensimEmbedder๋ Gensim์ ์๋ฒ ๋ฉ ๋ชจ๋ธ์ ํ์ต์ํค๊ณ , ํ์ต๋ ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ฌธ์ฅ์ ์๋ฒ ๋ฉํ๋ ํด๋์ค์ ๋๋ค. ์์ธํ ์ฌ์ฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import GensimEmbedder
from kochat.model import embed
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
emb = GensimEmbedder(
model=embed.FastText()
)
# ๋ชจ๋ธ ํ์ต
emb.fit(dataset.load_embed())
# ๋ชจ๋ธ ์ถ๋ก (์๋ฒ ๋ฉ)
user_input = emb.predict("์์ธ ํ๋ ๋ง์ง ์๋ ค์ค")
from kochat.proc import SoftmaxClassifier
4.3.2. SoftmaxClassifier
๋ ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ๋ถ๋ฅ ํ๋ก์ธ์์
๋๋ค.
์ด๋ฆ์ด SoftmaxClassifier์ธ ์ด์ ๋ Softmax Score๋ฅผ ์ด์ฉํด Fallback Detection์ ์ํํ๊ธฐ ๋๋ฌธ์
์ด๋ ๊ฒ ๋ช
๋ช
ํ๊ฒ ๋์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ช๋ช ๋
ผ๋ฌธ
์์ Calibrate๋์ง ์์ Softmax Score์ ๋ง์น Confidence์ฒ๋ผ
์ฐฉ๊ฐํด์ ์ฌ์ฉํ๋ฉด ์ฌ๊ฐํ ๋ฌธ์ ๊ฐ ๋ฐ์ํ ์ ์๋ค๋ ๊ฒ์ ๋ณด์ฌ์ฃผ์์ต๋๋ค.
์์ ๊ทธ๋ฆผ์ MNIST ๋ถ๋ฅ๋ชจ๋ธ์์ 0.999 ์ด์์ Softmax Score๋ฅผ ๊ฐ์ง๋ ์ด๋ฏธ์ง๋ค์
๋๋ค.
์ค์ ๋ก 0 ~ 9๊น์ง์ ์ซ์์๋ ์ ํ ์๊ด์๋ ์ด๋ฏธ์ง๋ค์ด๊ธฐ ๋๋ฌธ์ ๋ฎ์ Softmax Score๋ฅผ
๊ฐ์ง ๊ฒ์ด๋ผ๊ณ ์๊ฐ๋์ง๋ง ์ค์ ๋ก๋ ๊ทธ๋ ์ง ์์ต๋๋ค.
์ฌ์ค SoftmaxClassifier
๋ฅผ ์ค์ ์ฑ๋ด์ Intent Classification ๊ธฐ๋ฅ์ ์ํด
์ฌ์ฉํ๋ ๊ฒ์ ์ ์ ํ์ง ๋ชปํฉ๋๋ค. SoftmaxClassifier
๋ ์๋ ํ์ ํ DistanceClassifier
์์ ์ฑ๋ฅ ๋น๊ต๋ฅผ ์ํด ๊ตฌํํ์์ต๋๋ค. ์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import SoftmaxClassifier
from kochat.model import intent
from kochat.loss import CrossEntropyLoss
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
clf = SoftmaxClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CrossEntropyLoss(dataset.intent_dict)
)
# ๋๋๋ก์ด๋ฉด SoftmaxClassifier๋ CrossEntropyLoss๋ฅผ ์ด์ฉํด์ฃผ์ธ์
# ๋ค๋ฅธ Loss ํจ์๋ค์ ๊ฑฐ๋ฆฌ ๊ธฐ๋ฐ์ Metric Learning์ ์ํํ๊ธฐ ๋๋ฌธ์
# Softmax Classifiaction์ ์ ์ ํ์ง ๋ชปํ ์ ์์ต๋๋ค.
# ๋ชจ๋ธ ํ์ต
clf.fit(dataset.load_intent(emb))
# ๋ชจ๋ธ ์ถ๋ก (์ธํ
ํธ ๋ถ๋ฅ)
clf.predict(dataset.load_predict("์ค๋ ์์ธ ๋ ์จ ์ด๋จ๊น", emb))
from kochat.proc import DistanceClassifier
4.3.3. DistanceClassifier
๋ SoftmaxClassifier
์๋ ๋ค๋ฅด๊ฒ ๊ฑฐ๋ฆฌ๊ธฐ๋ฐ์ผ๋ก ์๋ํ๋ฉฐ,
์ผ์ข
์ Memory Network์
๋๋ค. [batch_size, -1] ์ ์ฌ์ด์ฆ๋ก ์ถ๋ ฅ๋ ์ถ๋ ฅ๋ฒกํฐ์
๊ธฐ์กด ๋ฐ์ดํฐ์
์ ์๋ ๋ฌธ์ฅ ๋ฒกํฐ๋ค ์ฌ์ด์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ์ฌ ๋ฐ์ดํฐ์
์์ ๊ฐ์ฅ ๊ฐ๊น์ด
K๊ฐ์ ์ํ์ ์ฐพ๊ณ ์ต๋ค ์ํ ํด๋์ค๋ก ๋ถ๋ฅํ๋ ์ต๊ทผ์ ์ด์ Retrieval ๊ธฐ๋ฐ์ ๋ถ๋ฅ ๋ชจ๋ธ์
๋๋ค.
์ด ๋ ๋ค๋ฅธ ํด๋์ค๋ค์ ๋ฉ๋ฆฌ, ๊ฐ์ ํด๋์ค๋ผ๋ฆฌ๋ ๊ฐ๊น์ด ์์ด์ผ ๋ถ๋ฅํ๊ธฐ์ ์ข๊ธฐ ๋๋ฌธ์ ์ฌ์ฉ์๊ฐ ์ค์ ํ Lossํจ์(์ฃผ๋ก Margin ๊ธฐ๋ฐ Loss)๋ฅผ ์ ์ฉํด Metric Learning์ ์ํํด์ ํด๋์ค ๊ฐ์ Margin์ ์ต๋์น๋ก ๋ฒ๋ฆฌ๋ ๋ฉ์ปค๋์ฆ์ด ๊ตฌํ๋์ด์์ต๋๋ค. ๋ํ ์ต๊ทผ์ ์ด์ ์๊ณ ๋ฆฌ์ฆ์ K๊ฐ์ config์์ ์ง์ ์ง์ ํ ์๋ ์๊ณ GridSearch๋ฅผ ์ ์ฉํ์ฌ ์๋์ผ๋ก ์ต์ ์ K๊ฐ์ ์ฐพ์ ์ ์๊ฒ ์ค๊ณํ์์ต๋๋ค.
์ต๊ทผ์ ์ด์์ ์ฐพ์ ๋ Brute force๋ก ์ง์ ๊ฑฐ๋ฆฌ๋ฅผ ์ผ์ผ์ด ๋ค ๊ตฌํ๋ฉด ๊ต์ฅํ ๋๋ฆฌ๊ธฐ
๋๋ฌธ์ ๋ค์ฐจ์ ๊ฒ์ํธ๋ฆฌ์ธ KDTree
ํน์ BallTree
(KDTree์ ๊ฐ์ ํํ)๋ฅผ ํตํด์
๊ฑฐ๋ฆฌ๋ฅผ ๊ณ์ฐํ๋ฉฐ ๊ฒฐ๊ณผ๋ก ๋ง๋ค์ด์ง ํธ๋ฆฌ ๊ตฌ์กฐ๋ฅผ ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํฉ๋๋ค. ๊ฒ์ํธ๋ฆฌ์ ์ข
๋ฅ,
๊ฑฐ๋ฆฌ ๋ฉํธ๋ฆญ(์ ํด๋ฆฌ๋์ธ, ๋งจํํผ ๋ฑ..)์ ์ ๋ถ GridSearch๋ก ์๋ํ ์ํฌ ์ ์์ผ๋ฉฐ
์ด์ ๋ํ ์ค์ ์ config์์ ๊ฐ๋ฅํฉ๋๋ค. ํธ๋ฆฌ๊ธฐ๋ฐ์ ๊ฒ์ ์๊ณ ๋ฆฌ์ฆ์ ์ฌ์ฉํ๊ธฐ ๋๋ฌธ์
SoftmaxClassifier
์ ๊ฑฐ์ ๋น์ทํ ์๋๋ก ํ์ต ๋ฐ ์ถ๋ก ์ด ๊ฐ๋ฅํฉ๋๋ค.
์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import DistanceClassifier
from kochat.model import intent
from kochat.loss import CenterLoss
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
clf = DistanceClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CenterLoss(dataset.intent_dict)
)
# ๋๋๋ก์ด๋ฉด DistanceClassifier๋ Margin ๊ธฐ๋ฐ์ Loss ํจ์๋ฅผ ์ด์ฉํด์ฃผ์ธ์
# ํ์ฌ๋ CenterLoss, COCOLoss, Cosface, GausianMixture ๋ฑ์
# ๊ฑฐ๋ฆฌ๊ธฐ๋ฐ Metric Learning ์ ์ฉ Lossํจ์๋ฅผ ์ง์ํฉ๋๋ค.
# ๋ชจ๋ธ ํ์ต
clf.fit(dataset.load_intent(emb))
# ๋ชจ๋ธ ์ถ๋ก (์ธํ
ํธ ๋ถ๋ฅ)
clf.predict(dataset.load_predict("์ค๋ ์์ธ ๋ ์จ ์ด๋จ๊น", emb))
FallbackDetector
4.3.4. SoftmaxClassifier
์ DistanceClassifier
๋ชจ๋ Fallback Detection ๊ธฐ๋ฅ์ ๊ตฌํ๋์ด์์ต๋๋ค.
Fallback Detection ๊ธฐ๋ฅ์ ์ด์ฉํ๋ ๋ฐฉ๋ฒ์ ์๋์ ๊ฐ์ด ๋ ๊ฐ์ง ๋ฐฉ๋ฒ์ ์ ๊ณตํฉ๋๋ค.
1. OOD ๋ฐ์ดํฐ๊ฐ ์๋ ๊ฒฝ์ฐ : ์ง์ config์ Threshold๋ฅผ ๋ง์ถฐ์ผํฉ๋๋ค.
2. OOD ๋ฐ์ดํฐ๊ฐ ์๋ ๊ฒฝ์ฐ : ๋จธ์ ๋ฌ๋์ ์ด์ฉํ์ฌ Threshold๋ฅผ ์๋ ํ์ตํฉ๋๋ค.
๋ฐ๋ก ์ฌ๊ธฐ์์ OOD ๋ฐ์ดํฐ์
์ด ์ฌ์ฉ๋ฉ๋๋ค.
SoftmaxClassifier
๋ out distribution ์ํ๋ค๊ณผ in distribution ์ํ๊ฐ์
maximum softmax score (size = [batch_size, 1])๋ฅผ feature๋ก ํ์ฌ
๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ๊ณ ,
DistanceClassifier
๋ out distribution ์ํ๋ค๊ณผ in distribution ์ํ๋ค์
K๊ฐ์ ์ต๊ทผ์ ์ด์์ ๊ฑฐ๋ฆฌ (size = [batch_size, K])๋ฅผ feature๋ก ํ์ฌ
๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค.
์ด๋ฌํ ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ FallbackDetector
๋ผ๊ณ ํฉ๋๋ค. FallbackDetector
๋ ๊ฐ
Classifier์์ ๋ด์ฅ ๋์ด์๊ธฐ ๋๋ฌธ์ ๋ณ๋ค๋ฅธ ์ถ๊ฐ ์์ค์ฝ๋ ์์ด Dataset
์ ood
ํ๋ผ๋ฏธํฐ๋ง True
๋ก ์ค์ ๋์ด์๋ค๋ฉด Classifierํ์ต์ด ๋๋๊ณ ๋์ ์๋์ผ๋ก ํ์ต๋๊ณ ,
predict()
์ ์ ์ฅ๋ FallbackDetector
๊ฐ ์๋ค๋ฉด ์๋์ผ๋ก ๋์ํฉ๋๋ค.
๋ํ FallbackDetector
๋ก ์ฌ์ฉํ ๋ชจ๋ธ์ ์๋์ฒ๋ผ config์์ ์ฌ์ฉ์๊ฐ ์ง์ ์ค์ ํ ์ ์์ผ๋ฉฐ
GridSearch๋ฅผ ์ง์ํ์ฌ ์ฌ๋ฌ๊ฐ์ ๋ชจ๋ธ์ ๋ฆฌ์คํธ์ ๋ฃ์ด๋๋ฉด Kochat ํ๋ ์์ํฌ๊ฐ
ํ์ฌ ๋ฐ์ดํฐ์
์ ๊ฐ์ฅ ์ ํฉํ FallbackDetector
๋ฅผ ์๋์ผ๋ก ๊ณจ๋ผ์ค๋๋ค.
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
INTENT = {
# ... (์๋ต)
# ํด๋ฐฑ ๋ํ
ํฐ ํ๋ณด (์ ํ ๋ชจ๋ธ์ ์ถ์ฒํฉ๋๋ค)
'fallback_detectors': [
LogisticRegression(max_iter=30000),
LinearSVC(max_iter=30000)
# ๊ฐ๋ฅํ max_iter๋ฅผ ๋๊ฒ ์ค์ ํด์ฃผ์ธ์
# sklearn default๊ฐ max_iter=100์ด๋ผ์ ์๋ ด์ด ์๋ฉ๋๋ค...
]
}
Fallback Detection ๋ฌธ์ ๋ Fallback ๋ฉํธ๋ฆญ(๊ฑฐ๋ฆฌ or score)๊ฐ ์ผ์ ์๊ณ์น๋ฅผ ๋์ด๊ฐ๋ฉด
์ํ์ in / out distribution ์ํ๋ก ๋ถ๋ฅํ๋๋ฐ ๊ทธ ์๊ณ์น๋ฅผ ํ์ฌ ๋ชจ๋ฅด๋ ์ํฉ์ด๋ฏ๋ก
์ ํ ๋ฌธ์ ๋ก ํด์ํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ FallbackDetector๋ก๋ ์ ์ฒ๋ผ ์ ํ ๋ชจ๋ธ์ธ
์ ํ SVM, ๋ก์ง์คํฑ ํ๊ท ๋ฑ์ ์ฃผ๋ก ์ด์ฉํฉ๋๋ค. ๋ฌผ๋ก ์์ ๋ฆฌ์คํธ์
RandomForestClassifier()
๋ BernoulliNB()
, GradientBoostingClassifier()
๋ฑ
๋ค์ํ sklearn ๋ชจ๋ธ์ ์
๋ ฅํด๋ ๋์์ ํ์ง๋ง, ์ผ๋ฐ์ ์ผ๋ก ์ ํ๋ชจ๋ธ์ด ๊ฐ์ฅ ์ฐ์ํ๊ณ
์์ ์ ์ธ ์ฑ๋ฅ์ ๋ณด์์ต๋๋ค.
์ด๋ ๊ฒ Fallback์ ๋ฉํธ๋ฆญ์ผ๋ก ๋จธ์ ๋ฌ๋ ๋ชจ๋ธ์ ํ์ตํ๋ฉด Threshold๋ฅผ ์ง์ ์ ์ ๊ฐ ์ค์ ํ์ง ์์๋ ๋ฉ๋๋ค. OOD ๋ฐ์ดํฐ์ ์ด ํ์ํ๋ค๋ ์น๋ช ์ ์ธ ๋จ์ ์ด ์์ง๋ง, ์ฐจํ ๋ฒ์ ์์๋ BERT์ Markov Chain์ ์ด์ฉํด OOD ๋ฐ์ดํฐ์ ์ ์๋์ผ๋ก ๋น ๋ฅด๊ฒ ์์ฑํ๋ ๋ชจ๋ธ์ ๊ตฌํํ์ฌ ์ถ๊ฐํ ์์ ์ ๋๋ค. (์ด ์ ๋ฐ์ดํธ ์ดํ๋ถํฐ๋ OOD ๋ฐ์ดํฐ์ ์ด ํ์ ์์ด์ง๋๋ค.)
๊ทธ๋ฌ๋ ์์ง OOD ๋ฐ์ดํฐ์ ์์ฑ๊ธฐ๋ฅ์ ์ง์ํ์ง ์๊ธฐ ๋๋ฌธ์ ํ์ฌ ๋ฒ์ ์์๋ ๋ง์ฝ OOD ๋ฐ์ดํฐ์ ์ด ์๋ค๋ฉด ์ฌ์ฉ์๊ฐ ์ง์ Threshold๋ฅผ ์ค์ ํด์ผ ํ๋ฏ๋ก ๋์ผ๋ก ์ํ๋ค์ด ์ด๋์ ๋ score ํน์ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ๋์ง ํ์ธํด์ผํฉ๋๋ค. ๋ฐ๋ผ์ Kochat์ Calibrate ๋ชจ๋๋ฅผ ์ง์ํฉ๋๋ค.
while True:
user_input = dataset.load_predict(input(), emb)
# ํฐ๋ฏธ๋์ ์ง์ ood๋ก ์๊ฐ๋ ๋งํ ์ํ์ ์
๋ ฅํด์
# ๋์ผ๋ก ๊ฒฐ๊ณผ๋ฅผ ์ง์ ํ์ธํ๊ณ , threshold๋ฅผ ์ง์ ์กฐ์ ํฉ๋๋ค.
result = clf.predict(user_input, calibrate=True)
print("classification result : {}".format(result))
# DistanceClassifier
>>> '=====================CALIBRATION_MODE====================='
'ํ์ฌ ์
๋ ฅํ์ ๋ฌธ์ฅ๊ณผ ๊ธฐ์กด ๋ฌธ์ฅ๋ค ์ฌ์ด์ ๊ฑฐ๋ฆฌ ํ๊ท ์ 2.912์ด๊ณ '
'๊ฐ๊น์ด ์ํ๋ค๊ณผ์ ๊ฑฐ๋ฆฌ๋ [2.341, 2.351, 2.412, 2.445 ...]์
๋๋ค.'
'์ด ์์น๋ฅผ ๋ณด๊ณ Config์ fallback_detection_threshold๋ฅผ ๋ง์ถ์ธ์.'
'criteria๋ ๊ฑฐ๋ฆฌํ๊ท (mean) / ์ต์๊ฐ(min)์ผ๋ก ์ค์ ํ ์ ์์ต๋๋ค.'
# SoftmaxClassifier
>>> '=====================CALIBRATION_MODE====================='
'ํ์ฌ ์
๋ ฅํ์ ๋ฌธ์ฅ์ softmax logits์ 0.997์
๋๋ค.'
'์ด ์์น๋ฅผ ๋ณด๊ณ Config์ fallback_detection_threshold๋ฅผ ๋ง์ถ์ธ์.'
์ด๋ ๊ฒ calibrate ๋ชจ๋๋ฅผ ์ฌ๋ฌ๋ฒ ์งํํ์ ์ ์ค์ค๋ก ๊ณ์ฐํ threshold์ ์ํ๋ criteria๋ฅผ ์๋์ฒ๋ผ config์ ์ค์ ํ๋ฉด ood ๋ฐ์ดํฐ์ ์์ด๋ FallbackDetector๋ฅผ ์ด์ฉํ ์ ์์ต๋๋ค.
INTENT = {
'distance_fallback_detection_criteria': 'mean', # or 'min'
# [auto, min, mean], auto๋ OOD ๋ฐ์ดํฐ ์์๋๋ง ๊ฐ๋ฅ
'distance_fallback_detection_threshold': 3.2,
# mean ํน์ min ์ ํ์ ์๊ณ๊ฐ
'softmax_fallback_detection_criteria': 'other',
# [auto, other], auto๋ OOD ๋ฐ์ดํฐ ์์๋๋ง ๊ฐ๋ฅ
'softmax_fallback_detection_threshold': 0.88,
# other ์ ํ์ fallback์ด ๋์ง ์๋ ์ต์ ๊ฐ
}
๊ทธ๋ฌ๋ ์ง๊ธ ๋ฒ์ ์์๋ ๊ฐ๊ธ์ OOD ๋ฐ์ดํฐ์
์ ์ถ๊ฐํด์ ์ด์ฉํด์ฃผ์ธ์.
์ ์์ผ์๋ฉด ์ ๊ฐ ๋ฐ๋ชจ ํด๋์ ๋ฃ์ด๋์ ๋ฐ์ดํฐ๋ผ๋ ๋ฃ์ด์ ์๋ํํด์ ์ฐ๋๊ฒ
ํจ์ฌ ์ฑ๋ฅ์ด ์ข์ต๋๋ค. ๋ช๋ช ๋น๋๋ค์ ์ด ์๊ณ์น๋ฅผ ์ง์ ์ ํ๊ฒ ํ๊ฑฐ๋ ๊ทธ๋ฅ ์์๋ก
fixํด๋๋๋ฐ, ๊ฐ์ธ์ ์ผ๋ก ์ด๊ฑธ ๊ทธ๋ฅ ์์๋ก fix ํด๋๊ฑฐ๋ ์ ์ ๋ณด๊ณ ์ง์ ์ ํ๊ฒ ํ๋๊ฑด
์ฑ๋ด ๋น๋๋ก์, ํน์ ํ๋ ์์ํฌ๋ก์ ๋ฌด์ฑ
์ํ ๊ฒ ์๋๊ฐ ์ถ์ต๋๋ค.
from kochat.proc import EntityRecongnizer
4.3.5. EntityRecongnizer
๋ ์ํฐํฐ ๊ฒ์ถ์ ๋ด๋นํ๋ Entity ๋ชจ๋ธ๋ค์ ํ์ต/ํ
์คํธ ์ํค๊ณ ์ถ๋ก ํ๋
ํด๋์ค์
๋๋ค. Entity ๊ฒ์ฌ์ ๊ฒฝ์ฐ ๋ฌธ์ฅ 1๊ฐ๋น ๋ผ๋ฒจ์ด ์ฌ๋ฌ๊ฐ(๋จ์ด ๊ฐฏ์์ ๋์ผ)์
๋๋ค.
๋ฌธ์ ๋ Outside ํ ํฐ์ธ 'O'๊ฐ ๋๋ถ๋ถ์ด๊ธฐ ๋๋ฌธ์ ์ ๋ถ๋ค 'O'๋ผ๊ณ ๋ง ์์ธกํด๋ ๊ฑฐ์ 90% ์ก๋ฐํ๋
์ ํ๋๊ฐ ๋์ค๊ฒ ๋ฉ๋๋ค. ๋ํ, ํจ๋์ํ์ฑํ ๋ถ๋ถ๋ 'O'๋ก ์ฒ๋ฆฌ ๋์ด์๋๋ฐ, ์ด ๋ถ๋ถ๋ ๋ง์๊ฒ์ผ๋ก
์๊ฐํ๊ณ Loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด Kochat์ F1 Score, Recall, Precision ๋ฑ NER์ ์ฑ๋ฅ์ ๋ณด๋ค ์ ํํ๊ฒ ํ๊ฐ ํ ์ ์๋ ๊ฐ๋ ฅํ Validation ๋ฐ ์๊ฐํ ์ง์๊ณผ Loss ํจ์ ๊ณ์ฐ์ PAD๋ถ๋ถ์ masking์ ์ ์ฉํ ์ ์์ต๋๋ค. (mask ์ ์ฉ ์ฌ๋ถ ์ญ์ config์์ ์ค์ ๊ฐ๋ฅํฉ๋๋ค.) ์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ์ต๋๋ค.
from kochat.data import Dataset
from kochat.proc import EntityRecognizer
from kochat.model import entity
from kochat.loss import CRFLoss
dataset = Dataset(ood=True)
# ํ๋ก์ธ์ ์์ฑ
rcn = EntityRecognizer(
model=entity.LSTM(dataset.intent_dict),
loss=CRFLoss(dataset.intent_dict)
# Conditional Random Field๋ฅผ Lossํจ์๋ก ์ง์ํฉ๋๋ค.
)
# ๋ชจ๋ธ ํ์ต
rcn.fit(dataset.load_entity(emb))
# ๋ชจ๋ธ ์ถ๋ก (์ํฐํฐ ๊ฒ์ถ)
rcn.predict(dataset.load_predict("์ค๋ ์์ธ ๋ ์จ ์ด๋จ๊น", emb))
from kochat.loss
4.4. loss
ํจํค์ง๋ ์ฌ์ ์ ์๋ ๋ค์ํ built-in Loss ํจ์๋ค์ด ์ ์ฅ๋ ํจํค์ง์
๋๋ค.
ํ์ฌ ๋ฒ์ ์์๋ ์๋ ๋ชฉ๋ก์ ํด๋นํ๋ Loss ํจ์๋ค์ ์ง์ํฉ๋๋ค. ์ถํ ๋ฒ์ ์ด ์
๋ฐ์ดํธ ๋๋ฉด
์ง๊ธ๋ณด๋ค ํจ์ฌ ๋ค์ํ built-in Loss ํจ์๋ฅผ ์ง์ํ ์์ ์
๋๋ค. ์๋ ๋ชฉ๋ก์ ์ฐธ๊ณ ํ์ฌ ์ฌ์ฉํด์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
4.4.1. intent loss ํจ์
Intent Loss ํจ์๋ ๊ธฐ๋ณธ์ ์ธ CrossEntropyLoss์ ๋ค์ํ Distance ๊ธฐ๋ฐ์ Lossํจ์๋ฅผ ํ์ฉํ ์ ์์ต๋๋ค. CrossEntropy๋ ํ์ ํ Softmax ๊ธฐ๋ฐ์ IntentClassifier์ ์ฃผ๋ก ํ์ฉํ๊ณ , Distance ๊ธฐ๋ฐ์ Loss ํจ์๋ค์ Distance ๊ธฐ๋ฐ์ IntentClassifier์ ํ์ฉํ ์ ์์ต๋๋ค. Distance ๊ธฐ๋ฐ์ Lossํจ์๋ค์ ์ปดํจํฐ ๋น์ ์์ญ (์ฃผ๋ก ์ผ๊ตด์ธ์) ๋ถ์ผ์์ ์ ์๋ ํจ์๋ค์ด์ง๋ง Intent ๋ถ๋ฅ์ Fallback ๋ํ ์ ์๋ ๋งค์ฐ ์ฐ์ํ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค.
from kochat.loss import CrossEntropyLoss
from kochat.loss import CenterLoss
from kochat.loss import GaussianMixture
from kochat.loss import COCOLoss
from kochat.loss import CosFace
# 1. ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ Cross Entropy Loss ํจ์์
๋๋ค.
cross_entropy = CrossEntropyLoss(label_dict=dataset.intent_dict)
# 2. Intra Class ๊ฐ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ขํ ์ ์๋ Center Loss ํจ์์
๋๋ค.
center_loss = CenterLoss(label_dict=dataset.intent_dict)
# 3. Intra Class ๊ฐ์ ๊ฑฐ๋ฆฌ๋ฅผ ์ขํ ์ ์๋ Large Margin Gaussian Mixture Loss ํจ์์
๋๋ค.
lmgl = GaussianMixture(label_dict=dataset.intent_dict)
# 4. Inter Class ๊ฐ์ Cosine ๋ง์ง์ ํค์ธ ์ ์๋ COCO (Congenerous Cosine) Loss ํจ์์
๋๋ค.
coco_loss = COCOLoss(label_dict=dataset.intent_dict)
# 5. Inter Class ๊ฐ์ Cosine ๋ง์ง์ ํค์ธ ์ ์๋ Cosface (Large Margin Cosine) Lossํจ์์
๋๋ค.
cosface = CosFace(label_dict=dataset.intent_dict)
4.4.2. entity loss ํจ์
Entity Loss ํจ์๋ ๊ธฐ๋ณธ์ ์ธ CrossEntropyLoss์ ํ๋ฅ ์ ๋ชจ๋ธ์ธ
Conditional Random Field (์ดํ CRF) Loss๋ฅผ ์ง์ํฉ๋๋ค.
CRF Loss๋ฅผ ์ ์ฉํ๋ฉด, EntityRecognizer์ ์ถ๋ ฅ ๊ฒฐ๊ณผ๋ฅผ ๋ค์ํ๋ฒ ๊ต์ ํ๋
ํจ๊ณผ๋ฅผ ๋ณผ ์ ์์ผ๋ฉฐ CRF Loss๋ฅผ ์ ์ฉํ๋ฉด, ์ถ๋ ฅ ๋์ฝ๋ฉ์ Viterbi ์๊ณ ๋ฆฌ์ฆ์
ํตํด ์ํํฉ๋๋ค.
from kochat.loss import CrossEntropyLoss
from kochat.loss import CRFLoss
# 1. ๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ cross entropy ๋ก์ค ํจ์์
๋๋ค.
cross_entropy = CrossEntropyLoss(label_dict=dataset.intent_dict)
# 2. CRF Loss ํจ์์
๋๋ค.
center_loss = CRFLoss(label_dict=dataset.intent_dict)
4.4.3. ์ปค์คํ loss ํจ์
Kochat์ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ํฉ๋๋ค.
Pytorch๋ก ์์ฑํ ์ปค์คํ
๋ชจ๋ธ์ ์ง์ ํ์ต์ํค๊ธฐ๊ณ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์
์ฌ์ฉํ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ง์ฝ ์ปค์คํ
๋ชจ๋ธ์ ์ฌ์ฉํ๋ ค๋ฉด ์๋์ ๋ช๊ฐ์ง ๊ท์น์ ๋ฐ๋์
๋ฐ๋ผ์ผํฉ๋๋ค.
- forward ํจ์์์ ํด๋น loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
- compute_loss ํจ์์์ ๋ผ๋ฒจ๊ณผ ๋น๊ตํ์ฌ ์ต์ข
loss๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
์๋์ ๊ตฌํ ์์ ๋ฅผ ๋ณด๋ฉด ๋์ฑ ์ฝ๊ฒ ์ดํดํ ์ ์์ต๋๋ค.
@intent
class CosFace(BaseLoss):
def __init__(self, label_dict: dict):
super(CosFace, self).__init__()
self.classes = len(label_dict)
self.centers = nn.Parameter(torch.randn(self.classes, self.d_loss))
def forward(self, feat: Tensor, label: Tensor) -> Tensor:
# 1. forward ํจ์์์ ํ์ฌ lossํจ์์ loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
batch_size = feat.shape[0]
norms = torch.norm(feat, p=2, dim=-1, keepdim=True)
nfeat = torch.div(feat, norms)
norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True)
ncenters = torch.div(self.centers, norms_c)
logits = torch.matmul(nfeat, torch.transpose(ncenters, 0, 1))
y_onehot = torch.FloatTensor(batch_size, self.classes)
y_onehot.zero_()
y_onehot = Variable(y_onehot).cuda()
y_onehot.scatter_(1, torch.unsqueeze(label, dim=-1), self.cosface_m)
margin_logits = self.cosface_s * (logits - y_onehot)
return margin_logits
def compute_loss(self, label: Tensor, logits: Tensor, feats: Tensor, mask: nn.Module = None) -> Tensor:
# 2. compute loss์์ ์ต์ข
loss๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
mlogits = self(feats, label)
# ์๊ธฐ ์์ ์ forward ํธ์ถ
return F.cross_entropy(mlogits, label)
@intent
class CenterLoss(BaseLoss):
def __init__(self, label_dict: dict):
super(CenterLoss, self).__init__()
self.classes = len(label_dict)
self.centers = nn.Parameter(torch.randn(self.classes, self.d_loss))
self.center_loss_function = CenterLossFunction.apply
def forward(self, feat: Tensor, label: Tensor) -> Tensor:
# 1. forward ํจ์์์ ํ์ฌ lossํจ์์ loss๋ฅผ ๊ณ์ฐํฉ๋๋ค.
batch_size = feat.size(0)
feat = feat.view(batch_size, 1, 1, -1).squeeze()
if feat.size(1) != self.d_loss:
raise ValueError("Center's dim: {0} should be equal to input feature's dim: {1}"
.format(self.d_loss, feat.size(1)))
return self.center_loss_function(feat, label, self.centers)
def compute_loss(self, label: Tensor, logits: Tensor, feats: Tensor, mask: nn.Module = None) -> Tensor:
# 2. compute loss์์ ์ต์ข
loss๊ฐ์ ๊ณ์ฐํฉ๋๋ค.
nll_loss = F.cross_entropy(logits, label)
center_loss = self(feats, label)
# ์๊ธฐ ์์ ์ forward ํธ์ถ
return nll_loss + self.center_factor * center_loss
from kochat.app
4.5. app
ํจํค์ง๋ kochat ๋ชจ๋ธ์ ์ ํ๋ฆฌ์ผ์ด์
์ผ๋ก ๋ฐฐํฌํ ์ ์๊ฒ๋ ํด์ฃผ๋
RESTful API์ธ KochatApi
ํด๋์ค์ API ํธ์ถ์ ๊ด๋ จ๋ ์๋๋ฆฌ์ค๋ฅผ
์์ฑํ ์ ์๊ฒ๋ ํ๋ Scenario
ํด๋์ค๋ฅผ ์ ๊ณตํฉ๋๋ค.
from kochat.app import Scenario
4.5.1 Scenario
ํด๋์ค๋ ์ด๋ค intent์์๋ ์ด๋ค entity๊ฐ ํ์ํ๊ณ ,
์ด๋ค api๋ฅผ ํธ์ถํ๋์ง ์ ์ํ๋ ์ผ์ข
์ ๋ช
์ธ์์ ๊ฐ์ต๋๋ค.
์๋๋ฆฌ์ค ์์ฑ์ ์๋์ ๊ฐ์ ๋ช๊ฐ์ง ์ฃผ์์ฌํญ์ด ์์ต๋๋ค.
- intent๋ ๋ฐ๋์ raw๋ฐ์ดํฐ ํ์ผ ๋ช ๊ณผ ๋์ผํ๊ฒ ์ค์ ํ๊ธฐ
- api๋ ํจ์ ๊ทธ ์์ฒด๋ฅผ ๋ฃ์ต๋๋ค (๋ฐ๋์ callable ํด์ผํฉ๋๋ค.)
- scenario ๋์ ๋๋ฆฌ ์ ์์์ KEY๊ฐ์ api ํจ์์ ์์/์ฒ ์๊ฐ ๋์ผํด์ผํฉ๋๋ค.
- scenario ๋์ ๋๋ฆฌ ์ ์์์ KEY๊ฐ์ config์ NER_categories์ ์ ์๋ ์ํฐํฐ๋ง ํ์ฉ๋ฉ๋๋ค.
- ๊ธฐ๋ณธ๊ฐ(default) ์ค์ ์ ์ํ๋ฉด scenario ๋์
๋๋ฆฌ์ ๋ฆฌ์คํธ์ ๊ฐ์ ์ฒจ๊ฐํฉ๋๋ค.
- kocrawl (๋ ์จ) ์์
from kochat.app import Scenario
from kocrawl.weather import WeatherCrawler
# kocrawl์ kochat์ ๋ง๋ค๋ฉด์ ํจ๊ป ๊ฐ๋ฐํ ํฌ๋กค๋ฌ์
๋๋ค.
# (https://github.com/gusdnd852/kocrawl)
# 'pip install kocrawl'๋ก ์์ฝ๊ฒ ์ค์นํ ์ ์์ต๋๋ค.
weather_scenario = Scenario(
intent='weather', # intent๋ ์ธํ
ํธ ๋ช
์ ์ ์ต๋๋ค (raw ๋ฐ์ดํฐ ํ์ผ๋ช
๊ณผ ๋์ผํด์ผํฉ๋๋ค)
api=WeatherCrawler().request, # API๋ ํจ์ ์ด๋ฆ ์์ฒด๋ฅผ ๋ฃ์ต๋๋ค. (callableํด์ผํฉ๋๋ค)
scenario={
'LOCATION': [],
# ๊ธฐ๋ณธ์ ์ผ๋ก 'KEY' : []์ ํํ๋ก ๋ง๋ญ๋๋ค.
'DATE': ['์ค๋']
# entity๊ฐ ๊ฒ์ถ๋์ง ์์์ ๋ default ๊ฐ์ ์ง์ ํ๊ณ ์ถ์ผ๋ฉด ๋ฆฌ์คํธ ์์ ์ํ๋ ๊ฐ์ ๋ฃ์ต๋๋ค.
# [์ ์ฃผ, ๋ ์จ, ์๋ ค์ค] => [S-LOCATION, O, O] => api('์ค๋', S-LOCATION) call
# ๋ง์ฝ ['์ค๋', 'ํ์ฌ']์ฒ๋ผ 2๊ฐ ์ด์์ default๋ฅผ ๋ฃ์ผ๋ฉด ๋๋ค์ผ๋ก ์ ํํด์ default ๊ฐ์ผ๋ก ์ง์ ํฉ๋๋ค.
}
# ์๋๋ฆฌ์ค ๋์
๋๋ฆฌ๋ฅผ ์ ์ํฉ๋๋ค.
# ์ฃผ์์ 1 : scenario ํค๊ฐ(LOCATION, DATE)์ ์์๋ API ํจ์์ ํ๋ผ๋ฏธํฐ ์์์ ๋์ผํด์ผํฉ๋๋ค.
# ์ฃผ์์ 2 : scenario ํค๊ฐ(LOCATION, DATE)์ ์ฒ ์๋ API ํจ์์ ํ๋ผ๋ฏธํฐ ์ฒ ์์ ๋์ผํด์ผํฉ๋๋ค.
# ์ฃผ์์ 3 : raw ๋ฐ์ดํฐ ํ์ผ์ ๋ผ๋ฒจ๋งํ ์ํฐํฐ๋ช
๊ณผ scenario ํค๊ฐ์ ๋์ผํด์ผํฉ๋๋ค.
# ์ฆ config์ NER_categories์ ๋ฏธ๋ฆฌ ์ ์๋ ์ํฐํฐ๋ง ์ฌ์ฉํ์
์ผํฉ๋๋ค.
# B-, I- ๋ฑ์ BIOํ๊ทธ๋ ์๋ตํฉ๋๋ค. (S-DATE โ DATE๋ก ์๊ฐ)
# ๋/์๋ฌธ์๊น์ง ๋์ผํ ํ์๋ ์๊ณ , ์ฒ ์๋ง ๊ฐ์ผ๋ฉด ๋ฉ๋๋ค. (๋ชจ๋ lowercase ์ํ์์ ๋น๊ต)
# ๋ค์ ๊ท์ฐฎ๋๋ผ๋ ์ ํํ ๊ฐ ์ ๋ฌ์ ์ํด ์ผ๋ถ๋ฌ ๋ง๋ ์ธ ๊ฐ์ง ์ ํ์ฌํญ์ด๋ ๋ฐ๋ผ์ฃผ์๊ธธ ๋ฐ๋๋๋ค.
# WeatherCrawler().request์ ํ๋ผ๋ฏธํฐ๋ WeatherCrawler().request(location, date)์
๋๋ค.
# APIํ๋ผ๋ฏธํฐ์ ์์/์ด๋ฆ์ด ๋์ผํ๋ฉฐ, ๋ฐ๋ชจ ๋ฐ์ดํฐ ํ์ผ์ ์๋ ์ํฐํฐ์ธ LOCATION, DATE์ ๋์ผํฉ๋๋ค.
# ๋ง์ฝ ํ๋ฆฌ๋ฉด ์ด๋์ ํ๋ ธ๋์ง ์๋ฌ ๋ฉ์์ง๋ก ์๋ ค๋๋ฆฝ๋๋ค.
)
- ๋ ์คํ ๋ ์์ฝ ์๋๋ฆฌ์ค
from kochat.app import Scenario
reservation_scenario = Scenario(
intent='reservation',
api=reservation_check,
# reservation_check(num_people, reservation_time)์ ๊ฐ์
# ํจ์๋ฅผ ํธ์ถํ์ง ๋ง๊ณ ๊ทธ ์์ฒด๋ฅผ ํ๋ผ๋ฏธํฐ๋ก ์
๋ ฅํฉ๋๋ค.
# ํจ์๋ฅผ ๋ฐ์์ ์ ์ฅํด๋๋ค๊ฐ ์์ฒญ ๋ฐ์์ Api ๋ด๋ถ์์ call ํฉ๋๋ค
scenario={
'NUM_PEOPLE': [4],
# NUM_PEOPLE์ default๋ฅผ 4๋ช
์ผ๋ก ์ค์ ํ์ต๋๋ค.
'RESERVATION_TIME': []
# API(reservation_check(num_people, reservation_time)์ ํ๋ผ๋ฏธํฐ์ ์์/์ฒ ์๊ฐ ์ผ์นํฉ๋๋ค.
# ์ด ๋, ๋ฐ๋์ NER_categories์ NUM_PEOPLE๊ณผ RESERVATION_TIME์ด ์ ์๋์ด ์์ด์ผํ๋ฉฐ,
# ์ค์ raw๋ฐ์ดํฐ์ ๋ผ๋ฒจ๋ง๋ ๋ ์ด๋ธ๋ ์์ ์ด๋ฆ์ ์ฌ์ฉํด์ผํฉ๋๋ค.
}
)
from kochat.app import KochatApi
4.5.2. KochatApi
๋ Flask๋ก ๊ตฌํ๋์์ผ๋ฉฐ restful api๋ฅผ ์ ๊ณตํ๋ ํด๋์ค์
๋๋ค.
์ฌ์ค ์๋ฒ๋ก ๊ตฌ๋ํ ๊ณํ์ด๋ผ๋ฉด ์์์ ์ค๋ช
ํ ๊ฒ ๋ณด๋ค ํจ์ฌ ์ฝ๊ฒ ํ์ตํ ์ ์์ต๋๋ค.
(ํ์ต์ ๋ง์ ๋ถ๋ถ๋ค์ด KochatApi
์์ ์๋ํ ๋๊ธฐ ๋๋ฌธ์ ํ๋ผ๋ฏธํฐ ์ ๋ฌ๋ง์ผ๋ก ํ์ต์ด ๊ฐ๋ฅํฉ๋๋ค.)
KochatApi
ํด๋์ค๋ ์๋์ ๊ฐ์ ๋ฉ์๋๋ค์ ์ง์ํ๋ฉฐ ์ฌ์ฉ๋ฒ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
from kochat.app import KochatApi
# kochat api ๊ฐ์ฒด๋ฅผ ์์ฑํฉ๋๋ค.
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=(emb, True), # ์๋ฒ ๋ฉ ํ๋ก์ธ์, ํ์ต์ฌ๋ถ
intent_classifier=(clf, True), # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ, ํ์ต์ฌ๋ถ
entity_recognizer=(rcn, True), # ์ํฐํฐ ๊ฒ์ถ๊ธฐ, ํ์ต์ฌ๋ถ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
# kochat.app์ FLask ๊ฐ์ฒด์
๋๋ค.
# Flask์ ์ฌ์ฉ๋ฒ๊ณผ ๋์ผํ๊ฒ ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
@kochat.app.route('/')
def index():
return render_template("index.html")
# ์ ํ๋ฆฌ์ผ์ด์
์๋ฒ๋ฅผ ๊ฐ๋ํฉ๋๋ค.
if __name__ == '__main__':
kochat.app.template_folder = kochat.root_dir + 'templates'
kochat.app.static_folder = kochat.root_dir + 'static'
kochat.app.run(port=8080, host='0.0.0.0')
์์ ๊ฐ์ด kochat ์๋ฒ๋ฅผ ์คํ์ํฌ ์ ์์ต๋๋ค. (์ฌ๋งํ๋ฉด ์์ ๊ฐ์ด template๊ณผ static์ ๋ช ์์ ์ผ๋ก ์ ์ด์ฃผ์ธ์.) ์ ์์์ฒ๋ผ ๋ทฐ๋ฅผ ์ง์ ์๋ฒ์ ์ฐ๊ฒฐํด์ ํ๋์ ์๋ฒ์์ ๋ทฐ์ ๋ฅ๋ฌ๋ ์ฝ๋๋ฅผ ๋ชจ๋ ๊ตฌ๋์ํฌ ์๋ ์๊ณ , ๋ง์ฝ Micro Service Architecture๋ฅผ ๊ตฌ์ฑํด์ผํ๋ค๋ฉด, ์ฑ๋ด ์๋ฒ์ index route ('/')๋ฑ์ ์ค์ ํ์ง ์๊ณ ๋ฅ๋ฌ๋ ๋ฐฑ์๋ ์๋ฒ๋ก๋ ์ถฉ๋ถํ ํ์ฉํ ์ ์์ต๋๋ค. ๋ง์ฝ ํ์ต์ ์ํ์ง ์์ ๋๋ ์๋์ฒ๋ผ ๊ตฌํํฉ๋๋ค.
# 1. Tuple์ ๋๋ฒ์งธ ์ธ์์ False ์
๋ ฅ
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=(emb, False), # ์๋ฒ ๋ฉ ํ๋ก์ธ์, ํ์ต์ฌ๋ถ
intent_classifier=(clf, False), # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ, ํ์ต์ฌ๋ถ
entity_recognizer=(rcn, False), # ์ํฐํฐ ๊ฒ์ถ๊ธฐ, ํ์ต์ฌ๋ถ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
# 2. Tuple์ ํ๋ก์ธ์๋ง ์
๋ ฅ
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=(emb), # ์๋ฒ ๋ฉ ํ๋ก์ธ์
intent_classifier=(clf), # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ
entity_recognizer=(rcn), # ์ํฐํฐ ๊ฒ์ถ๊ธฐ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
# 3. ๊ทธ๋ฅ ํ๋ก์ธ์๋ง ์
๋ ฅ
kochat = KochatApi(
dataset=dataset, # ๋ฐ์ดํฐ์
๊ฐ์ฒด
embed_processor=emb, # ์๋ฒ ๋ฉ ํ๋ก์ธ์
intent_classifier=clf, # ์ธํ
ํธ ๋ถ๋ฅ๊ธฐ
entity_recognizer=rcn, # ์ํฐํฐ ๊ฒ์ถ๊ธฐ
scenarios=[ #์๋๋ฆฌ์ค ๋ฆฌ์คํธ
weather, dust, travel, restaurant
]
)
์๋์์๋ Kochat ์๋ฒ์ url ํจํด์ ๋ํด ์์ธํ๊ฒ ์ค๋ช ํฉ๋๋ค. ํ์ฌ kochat api๋ ๋ค์๊ณผ ๊ฐ์ 4๊ฐ์ url ํจํด์ ์ง์ํ๋ฉฐ, ์ด url ํจํด๋ค์ config์ API ์ฑํฐ์์ ๋ณ๊ฒฝ ๊ฐ๋ฅํฉ๋๋ค.
API = {
'request_chat_url_pattern': 'request_chat', # request_chat ๊ธฐ๋ฅ url pattern
'fill_slot_url_pattern': 'fill_slot', # fill_slot ๊ธฐ๋ฅ url pattern
'get_intent_url_pattern': 'get_intent', # get_intent ๊ธฐ๋ฅ url pattern
'get_entity_url_pattern': 'get_entity' # get_entity ๊ธฐ๋ฅ url pattern
}
4.5.2.1. request_chat
๊ฐ์ฅ ๊ธฐ๋ณธ์ ์ธ ํจํด์ธ request_chat์
๋๋ค. intent๋ถ๋ฅ, entity๊ฒ์ถ, api์ฐ๊ฒฐ์ ํ๋ฒ์ ์งํํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://0.0.0.0/request_chat//
case 1. state SUCCESS
๋ชจ๋ entity๊ฐ ์ ์์ ์ผ๋ก ์
๋ ฅ๋ ๊ฒฝ์ฐ state 'SUCCESS'๋ฅผ ๋ฐํํฉ๋๋ค.
>>> ์ ์ gusdnd852 : ๋ชจ๋ ๋ถ์ฐ ๋ ์จ ์ด๋
https://123.456.789.000:1234/request_chat/gusdnd852/๋ชจ๋ ๋ถ์ฐ ๋ ์จ ์ด๋
โ {
'input': [๋ชจ๋ , ๋ถ์ฐ, ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-DATE, S-LOCATION, O, O]
'state': 'SUCCESS',
'answer': '๋ถ์ฐ์ ๋ ์จ ์ ๋ณด๋ฅผ ์ ํด๋๋ฆด๊ฒ์. ๐
๋ชจ๋ ๋ถ์ฐ์ง์ญ์ ์ค์ ์๋ ์ญ์จ 19๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์. ์คํ์๋ ์ญ์จ 26๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์.'
}
case 2. state REQUIRE_XXX
๋ง์ฝ default๊ฐ์ด ์๋ ์ํฐํฐ๊ฐ ์
๋ ฅ๋์ง ์์ ๊ฒฝ์ฐ state 'REQUIRE_XXX'๋ฅผ ๋ฐํํฉ๋๋ค.
๋๊ฐ ์ด์์ ์ํฐํฐ๊ฐ ๋ชจ์๋ผ๋ฉด state 'REQUIRE_XXX_YYY'๊ฐ ๋ฐํ๋ฉ๋๋ค.
>>> ์ ์ minqukanq : ๋ชฉ์์ผ ๋ ์จ ์ด๋
e.g. https://123.456.789.000:1234/request_chat/minqukanq/๋ชฉ์์ผ ๋ ์จ ์ด๋
โ {
'input': [๋ชฉ์์ผ, ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-DATE, O, O]
'state': 'REQUIRE_LOCATION',
'answer': None
}
case 3. state FALLBACK
์ธํ
ํธ ๋ถ๋ฅ์ FALLBACK์ด ๋ฐ์ํ๋ฉด FALLBACK์ ๋ฐํํฉ๋๋ค.
>>> ์ ์ sangji11 : ๋ชฉ์์ผ ์น๊ตฌ ์์ผ์ด๋ค
e.g. https://123.456.789.000:1234/request_chat/sangji11/๋ชฉ์์ผ ์น๊ตฌ ์์ผ์ด๋ค
โ {
'input': [๋ชฉ์์ผ, ์น๊ตฌ, ์์ผ์ด๋ค],
'intent': 'FALLBACK',
'entity': [S-DATE, O, O]
'state': 'FALLBACK',
'answer': None
}
4.5.2.2. fill_slot
๊ฐ์ฅ request์ REQUIRE_XXX๊ฐ ๋์ฌ๋, ์ฌ์ฉ์์๊ฒ ๋๋ฌป๊ณ ๊ธฐ์กด ๋์
๋๋ฆฌ์ ์ถ๊ฐํด์ api๋ฅผ ํธ์ถํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://0.0.0.0/fill_slot//
>>> ์ ์ gusdnd852 : ๋ชจ๋ ๋ ์จ ์๋ ค์ค โ REQUIRE_LOCATION
>>> ๋ด : ์ด๋ ์ง์ญ์ ์๋ ค๋๋ฆด๊น์?
>>> ์ ์ gusdnd852 : ๋ถ์ฐ
https://123.456.789.000:1234/fill_slot/gusdnd852/๋ถ์ฐ
โ {
'input': [๋ถ์ฐ] + [๋ชจ๋ , ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-LOCATION] + [S-DATE, O, O]
'state': 'SUCCESS',
'answer': '๋ถ์ฐ์ ๋ ์จ ์ ๋ณด๋ฅผ ์ ํด๋๋ฆด๊ฒ์. ๐
๋ชจ๋ ๋ถ์ฐ์ง์ญ์ ์ค์ ์๋ ์ญ์จ 19๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์. ์คํ์๋ ์ญ์จ 26๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์.'
}
>>> ์ ์ gusdnd852 : ๋ ์จ ์๋ ค์ค โ REQUIRE_DATE_LOCATION
>>> ๋ด : ์ธ์ ์ ์ด๋ ์ง์ญ์ ๋ ์จ๋ฅผ ์๋ ค๋๋ฆด๊น์?
>>> ์ ์ gusdnd852 : ๋ถ์ฐ ๋ชจ๋
https://123.456.789.000:1234/fill_slot/gusdnd852/๋ถ์ฐ ๋ชจ๋
โ {
'input': [๋ถ์ฐ, ๋ชจ๋ ] + [๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': [S-LOCATION, S-DATE] + [O, O]
'state': 'SUCCESS',
'answer': '๋ถ์ฐ์ ๋ ์จ ์ ๋ณด๋ฅผ ์ ํด๋๋ฆด๊ฒ์. ๐
๋ชจ๋ ๋ถ์ฐ์ง์ญ์ ์ค์ ์๋ ์ญ์จ 19๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์. ์คํ์๋ ์ญ์จ 26๋์ด๋ฉฐ, ์๋ง ํ๋์ด ๋ง์ ๊ฒ ๊ฐ์์.'
}
4.5.2.3. get_intent
intent๋ง ์๊ณ ์ถ์๋ ํธ์ถํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://0.0.0.0/get_intent/
https://123.456.789.000:1234/get_intent/์ ์ฃผ ๋ ์จ ์ด๋
โ {
'input': [์ ์ฃผ, ๋ ์จ, ์ด๋],
'intent': 'weather',
'entity': None,
'state': 'REQUEST_INTENT',
'answer': None
}
4.5.2.4. get_entity
entity๋ง ์๊ณ ์ถ์๋ ํธ์ถํฉ๋๋ค.
๊ธฐ๋ณธ ํจํด : https://0.0.0.0/get_entity/
https://123.456.789.000:1234/get_entity/์ ์ฃผ ๋ ์จ ์ด๋
โ {
'input': [์ ์ฃผ, ๋ ์จ, ์ด๋],
'intent': None,
'entity': [S-LOCATION, O, O],
'state': 'REQUEST_ENTITY',
'answer': None
}
5. Visualization Support
Kochat์ ์๋์ ๊ฐ์ด ๋ค์ํ ์๊ฐํ ๊ธฐ๋ฅ์ ์ง์ํฉ๋๋ค.
Feature Space๋ ์ผ์ Epoch๋ง๋ค ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅ๋๊ณ ,
๊ทธ ์ธ์ ์๊ฐํ ์๋ฃ๋ ๋งค Epoch๋ง๋ค ๊ณ์ ์
๋ฐ์ดํธ ๋๋ฉฐ
"root/saved"์ ๋ชจ๋ธ ์ ์ฅํ์ผ๊ณผ ํจ๊ป ์ ์ฅ๋ฉ๋๋ค.
์๊ฐํ ์๋ฃ ๋ฐ ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก๋
config์์ ๋ณ๊ฒฝํ ์ ์์ต๋๋ค.
5.1. Train/Test Accuracy
5.2. Train/Test Recall (macro average)
5.3. Train/Test Precision (macro average)
5.4. Train/Test F1-Score (macro average)
5.5. Train/Test Confusion Matrix
Confusion Matrix์ ๊ฒฝ์ฐ๋ X์ถ(์๋)๊ฐ Prediction, Y์ถ(์ผ์ชฝ)์ด Label์
๋๋ค.
๋ค์ ๋ฒ์ ์์ xticks์ yticks๋ฅผ ์ถ๊ฐํ ์์ ์
๋๋ค.
5.6. Train/Test Classification Performance Report
Accuracy, Precision, Recall, F1 Score ๋ฑ ๋ชจ๋ธ์ ๋ค์ํ ๋ฉํธ๋ฆญ์ผ๋ก ํ๊ฐํ๊ณ , ํ ํํ๋ก ์ด๋ฏธ์งํ์ผ์ ๋ง๋ค์ด์ค๋๋ค.
์์์ ๋ช๋ฒ์งธ ๊น์ง ๋ฐ์ฌ๋ฆผํด์ ๋ณด์ฌ์ค์ง config์์ ์ค์ ํ ์ ์์ต๋๋ค.
PROC = {
# ...(์๋ต)
'logging_precision': 5, # ๊ฒฐ๊ณผ ์ ์ฅ์ ๋ฐ์ฌ๋ฆผ ์์์ n๋ฒ์งธ์์ ๋ฐ์ฌ๋ฆผ
}
5.7. Train/Test Fallback Detection Performance Report
Fallback Detection์ Intent Classification์ ์์ญ์ ๋๋ค. Intent Classification๋ง ์ง์ํฉ๋๋ค. (Fallback Detection ์ฑ๋ฅ ํ๊ฐ๋ฅผ ์ํด์๋ ๋ฐ๋์ ood=True์ฌ์ผํฉ๋๋ค.)
5.8. Feature Space Visualization
Feature Space๋ Distance ๊ธฐ๋ฐ์ Metric Learning Lossํจ์๊ฐ ์ ์๋ํ๊ณ ์๋์ง ํ์ธํ๊ธฐ ์ํ๊ฒ์ผ๋ก Intent Classification๋ง ์ง์ํฉ๋๋ค. ๋ํ ์๊ฐํ ์ฐจ์์ config์ d_loss์ ๋ฐ๋ผ ๊ฒฐ์ ๋ฉ๋๋ค.
- d_loss = 2์ธ ๊ฒฝ์ฐ : 2์ฐจ์์ผ๋ก ์๊ฐํ
- d_loss = 3์ธ ๊ฒฝ์ฐ : 3์ฐจ์์ผ๋ก ์๊ฐํ
- d_loss > 3์ธ ๊ฒฝ์ฐ : Incremetal PCA๋ฅผ ํตํด 3์ฐจ์์ผ๋ก ์ฐจ์ ๊ฐ์ ํ ์๊ฐํ
Feature Space Visualization์ PCA๋ฅผ ์คํํ๊ธฐ ๋๋ฌธ์ ๋น์ฉ์ด ์๋นํ ํฝ๋๋ค. ๋ค๋ฅธ ์๊ฐํ๋ ๋งค Epoch๋ง๋ค ์ํํ์ง๋ง, Feature Space Visulization์ ๋ช Epoch๋ง๋ค ์ํํ ์ง ๊ฒฐ์ ํ ์ ์์ต๋๋ค.
PROC = {
# ...(์๋ต)
'visualization_epoch': 50, # ์๊ฐํ ๋น๋ (์ ํญ๋ง๋ค ์๊ฐํ ์ํ)
}
6. Performance Issue
์ด ์ฑํฐ๋ Kochat์ ๋ค์ํ ์ฑ๋ฅ ์ด์์ ๋ํด ๊ธฐ๋กํฉ๋๋ค.
6.1. ์ผ๊ตด์ธ์ ์์ญ์์ ์ฐ์ด๋ Loss ํจ์๋ค์ Fallback ๋ํ ์ ์ ํจ๊ณผ์ ์ด๋ค.
์ฌ์ค CenterLoss๋ CosFace ๊ฐ์ Margin Lossํจ์๋ค์ด ์ปดํจํฐ ๋น์ ์ ์ผ๊ตด์ธ์ ์์ญ์์
๋ง์ด ์ฐ์ธ๋ค๊ณ ๋ ํ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ชจ๋ Retrieval ๋ฌธ์ ์ ์ ์ฉํ ์ ์๋ Lossํจ์์
๋๋ค.
Kochat์ DistanceClassifier๋ ๊ฑฐ๋ฆฌ๊ธฐ๋ฐ์ Retrieval์ ์ํํ๊ธฐ ๋๋ฌธ์ ์ด๋ฌํ
Lossํจ์๋ฅผ ๋งค์ฐ ํจ๊ณผ์ ์ผ๋ก ํ์ฉํ ์ ์์ต๋๋ค. ์ค์ ๋ก ๋ฐ๋ชจ ๋ฐ์ดํฐ์
์ ์ ์ฉํ์ ๋
CrossEntropyLoss๋ก๋ 70% ์ธ์ ๋ฆฌ์ธ FallbackDetection ์ฑ๋ฅ์ด CenterLoss, CosFace
๋ฑ์ ์ ์ฉํ๋ฉด 90~95%๊น์ง ํฅ์๋์์ต๋๋ค. (120๊ฐ์ OOD ์ํ ํ
์คํธ)
- SoftmaxClassifier + CrossEntropyLoss + CNN (d_model=512, layers=1)
- DistanceClassifier + CrossEntropyLoss + CNN (d_model=512, layers=1)
- DistanceClassifier + CenterLoss + CNN (d_model=512, layers=1)
6.2. Retrieval Feature๋ก๋ LSTM๋ณด๋ค CNN์ด ๋ ์ข๋ค.
Retrieval ๊ธฐ๋ฐ์ Distance Classification์ ๊ฒฝ์ฐ LSTM๋ณด๋ค CNN์ Feature๋ค์ด
ํด๋์ค๋ณ๋ก ํจ์ฌ ์ ๊ตฌ๋ถ๋๋ ๊ฒ์ ํ์ธํ์ต๋๋ค. Feature Extraction ๋ฅ๋ ฅ ์์ฒด๋
CNN์ด ์ข๋ค๊ณ ์๋ ค์ง ๊ฒ์ฒ๋ผ ์๋ฌด๋๋ CNN์ด Feature๋ฅผ ๋ ์ ๋ฝ์๋ด๋ ๊ฒ ๊ฐ์ต๋๋ค.
Feature Space์์ ๊ตฌ๋ถ์ด ์ ๋๋ค๋ ๊ฒ์ OOD ์ฑ๋ฅ์ด ์ฐ์ํ๋ค๋ ๊ฒ๊ณผ ๋์น์ด๋ฏ๋ก,
DistanceClassifier ์ฌ์ฉ์ LSTM๋ณด๋จ CNN์ ์ฌ์ฉํ๋ ๊ฒ์ด ๋์ฑ ๋ฐ๋์งํด๋ณด์
๋๋ค.
- ์ข : LSTM (d_model=512, layers=1) + CosFace, 500 Epoch ํ์ต (์๋ ดํจ)
- ์ฐ : CNN (d_model=512, layers=1) + CosFace, 500 Epoch ํ์ต (์๋ ดํจ)
6.3. CRF Loss์ ์๋ ด ์๋๋ CrossEntropy๋ณด๋ค ๋๋ฆฌ๋ค.
EntityRecognizer์ ๊ฒฝ์ฐ ๋์ผ ์ฌ์ด์ฆ, ๋์ผ Layer์์ CRF Loss๋ฅผ ์ฌ์ฉํ๋ฉด
ํ์คํ ์ฑ๋ฅ์ ๋์ฑ ์ฐ์ํด์ง๋, ์กฐ๊ธ ๋ ๋ ๋๋ฆฌ๊ฒ ์๋ ดํ๋ ๊ฒ์ ํ์ธํ์ต๋๋ค.
CRF Loss์ ๊ฒฝ์ฐ ์กฐ๊ธ ๋ ๋ง์ ํ์ต ์๊ฐ์ ์ค์ผ ์ ์ฑ๋ฅ์ ๋ด๋ ๊ฒ ๊ฐ์ต๋๋ค.
- ์ข : LSTM (d_model=512, layers=1) + CrossEntropy โ Epoch 300์ f1-score 90% ๋๋ฌ
- ์ฐ : LSTM (d_model=512, layers=1) + CRFLoss โ Epoch 450์ f1-score 90% ๋๋ฌ
6.4. FallbackDetector์ max_iter๋ ๋๊ฒ ์ค์ ํด์ผํ๋ค.
Fallback Detector๋ sklearn ๋ชจ๋ธ๋ค์ ํ์ฉํ๋๋ฐ ๊ธฐ์กด sklearn๋ชจ๋ธ๋ค์
max_iter์ default๊ฐ์ด 100์ผ๋ก ์ค์ ๋์ด ์๋ ดํ๊ธฐ ์ ์ ํ์ต์ด ๋๋๋ฒ๋ฆฝ๋๋ค.
๋๋ฌธ์ Fallback Detector๋ฅผ config์ ์ ์ํ ๋ max_iter๋ฅผ ๋๊ฒ ์ค์ ํด์ผ
์ถฉ๋ถํ ํ์ต์๊ฐ์ ๋ณด์ฅ๋ฐ์ ์ ์์ต๋๋ค.
7. Demo Application
์ด ์ฑํฐ์์๋ Demo ์ ํ๋ฆฌ์ผ์ด์
์ ๋ํด ์๊ฐํฉ๋๋ค.
๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์
์ ์ฌํ์ ๋ณด๋ฅผ ์๊ฐํ๋ ์ฑ๋ด ์ ํ๋ฆฌ์ผ์ด์
์ผ๋ก,
๋ ์จ, ๋ฏธ์ธ๋จผ์ง, ๋ง์ง ์ฌํ์ง ์ ๋ณด๋ฅผ ์๋ ค์ฃผ๋ ๊ธฐ๋ฅ์ ๋ณด์ ํ๊ณ ์์ต๋๋ค.
Api๋ Kochat์ ๋ง๋ค๋ฉด์ ํจ๊ป ๋ง๋ Kocrawl
์ ์ฌ์ฉํ์ต๋๋ค.
7.1. View (HTML)
Html๊ณผ CSS๋ฅผ ์ฌ์ฉํ์ฌ View๋ฅผ ๊ตฌํํ์์ต๋๋ค. ์ ๊ฐ ๋์์ธ ํ ๊ฒ์ ์๋๊ณ ์ฌ๊ธฐ ์์ ์ ๊ณต๋๋ ๋ถํธ์คํธ๋ฉ ํ ๋ง๋ฅผ ์ฌ์ฉํ์์ต๋๋ค.
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Kochat ๋ฐ๋ชจ</title>
<script src="{{ url_for('static', filename="js/jquery.js") }}" type="text/javascript"></script>
<script src="{{ url_for('static', filename="js/bootstrap.js") }}" type="text/javascript"></script>
<script src="{{ url_for('static', filename="js/main.js") }}" type="text/javascript"></script>
<link href="{{ url_for('static', filename="css/bootstrap.css") }}" rel="stylesheet" id="bootstrap-css">
<link href="{{ url_for('static', filename="css/main.css") }}" rel="stylesheet" id="main-css">
<script>
greet();
onClickAsEnter();
</script>
</head>
<body>
<div class="chat_window">
<div class="top_menu">
<div class="buttons">
<div class="button close_button"></div>
<div class="button minimize"></div>
<div class="button maximize"></div>
</div>
<div class="title">Kochat ๋ฐ๋ชจ</div>
</div>
<ul class="messages"></ul>
<div class="bottom_wrapper clearfix">
<div class="message_input_wrapper">
<input class="message_input"
onkeyup="return onClickAsEnter(event)"
placeholder="๋ด์ฉ์ ์
๋ ฅํ์ธ์."/>
</div>
<div class="send_message"
id="send_message"
onclick="onSendButtonClicked()">
<div class="icon"></div>
<div class="text">๋ณด๋ด๊ธฐ</div>
</div>
</div>
</div>
<div class="message_template">
<li class="message">
<div class="avatar"></div>
<div class="text_wrapper">
<div class="text"></div>
</div>
</li>
</div>
</body>
</html>
7.2. ๋ฅ๋ฌ๋ ๋ชจ๋ธ ๊ตฌ์ฑ
์๋์ ๊ฐ์ ๋ชจ๋ธ ๊ตฌ์ฑ์ ์ฌ์ฉํ์์ต๋๋ค.
dataset = Dataset(ood=True)
emb = GensimEmbedder(model=embed.FastText())
clf = DistanceClassifier(
model=intent.CNN(dataset.intent_dict),
loss=CenterLoss(dataset.intent_dict)
)
rcn = EntityRecognizer(
model=entity.LSTM(dataset.entity_dict),
loss=CRFLoss(dataset.entity_dict)
)
kochat = KochatApi(
dataset=dataset,
embed_processor=(emb, True),
intent_classifier=(clf, True),
entity_recognizer=(rcn, True),
scenarios=[
weather, dust, travel, restaurant
]
)
@kochat.app.route('/')
def index():
return render_template("index.html")
if __name__ == '__main__':
kochat.app.template_folder = kochat.root_dir + 'templates'
kochat.app.static_folder = kochat.root_dir + 'static'
kochat.app.run(port=8080, host='0.0.0.0')
7.3. ์๋๋ฆฌ์ค ๊ตฌ์ฑ
Kocrawl์ ์ด์ฉํด 4๊ฐ์ง ์๋์ ๋ง๋ ์๋๋ฆฌ์ค๋ฅผ ๊ตฌ์ฑํ์์ต๋๋ค.
weather = Scenario(
intent='weather',
api=WeatherCrawler().request,
scenario={
'LOCATION': [],
'DATE': ['์ค๋']
}
)
dust = Scenario(
intent='dust',
api=DustCrawler().request,
scenario={
'LOCATION': [],
'DATE': ['์ค๋']
}
)
restaurant = Scenario(
intent='restaurant',
api=RestaurantCrawler().request,
scenario={
'LOCATION': [],
'RESTAURANT': ['์ ๋ช
ํ']
}
)
travel = Scenario(
intent='travel',
api=MapCrawler().request,
scenario={
'LOCATION': [],
'PLACE': ['๊ด๊ด์ง']
}
)
7.4. Javascript ๊ตฌํ (+ Ajax)
๋ง์ง๋ง์ผ๋ก ๋ฒํผ์ ๋๋ฅด๋ฉด ๋ฉ์์ง๊ฐ ๋์์ง๋ ์ ๋๋ฉ์ด์ ๊ณผ Ajax๋ฅผ ํตํด Kochat ์๋ฒ์ ํต์ ํ๋ ์์ค์ฝ๋๋ฅผ ์์ฑํ์์ต๋๋ค. ๊ฐ๋จํ chit chat ๋ํ 3๊ฐ์ง (์๋ , ๊ณ ๋ง์, ์์ด)๋ ๊ท์น๊ธฐ๋ฐ์ผ๋ก ๊ตฌํํ์์ต๋๋ค. ์ถํ์ Seq2Seq ๊ธฐ๋ฅ์ ์ถ๊ฐํ์ฌ ์ด ๋ถ๋ถ๋ ๋จธ์ ๋ฌ๋ ๊ธฐ๋ฐ์ผ๋ก ๋ณ๊ฒฝํ ์์ ์ ๋๋ค.
// variables
let userName = null;
let state = 'SUCCESS';
// functions
function Message(arg) {
this.text = arg.text;
this.message_side = arg.message_side;
this.draw = function (_this) {
return function () {
let $message;
$message = $($('.message_template').clone().html());
$message.addClass(_this.message_side).find('.text').html(_this.text);
$('.messages').append($message);
return setTimeout(function () {
return $message.addClass('appeared');
}, 0);
};
}(this);
return this;
}
function getMessageText() {
let $message_input;
$message_input = $('.message_input');
return $message_input.val();
}
function sendMessage(text, message_side) {
let $messages, message;
$('.message_input').val('');
$messages = $('.messages');
message = new Message({
text: text,
message_side: message_side
});
message.draw();
$messages.animate({scrollTop: $messages.prop('scrollHeight')}, 300);
}
function greet() {
setTimeout(function () {
return sendMessage("Kochat ๋ฐ๋ชจ์ ์ค์ ๊ฑธ ํ์ํฉ๋๋ค.", 'left');
}, 1000);
setTimeout(function () {
return sendMessage("์ฌ์ฉํ ๋๋ค์์ ์๋ ค์ฃผ์ธ์.", 'left');
}, 2000);
}
function onClickAsEnter(e) {
if (e.keyCode === 13) {
onSendButtonClicked()
}
}
function setUserName(username) {
if (username != null && username.replace(" ", "" !== "")) {
setTimeout(function () {
return sendMessage("๋ฐ๊ฐ์ต๋๋ค." + username + "๋. ๋๋ค์์ด ์ค์ ๋์์ต๋๋ค.", 'left');
}, 1000);
setTimeout(function () {
return sendMessage("์ ๋ ๊ฐ์ข
์ฌํ ์ ๋ณด๋ฅผ ์๋ ค์ฃผ๋ ์ฌํ๋ด์
๋๋ค.", 'left');
}, 2000);
setTimeout(function () {
return sendMessage("๋ ์จ, ๋ฏธ์ธ๋จผ์ง, ์ฌํ์ง, ๋ง์ง ์ ๋ณด์ ๋ํด ๋ฌด์์ด๋ ๋ฌผ์ด๋ณด์ธ์!", 'left');
}, 3000);
return username;
} else {
setTimeout(function () {
return sendMessage("์ฌ๋ฐ๋ฅธ ๋๋ค์์ ์ด์ฉํด์ฃผ์ธ์.", 'left');
}, 1000);
return null;
}
}
function requestChat(messageText, url_pattern) {
$.ajax({
url: "http://0.0.0.0:8080/" + url_pattern + '/' + userName + '/' + messageText,
type: "GET",
dataType: "json",
success: function (data) {
state = data['state'];
if (state === 'SUCCESS') {
return sendMessage(data['answer'], 'left');
} else if (state === 'REQUIRE_LOCATION') {
return sendMessage('์ด๋ ์ง์ญ์ ์๋ ค๋๋ฆด๊น์?', 'left');
} else {
return sendMessage('์ฃ์กํฉ๋๋ค. ๋ฌด์จ๋ง์ธ์ง ์ ๋ชจ๋ฅด๊ฒ ์ด์.', 'left');
}
},
error: function (request, status, error) {
console.log(error);
return sendMessage('์ฃ์กํฉ๋๋ค. ์๋ฒ ์ฐ๊ฒฐ์ ์คํจํ์ต๋๋ค.', 'left');
}
});
}
function onSendButtonClicked() {
let messageText = getMessageText();
sendMessage(messageText, 'right');
if (userName == null) {
userName = setUserName(messageText);
} else {
if (messageText.includes('์๋
')) {
setTimeout(function () {
return sendMessage("์๋
ํ์ธ์. ์ ๋ Kochat ์ฌํ๋ด์
๋๋ค.", 'left');
}, 1000);
} else if (messageText.includes('๊ณ ๋ง์')) {
setTimeout(function () {
return sendMessage("์ฒ๋ง์์. ๋ ๋ฌผ์ด๋ณด์ค ๊ฑด ์๋์?", 'left');
}, 1000);
} else if (messageText.includes('์์ด')) {
setTimeout(function () {
return sendMessage("๊ทธ๋ ๊ตฐ์. ์๊ฒ ์ต๋๋ค!", 'left');
}, 1000);
} else if (state.includes('REQUIRE')) {
return requestChat(messageText, 'fill_slot');
} else {
return requestChat(messageText, 'request_chat');
}
}
}
7.5. ์คํ ๊ฒฐ๊ณผ
Warning
๋ฐ๋ชจ ๋ฐ์ดํฐ์
์ ์์ด ์ ๊ธฐ ๋๋ฌธ์ ๋ค์ํ ์ง๋ช
์ด๋ ๋ค์ํ
์์, ๋ค์ํ ์ฌํ์ง ๋ฑ์ ์์ ๋ฃ์ง ๋ชปํฉ๋๋ค. (๋ฐ๋ชจ์์์ ์ํด
์ผ๋ถ ์์ธ ์ง์ญ ์์ฃผ๋ก๋ง ๋ฐ์ดํฐ์
์ ์์ฑํ์ต๋๋ค.) ๋ฐ๋ชจ๋ฐ์ดํฐ์
์
๋ฐ๋ชจ์์์ ์ฐ๊ธฐ ์ํ ์์ฃผ ์์ dev ๋ฐ์ดํฐ ์
์
๋๋ค.
์ค์ ๋ก ๋ค์ํ ๋์๋ ๋ค์ํ ์์ ๋ฑ์ ์์ ๋ค์ ์ ๋๋ก ๋ํ๋ฅผ ๋๋๋ ค๋ฉด ๋ฐ๋ชจ ๋ฐ์ดํฐ์
๋ณด๋ค
์์ฒด์ ์ธ ๋ฐ์ดํฐ ์
์ ๋ง์ด ์ฝ์
ํ์
์ผ ๋์ฑ ์ข์ ์ฑ๋ฅ์ ๊ธฐ๋ํ ์ ์์ ๊ฒ์
๋๋ค.
ํ๋ฃจ๋นจ๋ฆฌ Pretrain ๋ชจ๋ธ์ ์ง์ํ์ฌ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋๋ก ํ๊ฒ ์ต๋๋ค.
๋ชจ๋ ๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์
์์ค์ฝ๋๋
์ฌ๊ธฐ ๋ฅผ ์ฐธ๊ณ ํด์ฃผ์ธ์
8. Contributor
๋ง์ฝ ๋ณธ์ธ์ด ์ํ๋ ๊ธฐ๋ฅ์ Kocchat์ ์ถ๊ฐํ๊ณ ์ถ์ผ์๋ค๋ฉด ์ธ์ ๋ ์ง ์ปจํธ๋ฆฌ๋ทฐ์
ํ ์ ์์ต๋๋ค.
9. TODO List
- ver 1.0 : ์ํฐํฐ ํ์ต์ CRF ๋ฐ ๋ก์ค ๋ง์คํน ์ถ๊ฐํ๊ธฐ
- ver 1.0 : ์์ธํ README ๋ฌธ์ ์์ฑ ๋ฐ PyPI ๋ฐฐํฌํ๊ธฐ
- ver 1.0 : ๊ฐ๋จํ ์น ์ธํฐํ์ด์ค ๊ธฐ๋ฐ ๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์ ์ ์ํ๊ธฐ
- ver 1.0 : Jupyter Note Example ์์ฑํ๊ธฐ + Colab ์คํ ํ๊ฒฝ
- ver 1.1 : ๋ฐ์ดํฐ์ ํฌ๋งท RASA์ฒ๋ผ markdown์ ๋๊ดํธ ํํ๋ก ๋ณ๊ฒฝ
- ver 1.2 : Pretrain Embedding ์ ์ฉ ๊ฐ๋ฅํ๊ฒ ๋ณ๊ฒฝ (Gensim)
- ver 1.3 : Transformer ๊ธฐ๋ฐ ๋ชจ๋ธ ์ถ๊ฐ (Etri BERT, SK BERT)
- ver 1.3 : Pytorch Embedding ๋ชจ๋ธ ์ถ๊ฐ + Pretrain ์ ์ฉ ๊ฐ๋ฅํ๊ฒ
- ver 1.4 : Seq2Seq ์ถ๊ฐํด์ Fallback์ ๋์ฒํ ์ ์๊ฒ ๋ง๋ค๊ธฐ (LSTM, SK GPT2)
- ver 1.5 : ๋ค์ด๋ฒ ๋ง์ถค๋ฒ ๊ฒ์ฌ๊ธฐ ์ ๊ฑฐํ๊ณ , ์์ฒด์ ์ธ ๋์ด์ฐ๊ธฐ ๊ฒ์ฌ๋ชจ๋ ์ถ๊ฐ
- ver 1.6 : BERT์ Markov ์ฒด์ธ์ ์ด์ฉํ ์๋ OOD ๋ฐ์ดํฐ ์์ฑ๊ธฐ๋ฅ ์ถ๊ฐ
-
ver 1.7 : ๋ํ ํ๋ฆ๊ด๋ฆฌ๋ฅผ ์ํ Story ๊ด๋ฆฌ ๊ธฐ๋ฅ ๊ตฌํํด์ ์ถ๊ฐํ๊ธฐ
10. Reference
- ์ฑ๋ด ๋ถ๋ฅ ๊ทธ๋ฆผ
- seq2seq ๊ทธ๋ฆผ
- Fallback Detection ๊ทธ๋ฆผ
- ๋ฐ๋ชจ ์ ํ๋ฆฌ์ผ์ด์ ํ ํ๋ฆฟ
- ๊ทธ ์ธ์ ๊ทธ๋ฆผ ๋ฐ ์์ค์ฝ๋ : ๋ณธ์ธ ์ ์
11. License
Copyright 2020 Kochat.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.