Korean AI Project


License
MIT
Install
pip install koai==0.0.1.3

Documentation

KoAI; Korean AI Project. ํ•œ๊ตญ์–ด๋ฅผ ์œ„ํ•œ ์ธ๊ณต์ง€๋Šฅ ํ”„๋กœ์ ํŠธ

$ pip install koai

FineTuning

ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ(huggingface-hub) ๋˜๋Š” ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ํ†ตํ•ด ๋กœ๋“œ ๊ฐ€๋Šฅํ•œ ๋กœ์ปฌ ํŒŒ์ผ์„, klue ๋ฒค์น˜ ๋งˆํฌ์— ๋Œ€ํ•˜์—ฌ ํ…Œ์ŠคํŠธํ•˜๋Š” ์˜ˆ์‹œ์ž…๋‹ˆ๋‹ค.

from koai import finetune

# finetuning and evaluating on klue-sts dataset
finetune(
    task_name="klue-sts", 
    model_name_or_path="klue/bert-base", 
    do_train=True, 
    do_eval=True, 
    num_train_epochs=5, 
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_strategy="epoch"
)

# finetuning and evaluating on all klue dataset (except 'wos')
# if "finetune_model_across_the_tasks" is True, the model train all the tasks in KLUE
# but it is false(default is false), finetuning the language model individually.  
finetune(
    "klue", 
    "klue/bert-base", 
    do_train=True, 
    do_eval=True, 
    num_train_epochs=5, 
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_strategy="epoch"
)
  • task_name: str, ๊ณผ์ œ์˜ ์ด๋ฆ„์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค(klue ์ „์ฒด๋ฅผ ํ…Œ์ŠคํŠธ ํ•˜๋ ค๋ฉด, "klue". ํŠน์ • ํ…Œ์Šคํฌ๋ฅผ ์„ ํƒํ•˜๋ ค๋ฉด "klue-mrc"์™€ ๊ฐ™์ด ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”. "mrc"์™€ ๊ฐ™์€ ํ•˜์œ„ ํ…Œ์Šคํฌ ์ด๋ฆ„์€ ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ๋ฅผ ๋”ฐ๋ฆ…๋‹ˆ๋‹ค.)

  • model_name_or_path: str, ๋ชจ๋ธ์˜ ํ—ˆ๊น…ํŽ˜์ด์Šค ํ—ˆ๋ธŒ ์ด๋ฆ„ ๋˜๋Š” ๋กœ์ปฌ ๊ฒฝ๋กœ๋ฅผ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”.

  • remove_columns: bool = True, ๋ฐ์ดํ„ฐ ๋กœ๋“œ ํ›„ ๋ชจ๋ธ ์ž…๋ ฅ์„ ์œ„ํ•œ ํ”„๋กœ์„ธ์Šค ์™„๋ฃŒ ํ›„ ๊ธฐ์กด ์ปฌ๋Ÿผ ์ด๋ฆ„์„ ์‚ญ์ œํ• ์ง€ ์—ฌ๋ถ€๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

  • custom_task_infolist: Optional[List[TaskInfo]] = None, ์ง์ ‘ TaskInfo ํด๋ž˜์Šค๋ฅผ ์„ค์ •ํ•˜๊ณ  ์ด๋ฅผ ๋ฆฌ์ŠคํŠธ ์•ˆ์— ๋„ฃ์–ด ๋ฒค์น˜๋งˆํฌ ํ…Œ์ŠคํŠธ๋ฅผ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

  • max_source_length: int = 512, ์ž…๋ ฅ ํ…์ŠคํŠธ์˜ ์ตœ๋Œ€ ๊ธธ์ด๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

  • max_target_length: Optional[int] = None, (๋งŒ์•ฝ ์žˆ๋‹ค๋ฉด) ์ถœ๋ ฅ ํ…์ŠคํŠธ์˜ ์ตœ๋Œ€ ๊ธธ์ด๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

  • padding: str = "longest", padding์˜ ๋ฐฉ๋ฒ•์„ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค(transformers.PretrainedTokenizerBase.__call__์˜ 'padding'์ธ์ž์™€ ๋™์ผํ•ฉ๋‹ˆ๋‹ค).

  • save_model: bool = False, ๋ชจ๋ธ์„ ๋‚ด๋ถ€์— ์ €์žฅํ•  ์ง€ ์—ฌ๋ถ€๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

  • return_models: bool = False, ํ•จ์ˆ˜๊ฐ€ ํ•™์Šต๋œ ๋ชจ๋ธ์„ ๋ฐ˜ํ™˜ํ• ์ง€ ์—ฌ๋ถ€๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

  • output_dir: str = "runs/", (save_model=True์ผ ๋•Œ), ์ €์žฅํ•  ๋””๋ ‰ํ† ๋ฆฌ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

  • finetune_model_across_the_tasks: bool = False, ๋ชจ๋ธ์„ ์ž…๋ ฅ ๋ฐ›์€ ์—ฌ๋Ÿฌ ๋ฒค์น˜๋งˆํฌ์— ๋Œ€ํ•ด์„œ ์กฐ์ • ํ•™์Šต ์‹œ, ์ดˆ๊ธฐํ™” ํ• ์ง€๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค(True๋ฉด ํ•˜๋‚˜์˜ ๋ชจ๋ธ์ด ์—ฌ๋Ÿฌ ๋ฒค์น˜๋งˆํฌ์— ๋Œ€ํ•˜์—ฌ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค).

  • add_sp_tokens_to_unused:bool, ๊ณผ์ œ์—์„œ special_token ์„ unused ํ† ํฐ๊ณผ ๋Œ€์น˜ํ•  ์ง€๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.

