ITree.EqmR.Monads.State
Section StateEq.
Variable S : Type.
Definition state (s a : Type) : Type :=
s -> s * a.
Definition eqmR_state (A B : Type) (R : relationH A B) :=
fun (x : state S A) (y : state S B) =>
forall (s : S),
R (snd (x s)) (snd (y s)) /\ fst (x s) = fst (y s).
Lemma equal_typ {A} {x : A} : x = x. reflexivity. Qed.
#[global]
Instance EqmR_state : EqmR (state S) := {|eqmR := eqmR_state|}.
(* Couple properties about state mayRet. *)
Definition state_prop {A : Type} (ma : state S A) :=
(fun a : A => exists s s', ma s = (s', a)).
Definition const_state {A : Type} (s : S) (a : A) : S -> (S * A) :=
fun s => (s, a).
Lemma mayRet_state {A:Type} (ma : state S A) (a : A) :
(exists s s' : S, (ma s) = (s', a)) <->
a ∈ ma.
Proof.
split. intro.
- edestruct H as (? & ? & ?); clear H.
red. red. repeat intro.
red in EQR; unfold eqmR_state in EQR.
specialize (EQR x). crunch. subst. rewrite H0 in H; auto.
- intros EQR. repeat red in EQR.
epose (state_prop ma : A -> Prop) as Q.
assert (eqmR (diagonal_prop Q) ma ma).
{ intro. cbn. split.
- intros.
unfold state_prop. split. repeat red.
exists s. destruct (ma s) eqn: Hma. cbn in *. subst. eauto.
exists s. destruct (ma s) eqn: Hma. cbn in *. subst. eauto.
- reflexivity.
}
specialize (EQR (diagonal_prop Q) (diagonal_prop_PER Q) H).
unfold diagonal_prop in EQR.
destruct EQR. clear H1 H.
edestruct H0 as (? & ? & ?).
auto.
Qed.
(* The eqmR_state definition is EqmR_OK. *)
Ltac destruct_hyp :=
match goal with
| [H : (_ /\ _) /\ fst (?x ?s0) = fst (?y ?s1) |- _] =>
destruct H as ((? & ?) & ?); destruct (x s0) eqn: ?HX;
destruct (y s1) eqn : ?HY; cbn in *; subst
| [H : (_ /\ _) /\ _ = _ |- _] =>
destruct H as ((? & ?) & <-)
end.
Ltac rewrite_eq :=
do 2 match goal with
| [H : forall _, ?x = _ -> _ |- _] => destruct (H x) as (? & ? & <-); try reflexivity; clear H
end; eauto.
Ltac ok_crush :=
cbn in *;
match goal with
| [ x : S , H : eqmR_state _ _ _ _ _ |- _ ] => specialize (H x)
end;
destruct_hyp; split; [ split | ]; auto; intros * <-; rewrite_eq.
Instance eqmR_conj_state: RelConj (state S).
Proof.
intros A B R1 R2 ma mb H H0.
cbn in *. unfold eqmR_state in *.
intros. split; cycle 1.
specialize (H s); specialize (H0 s). crunch.
auto.
specialize (H s). crunch. subst.
specialize (H0 s). crunch. subst. eauto.
specialize (H0 s). crunch. subst. eauto.
Qed.
Lemma eqmR_transport_refl_St:
forall (A : Type) (R : relationH A A), Reflexive R -> Reflexive (eqmR R).
Proof.
repeat intro. split; eauto.
Qed.
Instance eqmR_Proper_eq_St:
forall A B : Type, Proper (eq_rel ==> eqmR (@eq A) ==> eqmR (@eq B) ==> Basics.flip Basics.impl) eqmR.
Proof.
repeat intro. split; eauto; repeat red in H0, H1, H2; edestruct H0, H1, H2.
- apply H. rewrite H3, H5. eauto.
- edestruct H0, H1, H2. rewrite H4, H6. eauto.
Unshelve. all: eauto.
Qed.
(* State monad is a monad. *)
Definition ret_state (A : Type) (x : A) := fun s : S => (s, x).
Definition bind_state (A B : Type) (sa : state S A) (f : A -> state S B) :=
(fun s : S => (f (snd (sa s)) (fst (sa s)))).
Global Instance Monad_state : Monad (state S) :=
{|
ret := ret_state;
bind := bind_state
|}.
Global Instance MonadState_state : MonadState S (state S) :=
{|
get := (fun x => (x, x));
put := (fun v => (fun _ => (v, tt)))
|}.
Ltac unfold_state := unfold fmap, bind, Monad_state, bind_state, ret, ret_state; cbn in *.
Instance eqmR_prod_fst_inv_state : FmapFst (state S).
Proof.
repeat intro.
repeat red in H. unfold_state.
specialize (H s); crunch;
destruct (m1 s) eqn: Hma; destruct (m2 s) eqn: Hma';
cbn in *; subst; auto.
destruct p, p0; inv H; eauto.
Defined.
Instance eqmR_prod_snd_inv_state : FmapSnd (state S).
Proof.
repeat intro.
repeat red in H. unfold_state.
specialize (H s); crunch;
destruct (m1 s) eqn: Hma; destruct (m2 s) eqn: Hma';
cbn in *; subst; auto.
destruct p, p0; inv H; eauto.
Defined.
Instance eqmR_prod_rel_state : RelProd (state S).
Proof.
intros. red. split; eauto. red in H, H0.
unfold eqmR_state in *.
destruct (snd (m1 s)) eqn: Hm1.
destruct (snd (m2 s)) eqn: Hm2.
cbn in *. constructor; eauto.
specialize (H s); specialize (H0 s).
rewrite Hm1 in H, H0.
rewrite Hm2 in H, H0.
crunch. cbn in *; eauto.
specialize (H s); specialize (H0 s).
rewrite Hm1 in H, H0.
rewrite Hm2 in H, H0.
crunch. eauto. cbn in *; eauto.
cbn in *; unfold eqmR_state in *; crunch; eauto.
specialize (H s); specialize (H0 s). crunch.
unfold bind_state, ret_state in *. cbn in *. eauto.
Defined.
Instance eqmR_sum_inl_state : FmapInl (state S).
Proof.
repeat intro.
repeat red in H. unfold_state.
specialize (H s); crunch;
destruct (m1 s) eqn: Hma; destruct (m2 s) eqn: Hma';
cbn in *; subst; auto.
Defined.
Instance eqmR_sum_inr_state : FmapInr (state S).
Proof.
repeat intro.
repeat red in H. unfold_state.
specialize (H s); crunch;
destruct (m1 s) eqn: Hma; destruct (m2 s) eqn: Hma';
cbn in *; subst; auto.
Defined.
Instance eqmR_sum_rel_state : RelSum (state S).
Proof.
repeat intro.
repeat red in H. unfold_state.
split; repeat intro.
- specialize (H s); crunch;
destruct (m1 s) eqn: Hma; destruct (m2 s) eqn: Hma';
cbn in *; subst; auto.
destruct s1, s3; inv H.
apply H0. auto. apply H1; auto.
- specialize (H (A1 + B1)%type (A2 + B2)%type inl inl inr inr (RA ⊕ RB)).
assert (inl_P :ProperH (respectfulH A1 A2 (A1 + B1) (A2 + B2) RA (RA ⊕ RB)) inl inl) by
(repeat intro; constructor; eauto).
assert (inr_P: ProperH (respectfulH B1 B2 (A1 + B1) (A2 + B2) RB (RA ⊕ RB)) inr inr) by
(repeat intro; constructor; eauto).
destruct (m1 s) eqn: Hma; destruct (m2 s) eqn: Hma'.
cbn.
specialize (H inl_P inr_P s). cbn in *.
rewrite Hma, Hma' in H; cbn in *.
crunch; eauto.
destruct s1, s3; crunch; eauto.
Defined.
Lemma eqmR_prod_sum_rel_state :
forall (A B C : Type) (ma ma' : state S (C + A * B)%type)
(RA : relationH A A) (RB : relationH B B) (RC : relationH C C),
(fst_sum <$> ma) ≈{ RC ⊕ RA } (fst_sum <$> ma') ->
(snd_sum <$> ma) ≈{ RC ⊕ RB } (snd_sum <$> ma') -> ma ≈{ RC ⊕ RA ⊗ RB } ma'.
Proof.
intros. repeat red in H. repeat red in H0.
intro. specialize (H s); specialize (H0 s). crunch.
all : unfold fmap, fst_sum, snd_sum in *.
all : inv H0; inv H; inv H1; inv H2;
destruct (ma s) eqn: Hma; destruct (ma' s) eqn: Hma'; cbn in *; destruct s1 eqn: Hs1; destruct s3 eqn: Hs3.
all : eauto.
all : repeat match goal with
| [H : inl _ = _ |- _] => inv H
| [H : inr _ = _ |- _] => inv H
end; eauto.
constructor. destruct p, p0. eauto.
Qed.
Instance eqmR_transport_PER_state : TransportPER (state S).
Proof.
split.
- repeat intro; crunch. repeat red in H0. symmetry. eauto.
symmetry. edestruct H0. symmetry. eauto.
edestruct H0. symmetry. eauto.
- repeat intro. intros. repeat intro. unfold TransitiveH in H.
cbn in *. unfold eqmR_state in *.
destruct (x s) eqn: HX.
specialize (H0 s).
specialize (H1 s).
crunch. subst. etransitivity; eauto. rewrite <- HX. auto.
cbn. rewrite HX in *. cbn in *. subst. auto.
Defined.
Instance eqmR_rel_comp_state : RelComp (state S).
Proof.
repeat intro. cbn in *. unfold eqmR_state in *.
intros. specialize (H s). specialize (H0 s).
destruct H. destruct H0. split; eauto. esplit; split; eauto.
rewrite H1, H2. auto.
Defined.
Instance eqmR_lift_transpose_state : LiftTranspose (state S).
Proof.
intros.
split; repeat intro; repeat red in H; specialize (H s) ;
crunch; eauto.
Defined.
Instance eqmR_Proper_state :
forall A B : Type, Proper (eq_rel (A:=A) (B:=B) ==> eq_rel) eqmR.
Proof.
repeat intro; split; repeat intro; crunch; repeat red in H0;
specialize (H0 s); crunch; subst; try eexists; try split;
eauto; apply H; auto.
Defined.
Instance eqmR_Proper_mono_state:
forall A B : Type,
Proper (subrelationH (A:=A) (B:=B) ==> subrelationH (B:=state S B)) eqmR.
Proof.
repeat intro; crunch; repeat red in H0; specialize (H0 s);
crunch; subst; eauto.
Defined.
Instance eqmR_transport_Equiv_state: TransportEquiv (state S).
Proof.
repeat intro; destruct H; split; try typeclasses eauto.
repeat intro; eauto.
eapply eqmR_transport_PER_state; split; eauto.
apply eqmR_transport_PER_state; split; eauto.
Qed.
Instance transport_refl : TransportReflexive (state S).
repeat intro; eauto.
Qed.
Instance transport_symm : TransportSymmetric (state S).
repeat intro; eauto. repeat red in H0. edestruct H0; eauto.
Qed.
Instance transport_trans : TransportTransitive (state S).
repeat intro; eauto. repeat red in H0. edestruct H0, H1; split; subst; eauto.
rewrite H3. auto.
Qed.
Global Instance EqmR_OK_state : EqmR_OK (state S).
Proof.
constructor; try typeclasses eauto.
Defined.
Opaque mayRet.
(* EqmRMonad properties for state monad. *)
Lemma mayRet_image :
forall A (a: A) ma, (a ∈ ma -> image ma a a).
Proof.
intros. apply H.
Qed.
Lemma image_eqmR_state :
forall (A : Type) (ma : state S A), ma ≈{ image ma } ma.
Proof.
intros. cbn. unfold eqmR_state. intros. cbn.
destruct (ma s) eqn: Hma.
cbn in *. split.
- intros. repeat red. intros. repeat red in EQR.
edestruct EQR. rewrite Hma in *. cbn in *; auto.
- reflexivity.
Qed.
Instance eqmR_mayRet_l_state : MayRetL (state S).
Proof.
red;intros.
rewrite <- mayRet_state in H. edestruct H as (? & ? & ?); clear H.
cbn in *.
destruct (ma1 x) eqn: Hma. cbn in *.
inv H0. rename EQH into H.
unfold eqmR_state in H. cbn in H.
specialize (H x). setoid_rewrite Hma in H.
cbn in *. setoid_rewrite <- mayRet_state.
destruct (ma2 x) eqn: Hma2. cbn in *.
crunch. eauto.
Defined.
Instance eqmR_mayRet_r_state : MayRetR (state S).
Proof.
red;intros.
rewrite <- mayRet_state in H. edestruct H as (? & ? & ?); clear H.
cbn in *.
destruct (ma1 x) eqn: Hma. cbn in *.
inv H0. rename EQH into H.
unfold eqmR_state in H. cbn in H.
specialize (H x). setoid_rewrite Hma in H. crunch.
cbn in *. setoid_rewrite <- mayRet_state.
exists a. split; eauto. rewrite H1 in H. auto.
Defined.
Instance eqmR_ret_state : RetFinal (state S).
Proof.
red; intros. cbn. unfold eqmR_state. cbn. intros.
split.
- eauto.
- reflexivity.
Defined.
Instance eqmR_bind_ProperH_state : ProperBind (state S).
Proof.
red;intros.
cbn in *. unfold eqmR_state in *. cbn in *. intro.
specialize (H s).
destruct (ma1 s) eqn: Hma1. destruct (ma2 s) eqn: Hma2. cbn in *. subst.
destruct H.
specialize (H0 _ _ H). subst.
edestruct H0; eauto. constructor; eauto.
unfold bind_state. rewrite Hma1, Hma2. cbn. eauto.
unfold bind_state. rewrite Hma1, Hma2. cbn. eauto.
Qed.
Instance eqmR_bind_ret_l_state : BindRetL (state S).
Proof.
red;intros. repeat red. cbn. intros.
unfold bind_state, ret_state. cbn in *. split; auto.
Defined.
Instance eqmR_bind_ret_r_state : BindRetR (state S).
Proof.
repeat intro. cbn. split; try split; try reflexivity.
Defined.
Instance eqmR_bind_bind_state : BindBind (state S).
Proof.
repeat intro. cbn.
unfold bind_state.
split; try split; try reflexivity;
intros; subst.
Defined.
Global Instance EqmRMonad_state : EqmRMonadLaws (state S).
Proof.
constructor; try typeclasses eauto.
repeat intro. eapply image_eqmR_state.
Qed.
(* EqmRMonadInverses for state monad. *)
Lemma are_continuations_functions:
forall {A B : Type} (a : A) (k : A -> state S B) (mb mb' : state S B),
k a = mb /\ k a = mb' -> mb = mb'.
Proof.
intros. destruct H. rewrite <- H, <- H0. reflexivity.
Qed.
Context {IS: inhabited S}.
Instance ret_inv_St : RetInv (state S).
Proof.
inv IS.
repeat intro. repeat red in H.
edestruct H. crunch; eauto. Unshelve. eauto.
Defined.
Definition imageA {A : Type} : relationH A A :=
fun a a' =>
forall (R : relationH A A) (PH : PER R), R a a'.
Definition mayRetA {A : Type} : A -> Prop :=
(fun (a : A) => imageA a a).
Instance state_mayRet_bind_inv : MayRetBindInv (state S).
Proof.
red; intros.
apply mayRet_state in H.
setoid_rewrite <- mayRet_state.
crunch.
unfold bind, Monad_state, bind_state in H.
destruct (ma x) eqn:Hma1.
exists a. split; eauto; eauto.
Defined.
Lemma fmap_inv_state_hetero :
forall (A B : Type) (ma ma' : state S A) (f f' : A -> B) (R : relationH B B),
(f <$> ma) ≈{ R } (f' <$> ma') ->
forall a : A, a ∈ ma -> exists a', a' ∈ ma' /\ ret (f a) ≈{ R } ret (f' a').
Proof.
repeat intro.
repeat red in H.
unfold fmap, bind, ret, Monad_state, bind_state, ret_state in H.
eapply mayRet_state in H0.
setoid_rewrite <- mayRet_state. crunch.
destruct (ma' x) eqn: Hma.
exists a0. specialize (H x). crunch. cbn in *; rewrite H0, Hma in H, H1.
crunch; eauto.
eapply eqmR_ret. typeclasses eauto.
cbn in *.
rewrite Hma, H0 in H. eauto.
Qed.
Definition state_prop_ {A : Type} (ma ma': state S A) :=
(fun a a' : A => exists s s' s'' : S, (ma s) = (s', a) /\ (ma' s) = (s'', a')).
Lemma image_state {A:Type} (ma ma': state S A) (a a': A) :
(exists s s' s'': S, (ma s) = (s', a) /\ (ma' s) = (s'', a')) ->
image_ _ ma ma' a a'.
Proof.
intro.
- edestruct H as (? & ? & ?); clear H.
red. repeat intro.
red in EQR; unfold eqmR_state in EQR.
specialize (EQR x). crunch. subst.
rewrite H, H2 in H0. auto.
Qed.
Instance fmap_inv_mayRet_state : FmapInvRet (state S).
Proof.
repeat intro.
repeat red in H.
unfold fmap, bind, ret, Monad_state, bind_state, ret_state in H.
eapply mayRet_state in H0.
crunch.
specialize (H x). crunch. cbn in *. rewrite H0 in H, H1.
crunch; eauto. cbn in *; constructor.
Defined.
Instance fmap_inv_state_ : FmapInv_mayRet (state S).
Proof.
repeat intro.
Ltac sunfold := unfold fmap, bind, ret, Monad_state, bind_state, ret_state in *; cbn in *; unfold eqmR_state in *.
sunfold.
specialize (H s). crunch; destruct (ma1 s) eqn: Hma, (ma2 s) eqn: Hma'; cbn in *; subst; eauto.
red. rewrite<- !mayRet_state; split; eauto.
Defined.
Instance eqmR_inv_state : EqmRInv (state S).
Proof.
repeat intro. rewrite <- mayRet_state in H0. destruct H0 as (?&?&?).
repeat red in H. specialize (H x). rewrite H0 in H.
destruct H; eauto.
Qed.
Global Instance EqmRMonadInverses_state : EqmRInversionProperties (state S).
Proof.
constructor; try typeclasses eauto.
Defined.
Global Instance EQM_state : EQM (state S) _.
Proof.
econstructor; typeclasses eauto.
Defined.
End StateEq.
Example bind_ret_inv_counterexample:
forall A B (a:A) (b:B) (x y : nat),
bind (m := state nat) (fun x => (x+1, a)) (fun x => (fun y => (y-1, b))) ≈{eq} ret b.
Proof.
intros.
repeat red. unfold bind, ret, Monad_state, bind_state, ret_state; cbn.
intros. split; eauto. lia.
Qed.
Definition ma : state nat nat := (fun s => if Nat.eqb s 2 then (s, 0) else (s + 1, 1)).
Definition k : nat -> state nat nat := (fun s => if Nat.eqb s 0 then (fun s => (s, s)) else (fun s => (s, 1))).
Lemma ma_mayRet :
0 ∈ ma.
Proof. apply mayRet_state. exists 2. cbn. exists 2; eauto. Qed.
Lemma k_mayRet :
0 ∈ k 0.
Proof. apply mayRet_state. cbn. eauto. Qed.
Lemma bind_mayRet:
~ (0 ∈ bind ma k).
Proof.
intro. apply mayRet_state in H. destruct H as (? & ? & ?).
cbn in *. unfold bind_state in H.
unfold ma in H.
destruct (Nat.eqb x 2) eqn: Hx. cbn in *.
apply EqNat.beq_nat_true in Hx. subst. inv H. cbn in *. inv H.
Qed.