Pytorch: Get prediction from classification model output

Maxim VasilievMaxim Vasiliev
1 min read

A classification model output is usually a tensor with probabilities for each class.
For example: if model is designed to classify for 3 classes then model output could be:

torch tensor [0.5, 0.7, 0.1]

class 0 - probabilily 0.5

class 1 - probability 0.7

class 2 - probability 0.1

To get most probable class consider following code:

outputs = model(images)
pred = outputs.max(1).indices

Here function torch.max has a dim(dimension) parameter equals 1 - that means one dimension vector.

Max function returns a named tuple (values, indices).

We take only indices of classes or categories of a detected.

values are max probabilities for an input image.

0
Subscribe to my newsletter

Read articles from Maxim Vasiliev directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Maxim Vasiliev
Maxim Vasiliev

Backend developer