State Monad, sequences of random numbers and monadic code

I'm trying to grasp the State Monad and with this purpose I wanted to write a monadic code that would generate a sequence of random numbers using a Linear Congruential Generator (probably not good, but my intention is just to learn the State Monad, not build a good RNG library).

The generator is just this (I want to generate a sequence of Bool s for simplicity):

type Seed = Int

random :: Seed -> (Bool, Seed)
random seed = let (a, c, m) = (1664525, 1013904223, 2^32)  -- some params for the LCG
                  seed' = (a*seed + c) `mod` m
              in  (even seed', seed')   -- return True/False if seed' is even/odd 

Don't worry about the numbers, this is just an update rule for the seed that (according to Numerical Recipes) should generate a pseudo-random sequence of Int s. Now, if I want to generate random numbers sequentially I'd do:

rand3Bools :: Seed -> ([Bool], Seed)
rand3Bools seed0  = let (b1, seed1) = random seed0
                        (b2, seed2) = random seed1
                        (b3, seed3) = random seed2
                    in  ([b1,b2,b3], seed3)

Ok, so I could avoid this boilerplate by using a State Monad:

import Control.Monad.State

data Random {seed :: Seed, value :: Bool}

nextVal = do 
   Random seed val <- get 
   let seed' = updateSeed seed
       val'  = even seed'
   put (Random seed' val')
   return val'

updateSeed seed = let (a,b,m) = (1664525, 1013904223, 2^32) in (a*seed + c) `mod` m

And finally:

getNRandSt n = replicateM n nextVal 

getNRand :: Int -> Seed -> [Bool]
getNRand   n seed = evalState (getNRandStates n) (Random seed True)

Ok, this works fine and give me a list of n pseudo-random Bool s for each given seed. But...

I can read what I've done (mainly based on this example: http://www.haskell.org/pipermail/beginners/2008-September/000275.html ) and replicate it to do other things. But I don't think I can understand what's really happening behind the do-notation and monadic functions (like replicateM).

Can anyone help me with some of this doubts?

1 - I've tried to desugar the nextVal function to understand what it does, but I couldn't. I can guess it extracts the current state, updates it and then pass the state ahead to the next computation, but this is just based on reading this do-sugar as if it was english.

How do I really desugar this function to the original >>= and return functions step-by-step?

2 - I couldn't grasp what exactly the put and get functions do. I can guess that they "pack" and "unpack" the state. But the mechanics behind the do-sugar is still elusive to me.

Well, any other general remarks about this code are very welcome. I sometimes fell with Haskell that I can create a code that works and do what I expect it to do, but I can't "follow the evaluation" as I'm accustomed to do with imperative programs.


The State monad does look kind of confusing at first; let's do as Norman Ramsey suggested, and walk through how to implement from scratch. Warning, this is pretty lengthy!

First, State has two type parameters: the type of the contained state data and the type of the final result of the computation. We'll use stateData and result respectively as type variables for them here. This makes sense if you think about it; the defining characteristic of a State-based computation is that it modifies a state while producing an output.

Less obvious is that the type constructor takes a function from a state to a modified state and result, like so:

newtype State stateData result = State (stateData -> (result, stateData))

So while the monad is called "State", the actual value wrapped by the the monad is that of a State-based computation, not the actual value of the contained state.

Keeping that in mind, we shouldn't be surprised to find that the function runState used to execute a computation in the State monad is actually nothing more than an accessor for the wrapped function itself, and could be defined like this:

runState (State f) = f

So what does it mean when you define a function that returns a State value? Let's ignore for a moment the fact that State is a monad, and just look at the underlying types. First, consider this function (which doesn't actually do anything with the state):

len2State :: String -> State Int Bool
len2State s = return ((length s) == 2)

If you look at the definition of State, we can see that here the stateData type is Int , and the result type is Bool , so the function wrapped by the data constructor must have the type Int -> (Bool, Int) . Now, imagine a State-less version of len2State --obviously, it would have type String -> Bool . So how would you go about converting such a function into one returning a value that fits into a State wrapper?

Well, obviously, the converted function will need to take a second parameter, an Int representing the state value. It also needs to return a state value, another Int . Since we're not actually doing anything with the state in this function, let's just do the obvious thing--pass that int right on through. Here's a State-shaped function, defined in terms of the State-less version:

len2 :: String -> Bool
len2 s = ((length s) == 2)

