๐Ÿ Python & library/HuggingFace

[HuggingFace] Trainer ์‚ฌ์šฉ๋ฒ•

๋ณต๋งŒ 2022. 7. 23. 15:27

Official Docs: https://huggingface.co/docs/transformers/v4.19.2/en/main_classes/trainer

 

Trainer

When using gradient accumulation, one step is counted as one step with backward pass. Therefore, logging, evaluation, save will be conducted every gradient_accumulation_steps * xxx_step training examples.

huggingface.co

 

Trainer class๋Š” ๋ชจ๋ธํ•™์Šต๋ถ€ํ„ฐ ํ‰๊ฐ€๊นŒ์ง€ ํ•œ ๋ฒˆ์— ํ•ด๊ฒฐํ•  ์ˆ˜ ์žˆ๋Š” API๋ฅผ ์ œ๊ณตํ•œ๋‹ค. ๋‹ค์Œ์˜ ์‚ฌ์šฉ์˜ˆ์‹œ๋ฅผ ๋ณด๋ฉด ์ง๊ด€์ ์œผ๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.

 

from transformers import Trainer

#initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset
    eval_dataset=eval_dataset
    compute_metrics,
    tokenizer=tokenizer
)

#train
trainer.train()

#save
trainer.save_model()

#eval
metrics = trainer.evaluate(eval_dataset=eval_dataset)

 

 

Initialize Trainer

from transformers import Trainer

#initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset
    eval_dataset=eval_dataset
    compute_metrics=compute_metrics,
)

 

  • ๊ธฐ๋ณธ์ ์œผ๋กœ ์œ„์™€ ๊ฐ™์ด Trainer์„ ์„ ์–ธํ•  ์ˆ˜ ์žˆ๋‹ค. (์ด์™ธ์—๋„ ๋” ๋งŽ์€ argument๋“ค์ด ์กด์žฌํ•œ๋‹ค) model์€ HuggingFace ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ์—์„œ ์ œ๊ณต๋˜๋Š” PretrainedModel์„ ์‚ฌ์šฉํ•ด๋„ ๋˜์ง€๋งŒ, torch.nn.Module์„ ์‚ฌ์šฉํ•  ์ˆ˜๋„ ์žˆ๋‹ค. ๋ชจ๋ธ์„ ์ง€์ •ํ•˜๋Š” ๋ฐฉ๋ฒ•์€ ์œ„์™€ ๊ฐ™์ด model argument๋กœ ์ค„ ์ˆ˜๋„ ์žˆ๊ณ , ํ˜น์€, ์•„๋ž˜์™€ ๊ฐ™์ด callableํ•œ ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜๋ฅผ model_init argument๋กœ ์ค„ ์ˆ˜๋„ ์žˆ๋‹ค. ๋งŒ์•ฝ model_init์„ ์ด์šฉํ•ด ๋ชจ๋ธ์„ ์ง€์ •ํ•ด์ฃผ๋ฉด, ๋งค train() method๊ฐ€ ํ˜ธ์ถœ๋  ๋•Œ๋งˆ๋‹ค ๋ชจ๋ธ์ด ์ƒˆ๋กญ๊ฒŒ ์ดˆ๊ธฐํ™”(์ƒ์„ฑ)๋œ๋‹ค. 
from transformers import Trainer, AutoModelForSequenceClassification

#initialize Trainer
trainer = Trainer(
    model_init=AutoModelForSequenceClassification.from_pretrained(model_name),
    args=training_args,
    train_dataset=train_dataset
    eval_dataset=eval_dataset
    compute_metrics=compute_metrics,
)

 

  • args๋Š” train์— ํ•„์š”ํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ๋“ค์˜ ๋ชจ์Œ์œผ๋กœ, TrainingArgs๋ฅผ ์ด์šฉํ•ด ์ค„ ์ˆ˜ ์žˆ๋‹ค. Optimizer์˜ ์ข…๋ฅ˜, learning rate, epoch ์ˆ˜, scheduler, half precision ์‚ฌ์šฉ์—ฌ๋ถ€ ๋“ฑ์„ ์ง€์ •ํ•  ์ˆ˜ ์žˆ์œผ๋ฉฐ, ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ๋ชฉ๋ก์€ ์—ฌ๊ธฐ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค. ์˜ˆ์‹œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1,              # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=10,   # batch size for evaluation
    warmup_steps=1000,               # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,               # How often to print logs
    do_train=True,                   # Perform training
    do_eval=True,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after eachh epoch
    gradient_accumulation_steps=64,  # total number of steps before back propagation
    fp16=True,                       # Use mixed precision
    fp16_opt_level="02",             # mixed precision mode
    run_name="ProBert-BFD-MS",       # experiment name
    seed=3                           # Seed for experiment reproducibility 3x3
)

 

  • train_dataset๊ณผ eval_dataset์€ ๊ฐ๊ฐ train๊ณผ validation/test์— ์‚ฌ์šฉ๋˜๋Š” torch.utils.data.Dataset์ด๋‹ค. ๊ผญ ์ดˆ๊ธฐํ™” ํ•  ๋•Œ ์ง€์ •ํ•˜์ง€ ์•Š์•„๋„ ๋œ๋‹ค.

 

  • compute_metrics๋Š” evaluation์— ์‚ฌ์šฉํ•  metric์„ ๊ณ„์‚ฐํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค. ๋ชจ๋ธ์˜ output์ธ EvalPrediction์„ input์œผ๋กœ ๋ฐ›์•„ metric์„ dictionary ํ˜•ํƒœ๋กœ returnํ•˜๋Š” ํ•จ์ˆ˜๊ฐ€ ๋˜์•ผ ํ•œ๋‹ค. ์˜ˆ์‹œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    auc = roc_auc_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall,
        'auroc': auc
    }
๋ฐ˜์‘ํ˜•