Franklin Pezzuti Dyer

Home     Posts     CV     Contact     People

Simple dependency injection using monads

I've been mulling over the idea of this blog post for a long time, probably since I wrote this older post describing a couple examples of monads. Monads are infamously tricky to grok, which is why I've been looking for examples to help myself understand the commonest Haskell monads that are as simple as possible, yet at the same time well-motivated (no Foo Bar Baz or any of that gibberish).

This post consists of a family of examples illustrating how monads allow us to easily implement dependency injection in Haskell, which is a design pattern whereby an abstractly-written function receives its dependencies "from outside", making it very easy to drastically change the function's behavior by manipulating the dependencies that it is given.

As a warning - this post isn't very beginner-friendly, and it's not an introduction to monads. The intended audience is someone (like me) who has been toying with monads for a little while but is struggling to get a "big picture" understanding of how they are powerful. If you have not programmed in Haskell and worked with monads before, you will probably have a hard time following what comes below, but I encourage you to download the code and load it into a REPL to tinker with regardless of your skill level.

The function into which we'll be "injecting dependencies" is a very simple one - we will just be writing an implementation for the following mathematical function: As a disclaimer, when I said that I wanted an example to "motivate" monads, that was coming from the mouth (fingers?) of someone who feels more comfortable in pure mathematics than software engineering, so not everyone will find this example as compelling as me. The above function is somewhat arbitrary, but I picked it specifically because it involves a couple of different mathematical operations that can be problematic, such as division and taking a square root. In a moment we'll see how various different monads can be used to define more robust versions of this function without rewriting any of its definition.

A naive implementation of this function in Haskell looks like this:


fNaive :: Float -> Float
fNaive x = 1 / (sqrt x) + 1 / (sqrt (1 - x))

Below is a more general definition of f that returns not a Float, but a t Float, where t is an arbitrary monad. This version of the function is defined using do-notation, so that it works for arbitrary monads, but its exact functionality depends on the particular monad being used. In the examples below, we'll see how we can "inject" new functionality into this function without touching its definition at all, simply by changing the monad that it uses.

The function f will be what is known as a kliesli morphism with respect to whatever monad t we use in each case. To build up this function monadically, we need kliesli versions of each of the floating point operations that are used to define f, namely addition, reciprocals, and square roots. So we start be defining an interface for monads supporting each of these operations as kliesli morphisms, and then proceed to write our more modular definition of the function f:


class Monad t => MonadicFloatOps t where
    madd :: Float -> Float -> t Float
    madd = (return .) . (+)
    msqrt :: Float -> t Float
    mrecip :: Float -> t Float

f :: MonadicFloatOps t => Float -> t Float
f x = do
    x1 <- msqrt x
    x2 <- (madd 1 (-x)) >>= msqrt
    y1 <- mrecip x1
    y2 <- mrecip x2
    madd y1 y2

These are the definitions that we'll use in what follows.

The Identity monad

The identity monad effectively does not change the structure of a type at all, meaning that values of Identity Float are precisely in correspondence with values of Float. So if we let t = Identity in our definition for f, we will have something that is funcionally equivalent to fNaive, except that its type signature is Float -> Identity Float instead of Float -> Float. Its outputs are not exactly Float values, but you can think of them as "wrapped floats".

We can redefine fNaive using f with a minimal amount of extra work: we just need to let Haskell know how our three floating-point operations are defined as kliesli morphisms returning Identity Float values rather than Float values. That is, we need to instance Identity into MonadicFloatOps. This can be done easily as follows:


instance MonadicFloatOps Identity where
    madd = (return .) . (+)
    msqrt = return . sqrt
    mrecip = return . (1/)

Then we can redefine fNaive without making any changes to f, just by telling Haskell that it's a special case of f where t is the Identity functor:


fNaive = f :: Float -> Identity Float

And here are some sample outputs, run in GHCI:


ghci> fNaive 0.3
Identity 3.0209703
ghci> fNaive 0.5
Identity 2.828427
ghci> fNaive 0
Identity Infinity
ghci> fNaive 2
Identity NaN

The Maybe monad

We've just seen how this function can return Infinity or NaN in some cases. These outputs can occur when one of the two reciprocals involved in calculating f results in a division by zero, or when one of the square roots is a complex square root. If we wanted to safely and conservatively handle these undefined cases without relying on the dubious conventions of Infinity and NaN established by the IEEE, we could make use of the Maybe monad. This monad is perfect for expressing the results of computations that may fail.

