Callbacks examples
Elenchiamo alcuni esempi di callback da usare assieme al learner framework.
Callback di esempio
stampa semplicemente un contatore dopo ogni iterazione sul batch
class CompletionCB(Callback):
def before_fit(self, learn): self.count = 0
def after_batch(self, learn): self.count += 1
def after_fit(self, learn): print(f'Completed {self.count} batches')
DeviceCB (CUDA Callback)
Callback utile per memorizzare il modello e il batch sulla GPU
Nota come il modello viene salvato all'inizio, prima del training loop, e come, invece, il batch viene passato sulla GPU prima di ogni forward().
from typing import Mapping
def_device = 'cuda' if torch.cuda.is_available() else 'cpu'
def to_device(x, device=def_device):
if isinstance(x, torch.Tensor): return x.to(device)
if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
return type(x)(to_device(o, device) for o in x)
class DeviceCB(Callback):
def __init__(self, device=def_device): fc.store_attr()
def before_fit(self, learn):
if hasattr(learn.model, 'to'): learn.model.to(self.device)
def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device)
MetricsCB
per stampare alcune metriche:
#!pip install torcheval
from copy import copy
from torcheval.metrics import MulticlassAccuracy,Mean
class MetricsCB(Callback):
def __init__(self, *ms, **metrics):
for o in ms: metrics[type(o).__name__] = o
self.metrics = metrics
self.all_metrics = copy(metrics)
self.all_metrics['loss'] = self.loss = Mean()
def _log(self, d): print(d)
def before_fit(self, learn): learn.metrics = self
def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]
def after_epoch(self, learn):
log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
log['epoch'] = learn.epoch
log['train'] = 'train' if learn.model.training else 'eval'
self._log(log)
def after_batch(self, learn):
x,y,*_ = to_cpu(learn.batch)
for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
self.loss.update(to_cpu(learn.loss), weight=len(x))
ProgressCB
stampa le metriche con progressbar
from fastprogress import progress_bar,master_bar
class ProgressCB(Callback):
order = MetricsCB.order+1
def __init__(self, plot=False): self.plot = plot
def before_fit(self, learn):
learn.epochs = self.mbar = master_bar(learn.epochs)
self.first = True
if hasattr(learn, 'metrics'): learn.metrics._log = self._log
self.losses = []
self.val_losses = []
def _log(self, d):
if self.first:
self.mbar.write(list(d), table=True)
self.first = False
self.mbar.write(list(d.values()), table=True)
def before_epoch(self, learn): learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)
def after_batch(self, learn):
learn.dl.comment = f'{learn.loss:.3f}'
if self.plot and hasattr(learn, 'metrics') and learn.training:
self.losses.append(learn.loss.item())
if self.val_losses: self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])
def after_epoch(self, learn):
if not learn.training:
if self.plot and hasattr(learn, 'metrics'):
self.val_losses.append(learn.metrics.all_metrics['loss'].compute())
self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch+1).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])