Acest blog va exemplifica metoda de utilizare a metodei „torch.argmax()” în PyTorch.
Cum se utilizează metoda „torch.argmax()” în PyTorch?
Metoda „torch.argmax()” ia orice tensor 1D sau 2D ca intrare și returnează un tensor care conține indici/indici ai valorilor maxime de-a lungul dimensiunii date.
Sintaxa metodei „torch.argmax()” este prezentată mai jos:
torță. argmax ( < input_tensor > )
Pentru a utiliza această metodă în PyTorch, parcurgeți următoarele exemple pentru o mai bună înțelegere:
Exemplul 1: Folosiți metoda „torch.argmax()” cu tensorul 1D
În primul exemplu, vom crea un tensor 1D și vom folosi metoda „torch.argmax()” cu acesta. Să urmăm procedura pas cu pas de mai jos:
Pasul 1: importați biblioteca PyTorch
Mai întâi, importați „ torță ” pentru a utiliza metoda „torch.argmax()”:
import torțăPasul 2: Creați tensorul 1D
Apoi, creați un tensor 1D și imprimați elementele acestuia. Aici, creăm următorul „ Zeci1 ” tensor dintr-o listă folosind “ torță.tensor() ”funcție:
Zeci1 = torță. tensor ( [ 5 , 0 , - 8 , 1 , 9 , 7 ] )
imprimare ( Zeci1 )
Acest lucru a creat un tensor 1D, așa cum se vede mai jos:
Pasul 3: Găsiți indici de valoare maximă
Acum, utilizați „ torch.argmax() ” pentru a găsi indicele/indicii valorii maxime din “ Zeci1 ” tensor:
T1_ind = torță. argmax ( Zeci1 )Pasul 4: Imprimați indicele valorii maxime
În cele din urmă, afișați indicele valorii maxime în tensorul de intrare:
imprimare ( „Indici:” , T1_ind )Ieșirea de mai jos arată indicele valorii maxime în „ Zeci1 ” tensor, adică 4. Înseamnă că cea mai mare valoare a tensorului se află la al 4-lea indice care este „ 9 ”:
Exemplul 2: Utilizați metoda „torch.argmax()” cu tensorul 2D
În al doilea exemplu, vom crea un tensor 2D și vom folosi metoda „torch.argmax()” cu acesta. Să urmăm pașii furnizați:
Pasul 1: importați biblioteca PyTorch
Mai întâi, importați „ torță ” pentru a utiliza metoda „torch.argmax()”:
import torțăPasul 2: Creați Tensor 2D
Apoi, utilizați „ torță.tensor() ” pentru a crea un tensor 2D și a imprima elementele acestuia. Aici, creăm următorul „ Zeci2 „Tensor 2D:
Zeci2 = torță. tensor ( [ [ 4 , 1 , - 7 ] , [ cincisprezece , 6 , 0 ] , [ - 7 , 9 , 2 ] ] )imprimare ( Zeci2 )
Aceasta a creat un tensor 2D, așa cum se vede mai jos:
Pasul 3: Găsiți indici de valoare maximă
Acum, găsiți indicele valorii maxime în „ Zeci2 ” tensor prin utilizarea “ torch.argmax() ”funcție:
T2_ind = torță. argmax ( Zeci2 )Pasul 4: Imprimați indicele valorii maxime
În cele din urmă, afișați indicele valorii maxime în tensorul de intrare:
imprimare ( „Indici:” , T2_ind )Conform rezultatului de mai jos, indicele valorii maxime în „ Zeci2 „tensorul este „3”. Înseamnă că cea mai mare valoare a tensorului se află la al treilea indice care este „ cincisprezece ”:
Pasul 5: Găsiți indici de valoare maximă de-a lungul coloanelor
Mai mult, utilizatorii pot găsi, de asemenea, indicii/indicii valorilor maxime de-a lungul fiecărei coloane a unui tensor. De exemplu, putem folosi „ dim=0 ” argument cu funcția „torch.argmax()”. Găsește indicii valorilor maxime de-a lungul coloanelor din „ Zeci2 ” tensor și apoi imprimă acești indici:
col_index = torță. argmax ( Zeci2 , dim = 0 )imprimare ( „Indici în coloane:” , col_index )
Rezultatul de mai jos arată indicii valorilor maxime de-a lungul fiecărei coloane a tensorului:
Pasul 6: Găsiți indici de valoare maximă de-a lungul rândurilor
În mod similar, utilizatorii pot găsi, de asemenea, indici/indici ai valorilor maxime de-a lungul fiecărui rând al unui tensor. De exemplu, utilizați „ dim=1 ” cu funcția „torch.argmax()” pentru a găsi indicii valorilor maxime de-a lungul rândurilor din tensorul „Tens2” și apoi tipăriți acești indici:
index_rând = torță. argmax ( Zeci2 , dim = 1 )imprimare ( „Indici în rânduri:” , index_rând )
Indicii valorii maxime de-a lungul fiecărui rând al unui tensor „Tens2” pot fi văzuți mai jos:
Am explicat eficient metoda de utilizare a metodei „torch.argmax()” în PyTorch.
Notă : Puteți accesa blocnotesul nostru Google Colab aici legătură .
Concluzie
Pentru a utiliza metoda „torch.argmax()” în PyTorch, mai întâi importați „ torță ” bibliotecă. Apoi, creați tensorul 1D sau 2D dorit și vizualizați elementele acestuia. Apoi, utilizați „ torch.argmax() ” metoda de a afla/calcula indicii/indicii valorilor maxime din tensor. Mai mult decât atât, utilizatorii pot găsi, de asemenea, indecșii valorii maxime de-a lungul fiecărui rând sau coloană din tensor folosind „ dim ' argument. În cele din urmă, afișați indicele valorii maxime în tensorul de intrare. Acest blog a exemplificat metoda de utilizare a metodei „torch.argmax()” în PyTorch.