The Maybe functor already has a built-in Monad instance, so we just need to instance it into MonadicFloatOps by defining "safe" versions of our three floating-point operations. We can accomplish this as follows:


instance MonadicFloatOps Maybe where
    madd = (return .) . (+)
    msqrt x
        | x >= 0 = Just (sqrt x)
        | otherwise = Nothing
    mrecip x
        | x == 0 = Nothing
        | otherwise = Just (1 / x)

Note that for the madd addition operation, our definition is the same as before, because madd cannot fail on its own. That is, the sum of two Maybe Float values shouldn't be Nothing unless one of the two arguments is Nothing. For msqrt and mrecip, we need to do a little more work: in the former case, we need to check that the argument is nonnegative, and in the latter case we just need to check that the argument is nonzero.

Again, we can define a Maybe Float variant of the function f just by indicating to Haskell that we want to use t = Maybe as our monad:


fMaybe = f :: Float -> Maybe Float

And here are the same test cases:


ghci> fMaybe 0.3
Just 3.0209703
ghci> fMaybe 0.5
Just 2.828427
ghci> fMaybe 0
Nothing
ghci> fMaybe 2
Nothing

The Either monads

Next suppose that we still want to write a safe version of f, but that this time we want to be a bit more descriptive about what exactly went wrong when the function returns nothing. As we've commented already, there are two potential problems: complex square roots, and division by zero. But when fMaybe returns Nothing, it provides no information about which of these two things happened.

So let's define a custom data type that expresses these two possible types of error:


data MyFloatError = ImaginarySqrtError | ZeroDivisionError deriving (Eq, Show)

The return value of our newest variant of f will return either one of these two errors or an actual floating-point number. Enter the Either monad. Or, to be more precise, the Either monads - for any type a, there is a monad corresponding to the functor Either a b. In our case, we will be using Either MyFloatError as our monad t. Before we do so, we need to instance the functor Either MyFloatError into our interface MonadicFloatOps:


instance MonadicFloatOps (Either MyFloatError) where
    madd = (return .) . (+)
    msqrt x
        | x >= 0 = Right (sqrt x)
        | otherwise = Left ImaginarySqrtError
    mrecip x
        | x == 0 = Left ZeroDivisionError
        | otherwise = Right (1 / x)

Again, madd is trivial to implement, since addition cannot fail on its own. But this time, msqrt returns an ImaginarySqrtError when it fails, whereas mrecip returns an ZeroDivisionError when it fails. This way, when f fails, we avoid erasing the information about how it failed.

As before, we can define fEither, our newest variant of f, like this:


fEither = f :: Float -> Either MyFloatError Float

and here are some test values:


ghci> fEither 0.3
Right 3.0209703
ghci> fEither 0.5
Right 2.828427
ghci> fEither 0
Left ZeroDivisionError
ghci> fEither 2
Left ImaginarySqrtError

The List monad

We've already seen how the square-root is an operation that can fail, but it is also sometimes described mathematically as a multivalued function, or a function that can return multiple possible values. This is because every positive real number has two different square roots, in particular a positive square root and a negative square root. Often we only want the positive square root of a number, but occasionally we are interested in both cases.

By making use of the list monad [], we can express the fact that a mathematical calculation has multiple different possible results depending on which branches are used for each multivalued expression involved. Our newest variant of the function f will therefore have the type signature Float -> [Float], and its output will list all of the possible values of this expression: depending on which signs are chosen for each term, for a maximum of $4$ different possible values. As before, we just need to instance the list monad into MonadicFloatOps, since its monad instance is built-in. Here are the definitions of the three operations:


instance MonadicFloatOps [] where
    madd = (return .) . (+)
    msqrt x
        | x > 0 = [sqrt x, -sqrt x]
        | x == 0 = [0]
        | otherwise = []
    mrecip x
        | x == 0 = []
        | otherwise = [1 / x]

Again, nothing special happens for the madd operation. The mrecip implementation is also pretty plain, but notice that we return an empty list [] when dividing by zero, so that we are still expressing the case in which f fails, as we did with the Maybe monad, except that the "failure" result is an empty list rather than Nothing. For msqrt, there can be zero, one or two distinct results, since there are no square roots of a negative number, there is one square root of zero, and there are two square roots of any positive number.

Here's the definition of fList:


