Unlocking Unit Testing in Python

Unit tests are designed to be simple and efficient, focusing on testing small, isolated code components. They should execute quickly and ideally should not involve any database calls or any API calls. However, I often observe developers writing highly coupled code, which poses challenges when it comes to writing tests.

Allow me to provide an example where the code is initially tightly coupled with database calls. I will guide you through the process of refactoring the code to achieve a state where it becomes completely unit testable.

The OrderReport class

The OrderReport class generates a report of fulfilled orders by filtering out cash payments and applying conditions based on location and amount. It constructs a list of order records with relevant details for the report.

from typing import List
from models import Order


class OrderReport:

    def get_records(self) -> List:
        records = []
        orders = Order.objects.filter(is_fulfilled=True)

        for order in orders:
            if order.payment_method == 'CASH':
                continue

            z_value = 0
            if order.location in ['Mumbai', 'Delhi']:
                z_value = 1

            x_value = 0
            if order.amount > 1500:
                x_value = 1

            record = {
                'id': order.id,
                'product': order.product,
                'z_value': z_value,
                'x_value': x_value,
            }
            records.append(record)

        return records

Many developers can relate to the challenge of writing code where the logic is tightly coupled within a single function and involves database calls, making testing a daunting task.

The first refactor

Let's make the class more modular and organized approach.

from typing import List
from typing import Dict
from models import Order


class OrderReport:

    def _get_orders(self):
        return Order.objects.filter(is_fulfilled=True)

    def _is_cash_payment(self, method: str) -> bool:
        return method == 'CASH'

    def _get_z_value(self, location: str) -> int:
        return 1 if location in ['Mumbai', 'Delhi'] else 0

    def _get_x_value(self, amount: int) -> int:
        return 1 if amount > 1500 else 0

    def _build_record(self, order) -> Dict:
        return {
            'id': order.id,
            'product': order.product,
            'z_value': self._get_z_value(location=order.location),
            'x_value': self._get_x_value(amount=order.amount),
        }

    def get_records(self) -> List:
        records = []

        for order in self._get_orders():
            if self._is_cash_payment(method=order.payment_method):
                continue

            records.append(self._build_record(order=order))

        return records
  • The code is divided into multiple private methods to encapsulate specific functionalities.

  • Each private method is responsible for a specific task, such as retrieving orders, checking if the payment method is cash, calculating z_value and x_value, and building records.

  • Database calls are abstracted within the private _get_orders() method.

  • The get_records() method serves as the entry point and orchestrates the execution by calling the private methods.

It separates concerns by dividing the logic into smaller, reusable private methods. However, one aspect that remains is the presence of database calls, which can pose a challenge for unit testing this class.

Decoupling the data layer

In the given class, the database call is encapsulated within the _get_orders() method. To address the testing challenge, one possible approach is to create an abstract class that contains all the logic, with _get_orders() defined as an abstract method. Let's call this abstract class OrderReportABC.

# order_report > __init__.py
from typing import List
from typing import Dict
from typing import Iterable
from abc import ABC, abstractmethod


class OrderReportABC(ABC):

    @abstractmethod
    def _get_orders(self) -> Iterable:
        pass

    def _is_cash_payment(self, method: str) -> bool:
        return method == 'CASH'

    def _get_z_value(self, location: str) -> int:
        return 1 if location in ['Mumbai', 'Delhi'] else 0

    def _get_x_value(self, amount: int) -> int:
        return 1 if amount > 1500 else 0

    def _build_record(self, order) -> Dict:
        return {
            'id': order.id,
            'product': order.product,
            'z_value': self._get_z_value(location=order.location),
            'x_value': self._get_x_value(amount=order.amount),
        }

    def get_records(self) -> List[Dict]:
        records = []

        for order in self._get_orders():
            if self._is_cash_payment(method=order.payment_method):
                continue

            records.append(self._build_record(order=order))

        return records

Now, we can create a class called OrderReport that inherits from the abstract class and provides the implementation of the _get_orders() method specific to the database.

# order_report > order_report.py
from . import OrderReportABC
from models import Order


class OrderReport(OrderReportABC):

    def _get_orders(self) -> Iterable:
        return Order.objects.filter(is_fulfilled=True)

We have also created another class called OrderReportMock, which accepts a list of mock orders in its constructor. This class is specifically designed for unit testing purposes as it eliminates the need for database calls.

# order_report > order_report_mock.py
from typing import List
from typing import Iterable
from . import OrderReportABC


class OrderReportMock(OrderReportABC):

    def __init__(self, orders: List) -> None:
        self._orders = orders
        super().__init__()

    def _get_orders(self) -> Iterable:
        return self._orders

