Training loop

Generalmente l'implementazione di base di un training loop in pytorch prevede una funzione fit() che si occupa del training e della valutazione del modello.

All'interno del training loop viene eseguito per ogni epoch, prima un addestramento usando tutti i batch di valid_dl e poi un'inferenza usando tutti i batch di valid_dl.

Nel dettaglio il training loop itera per ogni epoch attraverso i seguenti passaggi:

Addestramento

  • il modello viene settato in modalità train
  • vengono presi tutti i batch del training set e su ogni batch viene internamente eseguito il forward() del modello, producendo una previsione in output.
  • Viene calcolata la loss function, mettendo a confronto la previsione prodotta dal modello con l'effettiva e reale label (target) del batch di training
  • Viene eseguito il backward() per calcolare tutti i gradienti dei parametri del modello
  • Viene fatto l'update dei parametri (aggiornamento dei pesi o step()) sottraendo i gradienti appena calcolati
  • Vengono azzerati i gradienti appena calcolati

Inferenza

Arrivati a questo punto, con i valori dei pesi del modello settati dopo l'iterazione su TUTTI i batch di training, possiamo valutare il modello con i dati di validazione usando quei particolari pesi. Quindi la stessa logica va applicata in modo simile per il validation dataloader, avendo accortezza di non calcolare i gradienti:

  • il modello viene settato in modalità eval
  • vengono presi tutti i batch del validation set e su ogni batch viene internamente eseguito il forward() del modello che è stato addestrato con tutti i batch del dataloader di training, producendo una previsione in output.
  • Viene calcolata la loss function, mettendo a confronto la previsione prodotta dal modello con l'effettiva e reale label (target) del batch di validazione
  • Vengono eseguiti i calcoli per determinare il valore totale della loss function e l'accuratezza del modello e vengono stampati.
def accuracy(out, yb): return (out.argmax(dim=1)==yb).float().mean() # utility to print accuracy value

def fit(epochs, model, optimizer, loss_fn, train_dl, valid_dl):
    
    for epoch in range(epochs):
        model.train()

        for batch_index, (X_batch, y_batch) in enumerate(train_dl):
            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        with torch.no_grad():  
            tot_loss, tot_accuracy, count = 0.,0.,0
            
            for X_valid, y_valid in valid_dl:        
                
                val_probs = model(X_valid)
                val_loss = loss_fn(val_probs, y_valid) # labels.view(val_probs.shape)
                
                n = len(X_valid)
                count += n
                
                tot_loss += loss_fn(val_probs, y_valid).item()*n
                tot_accuracy  += accuracy (val_probs, y_valid).item()*n

        print(f'{epoch} loss = {tot_loss/count}, accuracy = {tot_accuracy/count}')

    return tot_loss/count, tot_accuracy/count

Per eseguire la funzione fit(), consideriamo di avere già un dataloader, un modello, un ottimizzatore e una loss function:

learning_rate = 0.1
train_dl, valid_dl # train dataloader e validation dataloader
model = cnn # cnn è una rete pytorch
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(cnn.parameters(), lr=learning_rate)

n_epochs = 100

fit(n_epochs, model0, optimizer, loss_fn, train_dl, valid_dl)

# print output d'esempio: 
# 0 loss = 0.5227519570827485, accuracy = 0.8274
# 1 loss = 0.2518924878358841, accuracy = 0.9262
# 2 loss = 0.18806885157227515, accuracy = 0.9428
# ....
# 46 loss = 0.06776001966204494, accuracy = 0.9803  # nota come diminuisce il loss ed aumenta l'accuratezza...