fList = f :: Float -> [Float]

And here are our test outputs:


ghci> fList 0.3
[3.0209703,0.6305132,-0.6305132,-3.0209703]
ghci> fList 0.5
[2.828427,0.0,0.0,-2.828427]
ghci> fList 0
[]
ghci> fList 2
[]

One peculiarity of our implementation is the duplicate value 0.0 that can be seen in the result of fList 0.5. Can you see why this happens?

The Reader monads

We've just discussed how the square-root function is sometimes viewed as a multi-valued function. If, instead of expressing all possible values of the function f, we only want to express one or the other depending on a choice of whether to use positive or negative square roots, we can make use of the Reader monad. This monad allows us to express values that depend on some shared environment, in a read-only capacity. This monad is often used in software as a way of making many environment variables available to a function without stacking up a bunch of input arguments.

In a more complicated piece of software, we might use the monad Reader env to pass around environment variables, where env is some user-defined data type, usually a record type with a field corresponding to each environment variable. But for this simple contrived example our environment only consists of a single bit of information, namely the answer to the question "do we want to use positive square roots, or negative square roots?" So we will use the Reader Bool monad.


instance MonadicFloatOps (Reader Bool) where
    madd = (return .) . (+)
    msqrt x = do
        pos <- ask
        let sx = sqrt x
        return $ if pos then sx else -sx
    mrecip x = return (1 / x)

The implementations of madd and mrecip are straightforward. Notice that unlike in the previous examples, we are not implementing any protections on the unsafe square-root and reciprocal operations. The only nontrivial bit of this instance is where we query the shared environment in the msqrt implementation and return either sqrt x or -sqrt x depending on the configuration variable.

Now we can define fReader like this:


fReader = f :: Float -> Reader Bool Float

Here are some sample outputs. Keep in mind that when x is a Float, the expression fReader x evaluates not to a floating point number but to a Reader Bool Float object, which cannot be evaluated to a Float until is receives an environment of type Bool. The function runReader is used to give it this argument.


ghci> runReader (fReader 0.3) True
3.0209703
ghci> runReader (fReader 0.3) False
-3.0209703
ghci> runReader (fReader 0.5) True
2.828427
ghci> runReader (fReader 0.5) False
-2.828427
ghci> runReader (fReader 0) True
Infinity
ghci> runReader (fReader 2) True
NaN

The IO monad

Now let's do something quite different. Suppose that, instead of being more interested in the safety of our function f, we are interested in its speed. To that end, suppose that we've written an optimized implementation of some mathematical function in another language like C, and would like to hook it up with our Haskell function to improve its performance. Or this could merely be for the sake of convenience, for instance if we already have an implementation of some mathematical function written in C and don't want to implement it all over again in Haskell, but we need it as a subcalculation for one of our Haskell functions.

In this case, I'll show how we can use an external implementation of the square-root function. To do this, we will need to let t = IO, since running an external executable file will require our Haskell program to carry out some I/O. Here's some C code that reads a floating-point number from its first argument and writes the approximate square root of this number to stdout using an iterative method:


#import <stdio.h>

double iterative_sqrt(double x, double est, double tol) {
    double est_prev = 0;
    double delta = est;
    while (delta > tol || delta < -tol) {
        est_prev = est;
        est = (est + x/est)/2;
        delta = est - est_prev;
    }
    return est;
}

int main(int argc, char** argv) {
    float x = 0;
    sscanf(argv[1], "%f", &x);
    printf("%.15f\n\n", iterative_sqrt(x, 1.0, 1e-15));
}

I make no claims that this is any faster than Haskell's built-in square-root function, and in fact it probably isn't. But this is only for illustrative purposes anyways, and it is no stretch of the imagination to think that for certain applications one might want to use an optimized C subroutine for some obscure mathematical calculation that is used in a Haskell function.

I've compiled this C code to an executable file sqrt.o. Here's how we can instance IO into MonadicFloatOps in such a way that msqrt delegates square-root calculations to this external executable:


instance MonadicFloatOps IO where
    madd = (return .) . (+)
    msqrt x = do
        (_, Just hout, _, ph) 
            <- createProcess (proc "./sqrt.o" [show x]) { std_out = CreatePipe }
        sqrtString <- hGetLine hout
        return (read sqrtString)
    mrecip x = return (1 / x)

From here, the definition of our new variant of f is easy:


fIO = f :: Float -> IO Float

and here are some test outputs:


ghci> fIO 0.3
3.0209703
ghci> fIO 0.5
2.828427
ghci> fIO 0
1.0e15
ghci> fIO 2
*** Exception: Prelude.read: no parse

Note that in this version we made no attempts to ensure the safety of f, so that in the cases where f is not defined, the output values of fIO can be messy.

The Writer monads

Another interesting piece of functionality that we might want to inject into f is operation counting. We might want to do this if we were carrying out a theoretical analysis of some computational algorithm and wanted to keep track of how many floating-point operations a certain subroutine carried out for each test case. This seems a little silly for the function f, because we can count these operations easily just by looking at f. But this could become very useful if f were a more complex function, especially one defined recursively, or if f were being called multiple times as part of some other function.

We can accomplish this using the Writer family of monad.s Like the Reader monad, it is also used to represent computations with a shared environment, but it can be contrasted with Reader in that the Writer monads provide a write-only environment. That means that the environment cannot affect the computation but it "accumulates" information drawn from the computation as it runs. This is ideal for operation counting, because the running operation count should be continually updated but should not affect the result of the function evaluation. Another practical use-case of the Writer monad might be to implement logging of some sort, say, appending a message to some log string each time a certain kind of function is called.

We will be using the Writer Int monad to implement operation counting for f, since our running operation count will have a type of Int. Unlike in previous examples, we cannot immediately instance Writer Int into MonadicFloatOps because it is not automatically an instance of Monad - in order to use the built-in monad instance for Writer a, the type a must be instanced into Monoid. This is because the writer needs to be told "how to accumulate" values of type a. In our case, simple addition will do. If we were implementing something like logging, our monoid operation might be string concatenation.

Here's our monoid instance for Int:


instance Semigroup Int where
    (<>) = (+)

instance Monoid Int where
    mempty = 0

and here is our instance of Writer Int into MonadicFloatOps:


instance MonadicFloatOps (Writer Int) where
    madd x y = tell 1 >> return (x + y)
    msqrt x = tell 1 >> return (sqrt x)
    mrecip x = tell 1 >> return (1 / x)

Finally, here are our test cases. Again, the unsafe examples fail with the default behavior for Float because we haven't checked for any of the unsafe inputs to msqrt or mrecip. The variant fWriter can be written with nothing more than a type hint:


fWriter = f :: Float -> Writer Int Float

and here are our test cases:


ghci> runWriter (fWriter 0.3)
(3.0209703,6)
ghci> runWriter (fWriter 0.5)
(2.828427,6)
ghci> runWriter (fWriter 0)
(Infinity,6)
ghci> runWriter (fWriter 2)
(NaN,6)

Of course, the operation counts are all equal to 6, because the function f is evaluated the same way for each of these calls. I think that the real power of this monad is better illustrated in a use-case where the function f is used as a subroutine to some other function. For instance, suppose that we want to approximate the following definite integral of $f$ using left Riemann sums: We can use our MonadicFloatOps instance for Writer Int to seamlessly integrate an operation count into our Riemann sum routine. Here's a quick implementation of left Riemann sums approximating this integral, where a,b are the endpoints of the interval and n is the number of rectangles:


analyzeRiemannSum :: MonadicFloatOps m => Float -> Float -> Int -> m Float
analyzeRiemannSum a b n = do
    let nf = fromIntegral n
    let step = (b - a)/nf
    let xs = [a + k*step | k <- [0..nf-1] ]
    ys <- sequence $ map f xs
    foldM madd 0 (map (step *) ys)

Yes, okay, admittedly I cheated because I'm not counting the floating point multiplications using (*) towards the operation count. If we wanted to include these in the operation count, we would just need to add a function called mmultiply to our MonadicFloatOps interface. But I won't do that here.

Anyhow, here are some test cases with $[a,b] = [0.1,0.9]$:


ghci> runWriter (analyzeRiemannSum 0.1 0.9 20)
(2.5338423,140)
ghci> runWriter (analyzeRiemannSum 0.1 0.9 100)
(2.529984,700)

Note that our analyzeRiemannSum function isn't specific to the Writer Int monad, but rather works for any MonadicFloatOps instance. More on this in just a second.

The State monads

For one final example, I'll show how a State monad can be used to implement not only operation counting for the function f, but also memoization for some operations involved in its calculation, in particular the square-root operation. It's a little silly to use memoization here, since we aren't doing particularly expensive computations. But this is all for illustrative purposes anyways.

