Cum să imprimați numărul de parametri de model în PyTorch

Cum Sa Imprimati Numarul De Parametri De Model In Pytorch



PyTorch este un cadru popular care este folosit în deep learning. Oferă funcții multiple pentru generarea de rețele neuronale complexe (NN). Utilizatorii pot efectua operațiuni de antrenament model cu acest cadru. Cu toate acestea, utilizatorii trebuie să fie familiarizați cu numărul de parametri înainte de a antrena modelul.

Acest blog va descrie:

Care sunt parametrii din PyTorch?

În PyTorch, „ nn.Modul ” este folosită pentru definirea modelelor. Include toate operațiunile și straturile care alcătuiesc modelul. Fiecare strat conține un set de parametri. Parametrii sunt practic actualizați în timpul antrenamentului pentru a minimiza eroarea dintre valorile reale și predicțiile modelului.







De ce trebuie utilizatorii să verifice parametrii modelului?

În timp ce antrenează modelul, utilizatorii trebuie să știe despre numărul de parametri ai modelului lor, deoarece necesită multă memorie și putere de procesare. Dacă sunt familiarizați cu numărul de parametri ai modelului, ei pot evalua cu ușurință cantitatea de memorie care va fi necesară și cât timp va dura antrenamentul, ceea ce îi ajută pe utilizatori să-și optimizeze procesul de antrenament și să prevină epuizarea sistemului. spaţiu.



Cum să afișați numărul de parametri de model în PyTorch?

nn.Modul „clasa are „ parametri() ” metoda care este utilizată pentru a vizualiza numărul de parametri de model din modelul PyTorch. Pentru a obține toate elementele, „ num1() ” se folosește metoda.



Pentru a înțelege conceptul discutat anterior, să aruncăm o privire la codul furnizat:





import torță. nn la fel de nn

clasă NNModel ( nn. Modul ) :
def __Fierbinte__ ( de sine ) :
super ( NNModel , de sine ) . __Fierbinte__ ( )
de sine . fc1 = nn. Liniar ( 10 , cincizeci )
de sine . fc2 = nn. Liniar ( cincizeci , 1 )

def redirecţiona ( de sine , i ) :
i = de sine . fc1 ( i )
i = de sine . fc2 ( i )
întoarcere i

modelul_meu = NNModel ( )
t_params = sumă ( p. numel ( ) pentru p în modelul_meu. parametrii ( ) )
imprimare ( f „Numărul total de parametri: {t_params}” )

În codul de mai sus:

  • În primul rând, definim un model care are două straturi liniare.
  • Apoi, generați instanța modelului și utilizați „ parametri() ” pentru a prelua toți parametrii.
  • Apoi, aplicăm expresia generatorului pentru a calcula toți parametrii prin însumarea numărului de elemente ale fiecărui parametru.
  • În cele din urmă, sunați la „ imprimare() ” instrucțiune pentru a afișa valorile rezultate pe ecran:



În codul descris mai sus, am afișat doar numărul total de parametri, dacă doriți să obțineți numele și dimensiunea parametrului, se pot folosi următoarele rânduri de cod:

pentru Nume , param în modelul_meu. state_dict ( ) . articole ( ) :

imprimare ( Nume , param. mărimea ( ) )

Aici:

  • stat_dict() ” este obiectul dicționar Python care este utilizat pentru stocarea și încărcarea modelelor din PyTorch.
  • articol() ” este utilizată pentru a returna lista cu toate cheile de dicționar împreună cu valorile.
  • imprimare() ” instrucțiunea este utilizată pentru a tipări numele și dimensiunea parametrului prin trecerea „ mărimea() ” metoda și parametrul:

Asta e tot! Am compilat cel mai simplu mod de a imprima numărul de parametri de model în PyTorch.

Concluzie

În PyTorch, „ nn.Modul ” se folosește pentru definirea modelelor care includ toate operațiunile și straturile care alcătuiesc modelul. „ nn.Modul „clasa are „ parametri() ” metoda care este utilizată pentru a vizualiza numărul de parametri de model din modelul PyTorch. Acest articol a demonstrat metoda de imprimare a numărului de parametri de model în PyTorch.