Multilabel classification

Intro

The fastai library simplifies training fast and accurate neural nets using modern best practices. See the fastai website to get started. The library is based on research into deep learning best practices undertaken at fast.ai, and includes “out of the box” support for vision, text, tabular, and collab (collaborative filtering) models.

Multilabel

Grab data and take 1 % for fast training:

library(fastai)
library(magrittr)
library(zeallot)
df = HF_load_dataset('civil_comments', split='train[:1%]')

Preprocess

Select multiple outputs/columns:

df = data.table::as.data.table(df)

lbl_cols = c('severe_toxicity',
             'obscene',
             'threat',
             'insult',
             'identity_attack',
             'sexual_explicit')

df <- df[,(lbl_cols) := round(.SD,0), .SDcols=lbl_cols]
df <- df[, (lbl_cols) := lapply(.SD, as.integer), .SDcols=lbl_cols]

Pretrained model

Load distill RoBERTa:

task = HF_TASKS_ALL()$SequenceClassification

pretrained_model_name = "distilroberta-base"
config = AutoConfig()$from_pretrained(pretrained_model_name)
config$num_labels = length(lbl_cols)

c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name,
                                                                               task=task,
                                                                               config=config)
Downloading: 100%|██████████| 899k/899k [00:00<00:00, 961kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 597kB/s]
Downloading: 100%|██████████| 331M/331M [03:26<00:00, 1.61MB/s]

Datablock

Create data blocks:

blocks = list(
  HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer),
  MultiCategoryBlock(encoded=TRUE, vocab=lbl_cols)
)

dblock = DataBlock(blocks=blocks,
                   get_x=ColReader('text'), get_y=ColReader(lbl_cols),
                   splitter=RandomSplitter())

dls = dblock %>% dataloaders(df, bs=8)

dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[    0, 24268,  5257,  ...,     1,     1,     1],
        [    0,   287,  4505,  ...,     1,     1,     1],
        [    0,    38,   437,  ...,     1,     1,     1],
        ...,
        [    0,   152,  1129,  ...,     1,     1,     1],
        [    0,    85,    18,  ...,     1,     1,     1],
        [    0, 22014,    31,  ...,     1,     1,     1]], device='cuda:0')

[[1]]$attention_mask
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')


[[2]]
TensorMultiCategory([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]], device='cuda:0')

Model

model = HF_BaseModelWrapper(hf_model)

learn = Learner(dls,
                model,
                opt_func=partial(Adam),
                loss_func=BCEWithLogitsLossFlat(),
                metrics=partial(accuracy_multi(), thresh=0.2),
                cbs=HF_BaseModelCallback(),
                splitter=hf_splitter())

learn$loss_func$thresh = 0.2
learn$create_opt()             # -> will create your layer groups based on your "splitter" function
learn$freeze()

learn %>% summary()

See summary:

epoch   train_loss   valid_loss   accuracy_multi   time  
------  -----------  -----------  ---------------  ------
HF_BaseModelWrapper (Input shape: 8 x 391)
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
Embedding            8 x 391 x 768        38,603,520 False     
________________________________________________________________
Embedding            8 x 391 x 768        394,752    False     
________________________________________________________________
Embedding            8 x 391 x 768        768        False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
Dropout              8 x 12 x 391 x 391   0          False     
________________________________________________________________
Linear               8 x 391 x 768        590,592    False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 391 x 3072       2,362,368  False     
________________________________________________________________
Linear               8 x 391 x 768        2,360,064  False     
________________________________________________________________
LayerNorm            8 x 391 x 768        1,536      True      
________________________________________________________________
Dropout              8 x 391 x 768        0          False     
________________________________________________________________
Linear               8 x 768              590,592    True      
________________________________________________________________
Dropout              8 x 768              0          False     
________________________________________________________________
Linear               8 x 6                4,614      True      
________________________________________________________________

Total params: 82,123,014
Total trainable params: 615,174
Total non-trainable params: 81,507,840

Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fee7e8166a8>)
Loss function: FlattenedLoss of BCEWithLogitsLoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - HF_BaseModelCallback

Conclusion

Finally, fit the model:

lrs = learn %>% lr_find(suggestions=TRUE)

learn %>% fit_one_cycle(1, lr_max=1e-2)
epoch   train_loss   valid_loss   accuracy_multi   time  
------  -----------  -----------  ---------------  ------
0       0.040617     0.034286     0.993257         01:21 

Predict:

learn$loss_func$thresh = 0.02

learn %>% predict("Those damned affluent white people should only eat their own food, like cod cakes and boiled potatoes.
No enchiladas for them!")
$probabilities
  severe_toxicity     obscene       threat     insult identity_attack sexual_explicit
1    9.302437e-07 0.004268706 0.0007849637 0.02687055     0.003282947      0.00232468

$labels
[1] "insult"