The State monads are ideal here because different parts of the computation must interact with a shared state, but neither a Reader nor a Writer is sufficient. Computations must have the ability to read from the shared state in order to find pre-computed values in a memo and avoid re-computing them. They also need the ability to write to the shared state in order to add new values to the memo that haven't been computed before. Since we need both reading and writing capabilities in this use-case, a State monad is needed here.

Here's a custom data type that we can use to hold both a memo of square-root values and an operation count:


data FloatOpsMemo = FloatOpsMemo {
    numOps :: Int,
    sqrtMemo :: [(Float, Float)]
} deriving (Eq, Show)

And here's an implementation of MonadicFloatOps for the state monad State FloatOpsMemo:


instance MonadicFloatOps (State FloatOpsMemo) where
    madd x y = modify (\s -> s { numOps = 1 + numOps s }) >> return (x + y)
    msqrt x = do
        memo <- fmap sqrtMemo get
        let lookupRes = lookup x memo
        case lookupRes of
            Nothing ->
                modify (\s -> s { numOps = 1 + numOps s, sqrtMemo = (x,sqx):(sqrtMemo s) })
                >> return sqx
                where sqx = sqrt x
            (Just sqx) -> return sqx
    mrecip x = modify (\s -> s { numOps = 1 + numOps s }) >> return (1 / x)

Admittedly there is a bit more boilerplate involved in incrementing the operation counts for madd and mrecip because we need to unpack the operation count component of our FloatOpsMemo state. The memo implementation for msqrt, which checks whether a square-root value is present in the memo before proceeding to calculate it, is also a bit verbose, but in a larger program the memo functionality would probably be abstracted out into a helper functino written elsewhere, rather than implemented by hand. Note that we are only incrementing the operation count for msqrt when the memo lookup misses.

The definition of fMemo, the memoized version of f, is:


fMemo = f :: Float -> State FloatOpsMemo Float

And here are the test cases:


ghci> runState (fMemo 0.3) (FloatOpsMemo { numOps = 0, sqrtMemo = [] })
(3.0209703,FloatOpsMemo {numOps = 6, sqrtMemo = [(0.7,0.83666),(0.3,0.5477226)]})
ghci> runState (fMemo 0.5) (FloatOpsMemo { numOps = 0, sqrtMemo = [] })
(2.828427,FloatOpsMemo {numOps = 5, sqrtMemo = [(0.5,0.70710677)]})
ghci> runState (fMemo 0) (FloatOpsMemo { numOps = 0, sqrtMemo = [] })
(Infinity,FloatOpsMemo {numOps = 6, sqrtMemo = [(1.0,1.0),(0.0,0.0)]})
ghci> runState (fMemo 2) (FloatOpsMemo { numOps = 0, sqrtMemo = [] })
(NaN,FloatOpsMemo {numOps = 6, sqrtMemo = [(-1.0,NaN),(2.0,1.4142135)]})

An interesting consequence of this implementation is that the final state of each run tells us exactly which square-roots have been calculated over its course. Another interesting thing to note that differs from our experience with the Writer operation counter is that numOps is only 5 for the case of fMemo 0.5. Can you see why this is?

Remember our analyzeRiemannSum function, and how it was defined for any MonadicFloatOps instance? Well guess what - we can use it for State FloatOpsMemo too. This will compute left Riemann sums for f memoizing all square roots over the course of the entire calculation. Here are two examples:


ghci> runState (analyzeRiemannSum 0.1 0.9 20) (FloatOpsMemo { numOps = 0, sqrtMemo = [] })
(2.5338423,FloatOpsMemo {numOps = 131, sqrtMemo = [(0.13999999,0.3741657),...]})
ghci> runState (analyzeRiemannSum 0.1 0.9 100) (FloatOpsMemo { numOps = 0, sqrtMemo = [] })
(2.529984,FloatOpsMemo {numOps = 674, sqrtMemo = [(0.10800004,0.3286336),...]})

There are a lot of values in these memos, so I've truncated them. But notice how the operation counts are a bit lower for these runs. This is because of the memoization we've implemented for msqrt. Not that I'm claiming this leads to any real performance gains in this case - but one can see how it would be handy to be able to incorporate memoization this seamlessly into a computational problem that could benefit from it a lot more.


back to home page