Unlocking Unit Testing in Python
Unit tests are designed to be simple and efficient, focusing on testing small, isolated code components. They should execute quickly and ideally should not involve any database calls or any API calls. However, I often observe developers writing highly coupled code, which poses challenges when it comes to writing tests.
Allow me to provide an example where the code is initially tightly coupled with database calls. I will guide you through the process of refactoring the code to achieve a state where it becomes completely unit testable.
The OrderReport class
The OrderReport
class generates a report of fulfilled orders by filtering out cash payments and applying conditions based on location and amount. It constructs a list of order records with relevant details for the report.
from typing import List
from models import Order
class OrderReport:
def get_records(self) -> List:
records = []
orders = Order.objects.filter(is_fulfilled=True)
for order in orders:
if order.payment_method == 'CASH':
continue
z_value = 0
if order.location in ['Mumbai', 'Delhi']:
z_value = 1
x_value = 0
if order.amount > 1500:
x_value = 1
record = {
'id': order.id,
'product': order.product,
'z_value': z_value,
'x_value': x_value,
}
records.append(record)
return records
Many developers can relate to the challenge of writing code where the logic is tightly coupled within a single function and involves database calls, making testing a daunting task.
The first refactor
Let's make the class more modular and organized approach.
from typing import List
from typing import Dict
from models import Order
class OrderReport:
def _get_orders(self):
return Order.objects.filter(is_fulfilled=True)
def _is_cash_payment(self, method: str) -> bool:
return method == 'CASH'
def _get_z_value(self, location: str) -> int:
return 1 if location in ['Mumbai', 'Delhi'] else 0
def _get_x_value(self, amount: int) -> int:
return 1 if amount > 1500 else 0
def _build_record(self, order) -> Dict:
return {
'id': order.id,
'product': order.product,
'z_value': self._get_z_value(location=order.location),
'x_value': self._get_x_value(amount=order.amount),
}
def get_records(self) -> List:
records = []
for order in self._get_orders():
if self._is_cash_payment(method=order.payment_method):
continue
records.append(self._build_record(order=order))
return records
The code is divided into multiple private methods to encapsulate specific functionalities.
Each private method is responsible for a specific task, such as retrieving orders, checking if the payment method is cash, calculating
z_value
andx_value
, and building records.Database calls are abstracted within the private
_get_orders()
method.The
get_records()
method serves as the entry point and orchestrates the execution by calling the private methods.
It separates concerns by dividing the logic into smaller, reusable private methods. However, one aspect that remains is the presence of database calls, which can pose a challenge for unit testing this class.
Decoupling the data layer
In the given class, the database call is encapsulated within the _get_orders()
method. To address the testing challenge, one possible approach is to create an abstract class that contains all the logic, with _get_orders()
defined as an abstract method. Let's call this abstract class OrderReportABC
.
# order_report > __init__.py
from typing import List
from typing import Dict
from typing import Iterable
from abc import ABC, abstractmethod
class OrderReportABC(ABC):
@abstractmethod
def _get_orders(self) -> Iterable:
pass
def _is_cash_payment(self, method: str) -> bool:
return method == 'CASH'
def _get_z_value(self, location: str) -> int:
return 1 if location in ['Mumbai', 'Delhi'] else 0
def _get_x_value(self, amount: int) -> int:
return 1 if amount > 1500 else 0
def _build_record(self, order) -> Dict:
return {
'id': order.id,
'product': order.product,
'z_value': self._get_z_value(location=order.location),
'x_value': self._get_x_value(amount=order.amount),
}
def get_records(self) -> List[Dict]:
records = []
for order in self._get_orders():
if self._is_cash_payment(method=order.payment_method):
continue
records.append(self._build_record(order=order))
return records
Now, we can create a class called OrderReport
that inherits from the abstract class and provides the implementation of the _get_orders()
method specific to the database.
# order_report > order_report.py
from . import OrderReportABC
from models import Order
class OrderReport(OrderReportABC):
def _get_orders(self) -> Iterable:
return Order.objects.filter(is_fulfilled=True)
We have also created another class called OrderReportMock
, which accepts a list of mock orders in its constructor. This class is specifically designed for unit testing purposes as it eliminates the need for database calls.
# order_report > order_report_mock.py
from typing import List
from typing import Iterable
from . import OrderReportABC
class OrderReportMock(OrderReportABC):
def __init__(self, orders: List) -> None:
self._orders = orders
super().__init__()
def _get_orders(self) -> Iterable:
return self._orders
Allow me to illustrate the steps we have taken.
The unit tests
from unittest import TestCase
from unittest.mock import MagicMock
from order_report import OrderReportMock
class TestOrderReport(TestCase):
def test_is_cash_payment(self):
order_report = OrderReportMock([])
self.assertTrue(order_report._is_cash_payment('CASH'))
self.assertFalse(order_report._is_cash_payment('CARD'))
def test_get_z_value(self):
order_report = OrderReportMock([])
self.assertEqual(order_report._get_z_value('Mumbai'), 1)
self.assertEqual(order_report._get_z_value('Delhi'), 1)
self.assertEqual(order_report._get_z_value('Bangalore'), 0)
def test_get_x_value(self):
order_report = OrderReportMock([])
self.assertEqual(order_report._get_x_value(2000), 1)
self.assertEqual(order_report._get_x_value(1000), 0)
self.assertEqual(order_report._get_x_value(2500), 1)
def test_build_record(self):
order_report = OrderReportMock([])
order = MagicMock(
location='Mumbai',
amount=2000,
id=1,
product='Product A')
record = order_report._build_record(order)
expected_record = {
'id': 1,
'product': 'Product A',
'z_value': 1,
'x_value': 1
}
self.assertEqual(record, expected_record)
def test_get_records_no_cash_payment(self):
orders = [
MagicMock(
payment_method='CARD',
location='Mumbai',
amount=2000,
id=1,
product='Product A'
),
MagicMock(
payment_method='CARD',
location='Delhi',
amount=1000,
id=2,
product='Product B'
),
MagicMock(
payment_method='CARD',
location='Bangalore',
amount=2500,
id=3,
product='Product C'
),
]
order_report = OrderReportMock(orders)
records = order_report.get_records()
expected_records = [
{
'id': 1,
'product': 'Product A',
'z_value': 1,
'x_value': 1
},
{
'id': 2,
'product': 'Product B',
'z_value': 1,
'x_value': 0
},
{
'id': 3,
'product': 'Product C',
'z_value': 0,
'x_value': 1
},
]
self.assertEqual(records, expected_records)
def test_get_records_with_cash_payment(self):
orders = [
MagicMock(
payment_method='CARD',
location='Mumbai',
amount=2000,
id=1,
product='Product A'
),
MagicMock(
payment_method='CASH',
location='Delhi',
amount=1000,
id=2,
product='Product B'
),
MagicMock(
payment_method='CARD',
location='Bangalore',
amount=2500,
id=3,
product='Product C'
),
]
order_report = OrderReportMock(orders)
records = order_report.get_records()
expected_records = [
{
'id': 1,
'product': 'Product A',
'z_value': 1,
'x_value': 1
},
{
'id': 3,
'product': 'Product C',
'z_value': 0,
'x_value': 1
},
]
self.assertEqual(records, expected_records)
The TestOrderReport
class is a unit test class that tests various methods of the OrderReportMock
class. Here's a brief explanation of each test case:
test_is_cash_payment
: Verifies that the_is_cash_payment
method correctly identifies 'CASH' payment method as True and other methods as False.test_get_z_value
: Ensures that the_get_z_value
method returns the expected values based on the provided location.test_get_x_value
: Checks that the_get_x_value
method returns the expected values based on the provided amount.test_build_record
: Validates that the_build_record
method constructs the record dictionary correctly using the given order.test_get_records_no_cash_payment
: Tests theget_records
method when there are no cash payments in the list of orders. It verifies that the generated records match the expected records.test_get_records_with_cash_payment
: Tests theget_records
method when there are cash payments in the list of orders. It ensures that the generated records exclude the orders with cash payment and match the expected records.
These test cases cover a range of scenarios, ensuring the accuracy and reliability of the methods and achieving complete test coverage.
In conclusion, it is crucial to prioritize making the code unit testable, as it simplifies the process of writing tests. When writing unit tests becomes challenging, developers are less likely to invest the effort. By decoupling the data layer and making our code unit testable, we have unlocked the power of efficient and reliable unit testing in Python.
Embracing unit testing as an integral part of our development process empowers us to build robust and resilient software applications.
I hope you found this helpful. Thanks for reading!
Subscribe to my newsletter
Read articles from Akshay Suresh Thekkath directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by