len2State :: String -> (Int -> (Bool, Int))
len2State s i = (len2' s, i)

But that's kind of silly and redundant. Let's generalize the conversion so that we can pass in the result value, and turn anything into a State-like function.

convert :: Bool -> (Int -> (Bool, Int))
convert r d = (r, d)

len2 s = ((length s) == 2)

len2State :: String -> (Int -> (Bool, Int))
len2State s = convert (len2 s)

What if we want a function that changes the state? Obviously we can't build one with convert , since we wrote that to pass the state through. Let's keep it simple, and write a function to overwrite the state with a new value. What kind of type would it need? It'll need an Int for the new state value, and of course will have to return a function stateData -> (result, stateData) , because that's what our State wrapper needs. Overwriting the state value doesn't really have a sensible result value outside the State computation, so our result here will just be () , the zero-element tuple that represents "no value" in Haskell.

overwriteState :: Int -> (Int -> ((), Int))
overwriteState newState _ = ((), newState)

That was easy! Now, let's actually do something with that state data. Let's rewrite len2State from above into something more sensible: we'll compare the string length to the current state value.

lenState :: String -> (Int -> (Bool, Int))
lenState s i = ((length s) == i, i)

Can we generalize this into a converter and a State-less function, like we did before? Not quite as easily. Our len function will need to take the state as an argument, but we don't want it to "know about" state. Awkward, indeed. However, we can write a quick helper function that handles everything for us: we'll give it a function that needs to use the state value, and it'll pass the value in and then package everything back up into a State-shaped function leaving len none the wiser.

useState :: (Int -> Bool) -> Int -> (Bool, Int)
useState f d = (f d, d)

len :: String -> Int -> Bool
len s i = (length s) == i

lenState :: String -> (Int -> (Bool, Int))
lenState s = useState (len s)

Now, the tricky part--what if we want to string these functions together? Let's say we want to use lenState on a string, then double the state value if the result is false, then check the string again, and finally return true if either check did. We have all the parts we need for this task, but writing it all out would be a pain. Can we make a function that automatically chains together two functions that each return State-like functions? Sure thing! We just need to make sure it takes as arguments two things: the State function returned by the first function, and a function that takes the prior function's result type as an argument. Let's see how it turns out:

chainStates :: (Int -> (result1, Int)) -> (result1 -> (Int -> (result2, Int))) -> (Int -> (result2, Int))
chainStates prev f d = let (r, d') = prev d
                       in f r d'

All this is doing is applying the first state function to some state data, then applying the second function to the result and the modified state data. Simple, right?

Now, the interesting part: Between chainStates and convert , we should almost be able to turn any combination of State-less functions into a State-enabled function! The only thing we need now is a replacement for useState that returns the state data as its result, so that chainStates can pass it along to the functions that don't know anything about the trick we're pulling on them. Also, we'll use lambdas to accept the result from the previous functions and give them temporary names. Okay, let's make this happen:

extractState :: Int -> (Int, Int)
extractState d = (d, d)

chained :: String -> (Int -> (Bool, Int))
chained str = chainStates  extractState         $ state1 ->
              let check1 = (len str state1) in
              chainStates (overwriteState (
                  if check1 
                  then state1 
                  else state1 * 2))             $  _ ->
              chainStates  extractState         $ state2 ->
              let check2 = (len str state2) in
              convert (check1 || check2)

And try it out:

> chained "abcd" 2
(True, 4)
> chained "abcd" 3
(False, 6)
> chained "abcd" 4
(True, 4)
> chained "abcdef" 5
(False, 10)

Of course, we can't forget that State is actually a monad that wraps the State-like functions and keeps us away from them, so none of our nifty functions that we've built will help us with the real thing. Or will they? In a shocking twist, it turns out that the real State monad provides all the same functions, under different names:

runState (State s) = s
return r = State (convert r)
(>>=) s f = State (d -> let (r, d') = (runState s) d in
                         runState (f r) d')
get = State extractState
put d = State (overwriteState d)

Note that >>= is almost identical to chainStates, but there was no good way to define it using chainStates. So, to wrap things up, we can rewrite the final example using the real State:

chained str = get                               >>= state1 ->
              let check1 = (len str state1) in
              put (if check1 
                  then state1 else state1 * 2)  >>=  _ ->
              get                               >>= state2 ->
              let check2 = (len str state2) in
              return (check1 || check2)

Or, all candied up with the equivalent do notation:

chained str = do
        state1 <- get
        let check1 = len str state1
        _ <- put (if check1 then state1 else state1 * 2)
        state2 <- get
        let check2 = (len str state2)
        return (check1 || check2)

First of all, your example is overly complicated because it doesn't need to store the val in the state monad; only the seed is the persistent state. Second, I think you will have better luck if instead of using the standard state monad, you re-implement all of the state monad and its operations yourself, with their types. I think you will learn more this way. Here are a couple of declarations to get you started:

data MyState s a = MyState (s -> (s, b))

get :: Mystate s s
put :: s -> Mystate s ()

Then you can write your own connectives:

unit :: a -> Mystate s a
bind :: Mystate s a -> (a -> Mystate s b) -> Mystate s b

Finally

data Seed = Seed Int
nextVal :: Mystate Seed Bool

As for your trouble desugaring, the do notation you are using is pretty sophisticated. But desugaring is a line-at-a-time mechanical procedure. As near as I can make out, your code should desugar like this (going back to your original types and code, which I disagree with):

 nextVal = get >>=  Random seed val ->
                      let seed' = updateSeed seed
                          val'  = even seed'
                      in  put (Random seed' val') >>=  _ -> return val'

In order to make the nesting structure a bit clearer, I've taken major liberties with the indentation.


You've got a couple great responses. What I do when working with the State monad is in my mind replace State sa with s -> (s,a) (after all, that's really what it is).

You then get a type for bind that looks like:

(>>=) :: (s -> (s,a)) ->
         (a -> s -> (s,b)) ->
         (s -> (s,b))

and you see that bind is just a specialized kind of function composition operator, like (.)

I wrote a blog/tutorial on the state monad here. It's probably not particularly good, but helped me grok things a little better by writing it.

链接地址: http://www.djcxy.com/p/42912.html

上一篇: 来自Haskell的State Monad的ID

下一篇: State Monad,随机数字和一元代码的序列