YAHT: Simple state monad (note to self)
- Magnus Therning
First off, keep in mind that the State st a
type is a function type (st -> (st, a)
) and that returnState a
wraps a
into a function that returns a state-value,value pair when it’s applied to a state-value.
bindState
is a bit complicated to think about. It takes two arguments, both functions. The first (m
) is of type st -> (st, a)
(i.e. State st a
). The second (k
) is of type a -> (st -> (st, b))
(i.e. a-> (State st b)
). In other words, this second argument is a function that takes a value and returns a function that we can apply to a state, the return function has “wrapped up” in it the end result of the computation. Looking more into the implementation of bindState
we see that it returns a function taking one argument (st
) where m
is used to calculate a new state-value,value (st'
and a
) pair based on st
. Next k
is used to “wrap” a
into a function (m'
) that we can apply a state_value to. The last step is to apply the new state (st'
) to the function wrapping the new value (m'
).
Looking at mapTreeStateM
on (Leaf a)
we see that f a
is m
in bindState
. \b -> returnState (Leaf b)
, from our look at bindState
we can see that b
here will be our newly calculated value (i.e. the new value coming out of f a
will be wrapped into a leaf again).
I found it easier to understand when I applied an “identity function” to a tree:
doNothing :: Integer -> State Integer Integer
= returnState a doNothing a
Working through, on paper, what goes on when applying doNothing
, a tree of one leaf, and an initial state of 0 made it quite clear what is happening.
Analysing the recursive step after this is fairly straight forward. Adding parentheses makes it a bit easier to read and I found it helpful to rewrite it to only consider the left branches:
mapTreeStateMLO :: (a -> State st Integer) -> Tree a -> State st (Tree Integer)
Leaf a) =
mapTreeStateMLO f (`bindState` (\b -> returnState (Leaf b))
(f a) Branch lhs rhs) =
mapTreeStateMLO f (`bindState` (\lhs' ->
(mapTreeStateMLO f lhs) Branch lhs' (Leaf (-1)))) returnState (
The extension to right branches comes fairly naturally after that.
I wrote a few functions to play with mapTreeStateM
. All of them use a single integer to represent its state.
State is the number of leaves in the tree:
countLeaf :: Integer -> State Integer Integer
= \st -> (st + 1, a) countLeaf a
State is the maximum value of a leaf
maxLeaf :: Integer -> State Integer Integer
= \st -> (max st a, a) maxLeaf a
This function that numbers each leaf in the tree.
numberLeaf :: Integer -> State Integer (Integer, Integer)
= \st -> (st + 1, (st, a)) numberLeaf a