(๊ทธ ๋ฐ–์— ํ—ˆ๊น…ํŽ˜์ด์Šค์˜ transformers.TrainingArguments ์˜ ๋ชจ๋“  ์ธ์ž๋ฅผ ์ž…๋ ฅํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.)

Custom Task

์ปค์Šคํ…€ ํ…Œ์Šคํฌ๋ฅผ ์ •์˜ํ•˜๊ณ  ์ด๋ฅผ ์กฐ์ • ํ•™์Šต ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ์‹œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

from koai import finetune
from koai.benchmarks.finetune_utils import TaskInfo


custom_task = TaskInfo(
    task="custom-task",
    task_type='sequence-classification',
    text_column = "text",
    label_column = "label",
    num_labels=3,
    eval_split = "test",
    custom_train_dataset = dataset['train'],
    custom_eval_dataset = dataset['test'],
    metric_name='f1'
)

models = finetune(
    "custom-task", 
    "klue/roberta-base", 
    custom_task_infolist=[custom_task], 
    do_train=True, 
    do_eval=True,
    num_train_epochs=10,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="no",
    padding=True,
    return_models=True,
    output_dir="runs"
)

Available Tasks

  • GLUE(except "glue-mnli_matched","glue-mnli_mismatched", and "glue-ax")
  • KLUE(except "klue-wos")
  • koai.benchmarks.finetune_utils.TaskInfo๋ฅผ ์ด์šฉํ•˜์—ฌ ์ปค์Šคํ…€ ํ…Œ์Šคํฌ์—๋„ ์ ์šฉ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค.

Issue

  • ํ˜„์žฌ ๊ฐœ๋ฐœ ์ค‘์— ์žˆ๋Š” ํ”„๋กœ์ ํŠธ์ž…๋‹ˆ๋‹ค. ํ–ฅํ›„ ๋ฒค์น˜๋งˆํฌ๊ฐ€ ์ถ”๊ฐ€๋  ์˜ˆ์ •์ž…๋‹ˆ๋‹ค.
  • ์†Œ์Šค์˜ ๋งŽ์€ ๋ถ€๋ถ„๋“ค์ด, https://github.com/huggingface/transformers/ ๋ฅผ ์ฐธ๊ณ  ๋ฐ ์ธ์šฉํ•˜์—ฌ ์ œ์ž‘๋˜์—ˆ์Šต๋‹ˆ๋‹ค.