Allow me to illustrate the steps we have taken.

The unit tests

from unittest import TestCase
from unittest.mock import MagicMock
from order_report import OrderReportMock


class TestOrderReport(TestCase):

    def test_is_cash_payment(self):
        order_report = OrderReportMock([])
        self.assertTrue(order_report._is_cash_payment('CASH'))
        self.assertFalse(order_report._is_cash_payment('CARD'))

    def test_get_z_value(self):
        order_report = OrderReportMock([])
        self.assertEqual(order_report._get_z_value('Mumbai'), 1)
        self.assertEqual(order_report._get_z_value('Delhi'), 1)
        self.assertEqual(order_report._get_z_value('Bangalore'), 0)

    def test_get_x_value(self):
        order_report = OrderReportMock([])
        self.assertEqual(order_report._get_x_value(2000), 1)
        self.assertEqual(order_report._get_x_value(1000), 0)
        self.assertEqual(order_report._get_x_value(2500), 1)

    def test_build_record(self):
        order_report = OrderReportMock([])
        order = MagicMock(
            location='Mumbai',
            amount=2000,
            id=1,
            product='Product A')
        record = order_report._build_record(order)
        expected_record = {
            'id': 1,
            'product': 'Product A',
            'z_value': 1,
            'x_value': 1
        }
        self.assertEqual(record, expected_record)

    def test_get_records_no_cash_payment(self):
        orders = [
            MagicMock(
                payment_method='CARD',
                location='Mumbai',
                amount=2000,
                id=1,
                product='Product A'
            ),
            MagicMock(
                payment_method='CARD',
                location='Delhi',
                amount=1000,
                id=2,
                product='Product B'
            ),
            MagicMock(
                payment_method='CARD',
                location='Bangalore',
                amount=2500,
                id=3,
                product='Product C'
            ),
        ]

        order_report = OrderReportMock(orders)
        records = order_report.get_records()

        expected_records = [
            {
                'id': 1,
                'product': 'Product A',
                'z_value': 1,
                'x_value': 1
            },
            {
                'id': 2,
                'product': 'Product B',
                'z_value': 1,
                'x_value': 0
            },
            {
                'id': 3,
                'product': 'Product C',
                'z_value': 0,
                'x_value': 1
            },
        ]

        self.assertEqual(records, expected_records)

    def test_get_records_with_cash_payment(self):
        orders = [
            MagicMock(
                payment_method='CARD',
                location='Mumbai',
                amount=2000,
                id=1,
                product='Product A'
            ),
            MagicMock(
                payment_method='CASH',
                location='Delhi',
                amount=1000,
                id=2,
                product='Product B'
            ),
            MagicMock(
                payment_method='CARD',
                location='Bangalore',
                amount=2500,
                id=3,
                product='Product C'
            ),
        ]

        order_report = OrderReportMock(orders)
        records = order_report.get_records()

        expected_records = [
            {
                'id': 1,
                'product': 'Product A',
                'z_value': 1,
                'x_value': 1
            },
            {
                'id': 3,
                'product': 'Product C',
                'z_value': 0,
                'x_value': 1
            },
        ]

        self.assertEqual(records, expected_records)

The TestOrderReport class is a unit test class that tests various methods of the OrderReportMock class. Here's a brief explanation of each test case:

  1. test_is_cash_payment: Verifies that the _is_cash_payment method correctly identifies 'CASH' payment method as True and other methods as False.

  2. test_get_z_value: Ensures that the _get_z_value method returns the expected values based on the provided location.

  3. test_get_x_value: Checks that the _get_x_value method returns the expected values based on the provided amount.

  4. test_build_record: Validates that the _build_record method constructs the record dictionary correctly using the given order.

  5. test_get_records_no_cash_payment: Tests the get_records method when there are no cash payments in the list of orders. It verifies that the generated records match the expected records.

  6. test_get_records_with_cash_payment: Tests the get_records method when there are cash payments in the list of orders. It ensures that the generated records exclude the orders with cash payment and match the expected records.

These test cases cover a range of scenarios, ensuring the accuracy and reliability of the methods and achieving complete test coverage.

In conclusion, it is crucial to prioritize making the code unit testable, as it simplifies the process of writing tests. When writing unit tests becomes challenging, developers are less likely to invest the effort. By decoupling the data layer and making our code unit testable, we have unlocked the power of efficient and reliable unit testing in Python.

Embracing unit testing as an integral part of our development process empowers us to build robust and resilient software applications.

I hope you found this helpful. Thanks for reading!

0
Subscribe to my newsletter

Read articles from Akshay Suresh Thekkath directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Akshay Suresh Thekkath
Akshay Suresh Thekkath