Learner Framework

L'implementazione e gestione del training loop può complicarsi da subito. Cambiamenti, aggiunte di codice, necessità di debugging, monitoraggio dei parametri etc. rendono difficilmente gestibile le logiche dentro la funzione fit().
Per questi motivi si rende necessaria la creazione di un vero e proprio framework per addestrare le reti neurali, che sia flessibile e che sappia adattarsi ad ogni modello applicato.
L'idea è di creare una classe Learner che implementa funzioni per gestire il training in modo ordinato e facilmente modificabile.

Callbacks

Quando vogliamo stampare dei dati all'interno del training loop, oppure vogliamo debuggare, se poniamo il codice di queste funzionalità "accessorie", separato ed esterno al training loop, riusciamo a semplificarne di molto la gestione.
Poniamo il codice che esegue la funzionalità accessoria in una funzione esterna e la eseguiamo passandola come argomento ad un'altra funzione che viene eseguita nel training loop.
Definiamo cioè, di fatto, delle callback richiamate all'interno del training loop, in alcuni punti chiave predeterminati, come per esempio:

  • before_fit (prima dell'esecuzione del training loop)
  • before_epoch (prima di ogni iterazione sull'intero Dataset)
  • before_batch (prima di eseguire forward() su un singolo batch)
  • after_batch (dopo calcolata la loss function e aggiornati i pesi se si è in modalità training, per un singolo batch)
  • after_epoch (dopo l'iterazione sull'intero Dataset)
  • after_fit (dopo l'esecuzione del training loop)

Possiamo ovviamente individuare altri punti, se necessario. (es.: after_predict, after_loss etc.) Possiamo eseguire una lista di callback in ognuno di questi punti individuati, definendo una funzione di raccolta run_cbs() che le richiama.
E' importante definire un sistema per stabilire l'eventuale ordine (order) di esecuzione delle callback della lista e, nel caso in cui le callback si influenzino tra loro, un elenco di eccezioni sollevate durante l'esecuzione di una specifica callback, che blocca l'esecuzione di altre callback.

class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

class Callback(): order = 0

def run_cbs(cbs, method_name, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_name, None)
        if method is not None: method(learn)

Passiamo alla definizione del framework, strutturato dalla funzione fit():

class Learner():
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.preds = self.model(self.batch[0])
        self.loss = self.loss_func(self.preds, self.batch[1])
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()

    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        try:
            self.callback('before_epoch')
            for self.iter,self.batch in enumerate(self.dl):
                try:
                    self.callback('before_batch')
                    self.one_batch()
                    self.callback('after_batch')
                except CancelBatchException: pass
            self.callback('after_epoch')
        except CancelEpochException: pass
    
    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        try:
            self.callback('before_fit')
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epoch(False)
            self.callback('after_fit')
        except CancelFitException: pass

    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)

    @property
    def training(self): return self.model.training

Abbiamo così un sistema flessibile e facilmente ampliabile, dove risulta molto semplice aggiungere funzionalità.