Micro library for retries in cats-effect using Scala extensions

This blog entry will show how to add a simple retry mechanism to cats-effect (CE) using Scala 3 extensions. The idea is that whenever we define our IO instances we can also specify that, if they fail and certain conditions are met, they will be automatically run again. Our goal is to showcase a not-that-trivial example of how to use extensions.

💡
At the time of this post entry there is no retry mechanism provided by the last version of CE (3.5.4). However this is likely to change in future versions. As soon as one in included in CE, if you need a retries mechanism, I advise to use the one provided by CE.

All the code presented in this post is available in this gist.


Setting up a retries library can look deceivingly easy, after all the basic idea is simple: capture exceptions, run the task again if an error is caught. But when we start thinking about possible requirements things get a bit more complex. Should we set a limit on the number of retries (probably so!)? Should we check for any error or only specific exceptions? Should we take into account the history of previous errors? Should we add some time between retries too? If so, is that time constant? Possible calls to our library could look as this:

val io: IO[A] = ???
io.retryTimes(3) // Retry up to 3 times, after that re-thrown the error
io.retryIfError[MyException] // Retry only if 'MyException' is captured
io.retryForever // Retry until io succeeds
io.retryTimes(3, 1.second) // Retry up to 3 times, whith one second of interleaving time

We'll go for a not-so-simple solution, likely not universal but still powerful enough to meet many requirements. Our 'retry' function will allow to set:

  1. The type of errors to watch for, any other error is immediately re-thrown

  2. A condition to check, so only if the condition is met the retry will be run. This condition will be checked by a function that will take as input the error instance, and an state S . The state can be anything that the developer wants to keep track of between retries (e.g. a counter of tries so far). An initial value for the state initState can be set too. We must take into account also that maybe condition is not a pure function.

  3. Time between retries, if any. This time will be recalculated by the condition function.

With the conditions set above the type of our condition function looks like this:

// Let 'T' be the type of the exceptions we want to catch and 'S' the
// type of the state to keep between retries
type Condition = (S, T) => F[Either[T, (S, FiniteDuration)]]

The return Either is used to signal whether we must re-thrown the error (if we get Left) or run a retry (if we get Right). Maybe we must wait for some time for that retry, so we include a FiniteDuration in the return type (ofc Duration.Zero will signal no waiting time), and the next value of S that will be used if the retry fails too. With that in mind we can code our retry function as follows:

import cats.MonadThrow
import cats.effect.{IO, IOApp, Ref, Temporal}
import cats.syntax.all.*

import scala.concurrent.duration.*
import scala.reflect.Typeable

extension[F[_]: MonadThrow: Temporal, A, S](f: F[A])
  def retry[T <: Throwable : Typeable](initState: S, condition: (S, T) => F[Either[T, (S, FiniteDuration)]]): F[A] =
    MonadThrow[F].attempt(f).flatMap:
      case Right(a) => a.pure[F] // All good! Nothing to do
      case Left(t: T) => // Error of type `T`, let's check what to do now:
        condition(initState, t).flatMap:
          case Right((newState, waitTime)) => // Retry!
            Temporal[F].sleep(waitTime) >> retry(newState, condition)
          case Left(t) => MonadThrow[F].raiseError(t) // Rethrown immediately ¯\_(ツ)_/¯ 
      case Left(e) => MonadThrow[F].raiseError(e) // This error is not `T`, re-thrown

