Implementing Vectoring Calcs
Alcuni esempi di implementazione considerando i tensori:
Consideriamo la parte finale del forward pass della rete makemore:
# forward pass
# ...
# ...
# n è la dimensione del batch = 32
# Yb è l'array degli output [32]
loss = -logprobs[range(n), Yb].mean()
logprobs ha dim. [32, 27] e dlogprobs dovrà avere stesse dimensioni.
La media contribuisce alla derivata parziale di un elemento dentro logprobs per 1/n
e logprobs mette segno negativo davanti all'elemento.
Quindi possiamo dire che:
dlogprobs = -1/n
In realtà gli unici elementi che saranno impattati sono quelli corrispondenti a [range(n), Yb], mentre tutti gli altri non saranno rilevanti per la derivata.
Possiamo inizializzare dlogprobs a zero e renderlo = -1/n solo nei punti che ci interessano
dlogprobs = torch.zeros_like(logprobs) # inizializza un tensore di stessa forma di logprobs a zero
dlogprobs[range(n), Yb] = -1/n
moltiplicazione tra tensori
Consideriamo due tensori, a[3, 3] * b[3, 1] e moltiplichiamoli tra loro:
c = a * b,
Per moltiplicare a e b viene prima replicata l'unica colonna di b n volte (implicit broadcasting),
quante sono le colonne di a, in modo da coprire completamente a,
e quindi viene fatto il prodotto di ogni elemento:
a11*b1 a12*b1 a13*b1
a21*b2 a22*b2 a23*b2
a31*b3 a32*b3 a33*b3
A causa della replica, abbiamo la situazione in cui, il nodo b1, (o b2 o b3) viene usato più volte nella rete.
Per il calcolo del gradiente, quando un nodo viene usato più volte, i gradienti di quel nodo vanno sommati tra di loro,
in questo caso riga per riga.
Quindi, la derivata parziale del tensore a rispetto a c è:
b.sum(1, keepdim=True)
esempio: vogliamo calcolare dcounts_sum_inv a partire da questa situazione:
# forward pass:
probs = counts * counts_sum_inv
logprobs = probs.log()
loss = -logprobs[range(n), Yb].mean()
immaginando di aver già calcolato dlogprobs e dprobs, per calcolare dcounts_sum_inv, bisogna considerare solo l'espressione:
probs = counts * counts_sum_inv
e le forme dei tensori:
probs.shape # [32, 27]
counts.shape # [32, 27]
counts_sum_inv.shape # [32, 1]
Nella moltiplicazione si ha una replica di counts_sum_inv, questo comporterà una somma dei gradienti riga per riga.
la derivata parziale di counts_sum_inv = counts .
Possiamo poi applicare la chain rule moltiplicando per dprobs.
Alla fine applichiamo la somma dei gradienti riga per riga.
dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
Allo stesso modo, per calcolare dcounts, considerando sempre probs = counts * counts_sum_inv:
dcounts = counts_sum_inv
Applichiamo sempre la chain rule, anche.
La replica NON influisce su dcounts, ma su dcounts_sum_inv, in quanto non replichiamo counts ma count_sum_inv.
dcounts = counts_sum_inv * dprobs
utilizzo dello stesso nodo in più equazioni del forward pass.
Quando un nodo è usato più volte nella rete, i gradienti di tutti gli usi si sommano tra di loro.
continuando l'esempio di prima:
counts = norm_logits.exp()
counts_sum = counts.sum(1, keepdims=True) # <--- 1
counts_sum_inv = counts_sum**-1
probs = counts * counts_sum_inv # <--- 2
il nodo counts è a destra del segno di uguale in 2 espressioni, è usato, cioè due volte nella rete.
Il gradiente dcounts sarà dato dalla somma del gradiente dell'espressione 1 e del gradiente dell'espressione 2.
Continuando l'esempio di calcolo che stiamo facendo, al momento abbiamo calcolato solo il contributo dcounts2 dell'espressione 2:
dcounts2 = counts_sum_inv * dprobs # contributo dell'espressione 2
Per il contributo dell'espressione 1 counts_sum = counts.sum(1, keepdims=True), le forme sono:
counts.shape # [32, 27]
counts_sum.shape # [32, 1]
La colonna counts_sum va sommata riga per riga.
Possiamo implementare dcounts1 per l'espressione 1 in questo modo:
dcounts1 = torch.ones_like(counts) * dcounts_sum
avremo che dcounts è dato da:
dcounts = dcounts1 + dcounts2