The fastai library simplifies training fast and accurate neural nets using modern best practices. See the fastai website to get started. The library is based on research into deep learning best practices undertaken at fast.ai
, and includes “out of the box” support for vision
, text
, tabular
, and collab
(collaborative filtering) models.
Grab data for binary classification:
library(fastai)
library(magrittr)
library(zeallot)
URLs_IMDB_SAMPLE()
Define task:
= HF_TASKS_AUTO()
HF_TASKS_AUTO = HF_TASKS_AUTO$SequenceClassification
task
= "roberta-base" # "distilbert-base-uncased" "bert-base-uncased"
pretrained_model_name c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name, task=task)
Downloading: 100%|██████████| 481/481 [00:00<00:00, 277kB/s]
Downloading: 100%|██████████| 899k/899k [00:01<00:00, 580kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 471kB/s]
Downloading: 100%|██████████| 501M/501M [03:11<00:00, 2.62MB/s]
Create Learner
with Hugging Face data blocks:
= data.table::fread('imdb_sample/texts.csv')
imdb_df
= list(HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer), CategoryBlock())
blocks
= DataBlock(blocks=blocks,
dblock get_x=ColReader('text'),
get_y=ColReader('label'),
splitter=ColSplitter(col='is_valid'))
= dblock %>% dataloaders(imdb_df, bs=4)
dls %>% one_batch() dls
[[1]]
[[1]]$input_ids
tensor([[ 0, 4833, 3009, ..., 1916, 6, 2],
[ 0, 1876, 13856, ..., 7, 47, 2],
[ 0, 2647, 6, ..., 6, 61, 2],
[ 0, 20, 2091, ..., 5779, 30, 2]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]], device='cuda:0')
[[2]]
TensorCategory([0, 1, 0, 0], device='cuda:0')
Wrap model:
= HF_BaseModelWrapper(hf_model)
model
= Learner(dls,
learn
model,opt_func=partial(Adam, decouple_wd=TRUE),
loss_func=CrossEntropyLossFlat(),
metrics=accuracy,
cbs=HF_BaseModelCallback(),
splitter=hf_splitter())
$create_opt()
learn$freeze()
learn
%>% summary() learn
epoch train_loss valid_loss accuracy time
------ ----------- ----------- --------- ------
HF_BaseModelWrapper (Input shape: 4 x 512)
================================================================
Layer (type) Output Shape Param # Trainable
================================================================
Embedding 4 x 512 x 768 38,603,520 False
________________________________________________________________
Embedding 4 x 512 x 768 394,752 False
________________________________________________________________
Embedding 4 x 512 x 768 768 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
Dropout 4 x 12 x 512 x 512 0 False
________________________________________________________________
Linear 4 x 512 x 768 590,592 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 512 x 3072 2,362,368 False
________________________________________________________________
Linear 4 x 512 x 768 2,360,064 False
________________________________________________________________
LayerNorm 4 x 512 x 768 1,536 True
________________________________________________________________
Dropout 4 x 512 x 768 0 False
________________________________________________________________
Linear 4 x 768 590,592 True
________________________________________________________________
Dropout 4 x 768 0 False
________________________________________________________________
Linear 4 x 2 1,538 True
________________________________________________________________
Total params: 124,647,170
Total trainable params: 630,530
Total non-trainable params: 124,016,640
Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fd850db18c8>, decouple_wd=True)
Loss function: FlattenedLoss of CrossEntropyLoss()
Model frozen up to parameter group #2
Callbacks:
- TrainEvalCallback
- Recorder
- ProgressCallback
- HF_BaseModelCallback
Train and predict:
= learn %>% fit_one_cycle(3, lr_max=1e-3)
result
%>% predict(imdb_df$text[1:4]) learn