Let's explain a bit some of the 'quirks' of this code:

  1. What are all those types in the extension definition [F[_]: MonadThrow: Temporal, A, S? Well, let's go one by one:

    1. F[_] we use so our code is polymorphic, that is, although we know that in many cases we will be using CE's IO we want to allow this retry functionality with other types

    2. :MonadThrow This is used to signal that there is a monad of F that can raise errors, and interestingly also handling them (see the MonadThrow[F].attempt call)

    3. :Temporal We need this to be able to 'sleep' between retries (see the Temporal[F].sleep call)

    4. A is just the type returned by f upon invocation (wrapped in F ofc)

    5. S is the type that contains the state between retries

  2. Why is retry has such a complex signature? Well, let's go step by step:

    1. We need to define [T <: Throwable : Typeable] to set the type T of errors we want to look for. Because it represents errors it must be a subclass of Throwable. And we have to add the extra : Typeable to prevent getting warning messages the type test for T cannot be checked at runtime. This is caused by the JVM type erasure, i.e. we cannot know at runtime which is the type of T. We use Typeable as this magic wand that fixes the issue in Scala.

    2. retry accepts two parameters: the initial state (in an instance of S) and the condition function that has the signature we explained before.

The implementation is pretty straightforward: we run f, if no error is caught we return the resulting a immediately, otherwise we check the error type. If the error is an instance of T then we invoke then condition to know what to do next, depending on the result we just re-throw the error or try again (after waiting for some time). If the error is not of type T then it is re-thrown.

Our retry function is powerful but arguably cumbersome. Fortunately we can use it to code simpler functions that can be used in many typical uses cases for retries. See for example:

extension[F[_]: MonadThrow: Temporal, A, S](f: F[A])

  // retry as defined previously
  def retry[T <: Throwable : Typeable](initState: S, condition: (S, T) => F[Either[T, (S, FiniteDuration)]]): F[A] = ???

  // retry up to 'n' times, fixed time between retries
  def retryN[T <: Throwable : Typeable](n: Int, timeBetweenRetries: FiniteDuration = Duration.Zero): F[A] =
    require(n >= 0)
    retry[T](0,
      (counter, t) =>
        if counter >= n then Left(t).pure[F]
        else Right(counter + 1, timeBetweenRetries).pure[F]
      )

  // retry while some condition is met, fixed time between retries
  def retryWhile[T <: Throwable : Typeable](cond: T => F[Boolean], timeBetweenRetries: FiniteDuration = Duration.Zero): F[A] =
    retry[T]((),
      (_, t) =>
        MonadThrow[F].ifF(cond(t))(Right((), timeBetweenRetries), Left(t))
    )

  // retry until successful (no error caught), fixed time between retries
  def retryUntilSuccessful[T <: Throwable : Typeable](timeBetweenRetries: FiniteDuration = Duration.Zero): F[A] =
    retryWhile[T](_ => true.pure[F], timeBetweenRetries)

Now we can use our retries microlibrary! And thanks to the power of Scala extensions Invoking it is fairly trivial, see an example using scala-cli (download full code from this gist):

//> using scala "3.4.2"
//> using dep "org.typelevel::cats-effect::3.5.4"

import cats.MonadThrow
import cats.effect.{IO, IOApp, Temporal}
import cats.syntax.all.*

import scala.concurrent.duration.*
import scala.reflect.Typeable

extension[F[_]: MonadThrow: Temporal, A, S](f: F[A])
  // retry methods as defined previously
  def retry[T <: Throwable : Typeable](initState: S, condition: (S, T) => F[Either[T, (S, FiniteDuration)]]): F[A] = ???
  def retryN[T <: Throwable : Typeable](n: Int, timeBetweenRetries: FiniteDuration = Duration.Zero): F[A] = ???
  def retryWhile[T <: Throwable : Typeable](cond: T => F[Boolean], timeBetweenRetries: FiniteDuration = Duration.Zero): F[A] = ???
  def retryUntilSuccessful[T <: Throwable : Typeable](timeBetweenRetries: FiniteDuration = Duration.Zero): F[A] = ???

object Main extends IOApp.Simple:

  case class CustomError(msg: String) extends Extension(msg)

  val ioa: IO[Unit] = IO.println("Hi, I'm going to fail!") >> IO.raiseError[Unit](CustomError("Ups"))

  // retry if `CustomError` is captured, up to 3 times, 1 second of waiting time between retries
  @Override def run: IO[Unit] = ioa.retryN[CustomError](3, 1.second)
0
Subscribe to my newsletter

Read articles from Luis Rodero-Merino directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

Luis Rodero-Merino
Luis Rodero-Merino

Dev interested in functional programming solutions for Scala.