Multilabel classification

Intro

The fastai library simplifies training fast and accurate neural nets using modern best practices. See the fastai website to get started. The library is based on research into deep learning best practices undertaken at fast.ai, and includes “out of the box” support for vision, text, tabular, and collab (collaborative filtering) models.

Multilabel

Grab data and take 1 % for fast training:

library(fastai)
library(magrittr)
library(zeallot)
df = HF_load_dataset('civil_comments', split='train[:1%]')

Preprocess

Select multiple outputs/columns:

df = data.table::as.data.table(df)

lbl_cols = c('severe_toxicity',
'obscene',
'threat',
'insult',
'identity_attack',
'sexual_explicit')

df <- df[,(lbl_cols) := round(.SD,0), .SDcols=lbl_cols]
df <- df[, (lbl_cols) := lapply(.SD, as.integer), .SDcols=lbl_cols]

Pretrained model

task = HF_TASKS_ALL()$SequenceClassification pretrained_model_name = "distilroberta-base" config = AutoConfig()$from_pretrained(pretrained_model_name)
config$num_labels = length(lbl_cols) c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task, config=config) Downloading: 100%|██████████| 899k/899k [00:00<00:00, 961kB/s] Downloading: 100%|██████████| 456k/456k [00:00<00:00, 597kB/s] Downloading: 100%|██████████| 331M/331M [03:26<00:00, 1.61MB/s] Datablock Create data blocks: blocks = list( HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), MultiCategoryBlock(encoded=TRUE, vocab=lbl_cols) ) dblock = DataBlock(blocks=blocks, get_x=ColReader('text'), get_y=ColReader(lbl_cols), splitter=RandomSplitter()) dls = dblock %>% dataloaders(df, bs=8) dls %>% one_batch() [[1]] [[1]]$input_ids
tensor([[    0, 24268,  5257,  ...,     1,     1,     1],
[    0,   287,  4505,  ...,     1,     1,     1],
[    0,    38,   437,  ...,     1,     1,     1],
...,
[    0,   152,  1129,  ...,     1,     1,     1],
[    0,    85,    18,  ...,     1,     1,     1],
[    0, 22014,    31,  ...,     1,     1,     1]], device='cuda:0')

learn$freeze() learn %>% summary() See summary: epoch train_loss valid_loss accuracy_multi time ------ ----------- ----------- --------------- ------ HF_BaseModelWrapper (Input shape: 8 x 391) ================================================================ Layer (type) Output Shape Param # Trainable ================================================================ Embedding 8 x 391 x 768 38,603,520 False ________________________________________________________________ Embedding 8 x 391 x 768 394,752 False ________________________________________________________________ Embedding 8 x 391 x 768 768 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ Dropout 8 x 12 x 391 x 391 0 False ________________________________________________________________ Linear 8 x 391 x 768 590,592 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 391 x 3072 2,362,368 False ________________________________________________________________ Linear 8 x 391 x 768 2,360,064 False ________________________________________________________________ LayerNorm 8 x 391 x 768 1,536 True ________________________________________________________________ Dropout 8 x 391 x 768 0 False ________________________________________________________________ Linear 8 x 768 590,592 True ________________________________________________________________ Dropout 8 x 768 0 False ________________________________________________________________ Linear 8 x 6 4,614 True ________________________________________________________________ Total params: 82,123,014 Total trainable params: 615,174 Total non-trainable params: 81,507,840 Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fee7e8166a8>) Loss function: FlattenedLoss of BCEWithLogitsLoss() Model frozen up to parameter group #2 Callbacks: - TrainEvalCallback - Recorder - ProgressCallback - HF_BaseModelCallback Conclusion Finally, fit the model: lrs = learn %>% lr_find(suggestions=TRUE) learn %>% fit_one_cycle(1, lr_max=1e-2) epoch train_loss valid_loss accuracy_multi time ------ ----------- ----------- --------------- ------ 0 0.040617 0.034286 0.993257 01:21  Predict: learn$loss_func$thresh = 0.02 learn %>% predict("Those damned affluent white people should only eat their own food, like cod cakes and boiled potatoes. No enchiladas for them!") $probabilities
severe_toxicity     obscene       threat     insult identity_attack sexual_explicit
1    9.302437e-07 0.004268706 0.0007849637 0.02687055     0.003282947      0.00232468

\$labels
[1] "insult"