Pytorch: Get prediction from classification model output

1 min read

A classification model output is usually a tensor with probabilities for each class.
For example: if model is designed to classify for 3 classes then model output could be:
torch tensor [0.5, 0.7, 0.1]
class 0 - probabilily 0.5
class 1 - probability 0.7
class 2 - probability 0.1
To get most probable class consider following code:
outputs = model(images)
pred = outputs.max(1).indices
Here function torch.max has a dim(dimension) parameter equals 1 - that means one dimension vector.
Max function returns a named tuple (values, indices).
We take only indices of classes or categories of a detected.
values are max probabilities for an input image.
0
Subscribe to my newsletter
Read articles from Maxim Vasiliev directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by

Maxim Vasiliev
Maxim Vasiliev
Backend developer