diff --git a/Compiler/CompAction.hs b/Compiler/CompAction.hs index f2ba92f..8e2c8aa 100644 --- a/Compiler/CompAction.hs +++ b/Compiler/CompAction.hs @@ -26,7 +26,7 @@ deriving instance Show T.BitFormat deriving instance Show T.RegFileReaders deriving instance Show T.RegFileBase -show_finfun :: Show a => Int -> (CustomExtract.EFin -> a) -> String +show_finfun :: Show a => Int -> (T.Fin -> a) -> String show_finfun n f = "{ " ++ intercalate " ; " (map (show . f) $ T.getFins n) ++ " }" instance Show T.FullFormat where @@ -66,10 +66,10 @@ instance Show (T.Expr ty) where show (T.BinBitBool n m o e1 e2) = "BinBitBool " ++ show n ++ " " ++ show m ++ " " ++ show o ++ " " ++ show e1 ++ " " ++ show e2 show (T.ITE fk e1 e2 e3) = "ITE " ++ show fk ++ " " ++ show e1 ++ " " ++ show e2 ++ " " ++ show e3 show (T.Eq k e1 e2) = "Eq " ++ show k ++ " " ++ show e1 ++ " " ++ show e2 - show (T.ReadStruct n fk fs e i) = "ReadStruct " ++ show n ++ " " ++ show_finfun n fk ++ " " ++ show_finfun n fs ++ " " ++ show e ++ " " ++ show i + show (T.ReadStruct n fk fs e i) = "ReadStruct " ++ show n ++ " " ++ show_finfun n fk ++ " " ++ show_finfun n fs ++ " " ++ show e ++ " " ++ show (T.to_nat n i) show (T.BuildStruct n fk fs fe) = "BuildStruct " ++ show n ++ " " ++ show_finfun n fk ++ " " ++ show_finfun n fs ++ " " ++ show_finfun n fe show (T.ReadArray n m k e1 e2) = "ReadArray " ++ show n ++ " " ++ show m ++ " " ++ show e1 ++ " " ++ show e2 - show (T.ReadArrayConst n k e i) = "ReadArrayConst " ++ show n ++ " " ++ show k ++ " " ++ show e ++ " " ++ show i + show (T.ReadArrayConst n k e i) = "ReadArrayConst " ++ show n ++ " " ++ show k ++ " " ++ show e ++ " " ++ show (T.to_nat n i) show (T.BuildArray n k f) = "BuildArray " ++ show n ++ " " ++ show k ++ " " ++ show_finfun n f deriving instance Show T.SyncRead diff --git a/Compiler/Compiler.v b/Compiler/Compiler.v index 3ce6204..b9ef0bc 100644 --- a/Compiler/Compiler.v +++ b/Compiler/Compiler.v @@ -1,4 +1,5 @@ Require Import Kami.StateMonad Kami.Syntax Kami.Properties Kami.PProperties Kami.PPlusProperties Kami.Notations Kami.Lib.EclecticLib. +Require Import Kami.StdLib.Fin. Import Word.Notations. Require Import ZArith. @@ -277,7 +278,7 @@ Section Compile. Proof. split; intros. - unfold SyncRead_eqb in H. - repeat (apply andb_prop in H; dest; rewrite String.eqb_eq in *). + repeat (apply andb_prop in H; dest; rewrite String.eqb_eq in * ). destruct r, r'; simpl in *; subst; auto. - rewrite H. unfold SyncRead_eqb; repeat rewrite eqb_refl; auto. @@ -595,7 +596,7 @@ Section Semantics. (HIn : In (dataArray, (existT _ (SyntaxKind (Array idxNum Data)) regVal)) updatedRegs) cont calls val contArray (HContArray : contArray = - BuildArray (fun i : Fin.t num => + BuildArray (fun i : Fin num => ReadArray (Var type _ regVal) (CABit Add (Var type (SyntaxKind _) (evalExpr idx) :: @@ -663,7 +664,7 @@ Section Semantics. (HRegVal : In (dataArray, (existT _ (SyntaxKind (Array idxNum Data)) regV)) updatedRegs) (HWriteMap : WfRegMapExpr (UpdRegMap readRegName pred - (BuildArray (fun i : Fin.t num => + (BuildArray (fun i : Fin num => ReadArray (Var type _ regV) (CABit Add (Var type (SyntaxKind _) (evalExpr idx) :: @@ -682,7 +683,7 @@ Section Semantics. (HRegVal2 : In (dataArray, existT _ (SyntaxKind (Array idxNum Data)) regVal) updatedRegs) (contArray : Expr type (SyntaxKind (Array num Data))) (HContArray : contArray = - BuildArray (fun i : Fin.t num => + BuildArray (fun i : Fin num => ReadArray (Var type _ regVal) (CABit Add (Var type (SyntaxKind _) idx :: @@ -816,7 +817,7 @@ Section EActionT_Semantics. (cont : type (Array num Data) -> EActionT type retK) (contArray : Expr type (SyntaxKind (Array num Data))) (HContArray : contArray = - BuildArray (fun i : Fin.t num => + BuildArray (fun i : Fin num => ReadArray (Var type _ regV) (CABit Add (Var type (SyntaxKind _) (evalExpr idx) :: @@ -878,7 +879,7 @@ Section EActionT_Semantics. (HANewUml : anewUml = (UmUpd (readRegName, existT _ _ (evalExpr - (BuildArray (fun i : Fin.t num => + (BuildArray (fun i : Fin num => ReadArray (Var type _ regV) (CABit Add (Var type (SyntaxKind _) (evalExpr idx) :: @@ -893,7 +894,7 @@ Section EActionT_Semantics. (HRegVal2 : In (dataArray, existT _ (SyntaxKind (Array idxNum Data)) regVal) o) (contArray : Expr type (SyntaxKind (Array num Data))) (HContArray : contArray = - BuildArray (fun i : Fin.t num => + BuildArray (fun i : Fin num => ReadArray (Var type _ regVal) (CABit Add (Var type (SyntaxKind _) idx :: diff --git a/Compiler/CompilerProps.v b/Compiler/CompilerProps.v index 7a121f7..5081cbc 100644 --- a/Compiler/CompilerProps.v +++ b/Compiler/CompilerProps.v @@ -1,4 +1,5 @@ Require Import Kami.StateMonad Kami.Syntax Kami.Properties Kami.PProperties Kami.PPlusProperties Kami.Lib.EclecticLib Kami.Notations Kami.Compiler.Compiler. +Require Import Kami.StdLib.Fin. Import Word.Notations. Require Import ZArith. @@ -1396,7 +1397,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (IF ReadArrayConst mask0 i then newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum) (proj1_sig (to_nat i))) <- ReadArrayConst val i] else newArr)%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regVal))))]) :: tl upds0)) (r, @@ -1405,7 +1406,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (IF ReadArrayConst mask0 i then newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum)(proj1_sig (to_nat i))) <- ReadArrayConst val i] else newArr)%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regVal))))]) :: tl upds0)) as P0. { unfold WfRegMapExpr in *; dest; split; auto; constructor. } @@ -1414,7 +1415,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (IF ReadArrayConst mask0 i then newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum)(proj1_sig (to_nat i))) <- ReadArrayConst val i] else newArr)%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regVal))))::x). repeat split; auto. @@ -1444,7 +1445,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum) (proj1_sig (to_nat i))) <- ReadArrayConst val i])%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regVal))))]) :: tl upds0)) (r, @@ -1453,7 +1454,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum)(proj1_sig (to_nat i))) <- ReadArrayConst val i])%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regVal))))]) :: tl upds0)) as P0. { unfold WfRegMapExpr in *; dest; split; auto; constructor. } @@ -1462,7 +1463,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum)(proj1_sig (to_nat i))) <- ReadArrayConst val i])%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regVal))))::x). repeat split; auto. @@ -1532,7 +1533,7 @@ Proof. existT (fullType type) (SyntaxKind (Array num Data)) (evalExpr (BuildArray - (fun i : Fin.t num => + (fun i : Fin num => (Var type (SyntaxKind (Array idxNum Data)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up idxNum))) (evalExpr idx) + Const type (natToWord (Nat.log2_up idxNum)(proj1_sig (to_nat i)))])%kami_expr))))]) :: tl upds0)) @@ -1542,7 +1543,7 @@ Proof. existT (fullType type) (SyntaxKind (Array num Data)) (evalExpr (BuildArray - (fun i : Fin.t num => + (fun i : Fin num => (Var type (SyntaxKind (Array idxNum Data)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up idxNum))) (evalExpr idx) + Const type (natToWord (Nat.log2_up idxNum) (proj1_sig (to_nat i)))])%kami_expr))))]) :: tl upds0)) as P0. @@ -1551,7 +1552,7 @@ Proof. exists (UmUpd (readReg, existT (fullType type) (SyntaxKind (Array num Data)) (evalExpr (BuildArray - (fun i : Fin.t num => + (fun i : Fin num => (Var type (SyntaxKind (Array idxNum Data)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up idxNum))) (evalExpr idx) + Const type (natToWord (Nat.log2_up idxNum) (proj1_sig (to_nat i)))])%kami_expr))))::x). @@ -2564,7 +2565,7 @@ Proof. rewrite (getKindAttr_map_fst _ _ (prevPrevRegsTrue H2)) in H4; contradiction. -- rewrite in_app_iff; auto. + reflexivity. -Qed. +Qed. Lemma PriorityUpds_exist o (HNoDup : NoDup (map fst o)): forall upds @@ -3031,7 +3032,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (IF ReadArrayConst mask0 i then newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum) (proj1_sig (to_nat i))) <- ReadArrayConst val i] else newArr)%kami_expr) @@ -3048,7 +3049,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (newArr @[ idx + Const type (natToWord (Nat.log2_up idxNum) (proj1_sig (to_nat i))) <- ReadArrayConst val i])%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regV))))::x); split; auto. @@ -3075,7 +3076,7 @@ Proof. existT (fullType type) (SyntaxKind (Array num Data)) (evalExpr (BuildArray - (fun i : Fin.t num => + (fun i : Fin num => (Var type (SyntaxKind (Array idxNum Data)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up idxNum))) (evalExpr idx) + Const type (natToWord (Nat.log2_up idxNum) (proj1_sig (to_nat i)))])%kami_expr)))) :: x); split; auto. @@ -3199,7 +3200,7 @@ Proof. existT (fullType type) (SyntaxKind (Array rfIdxNum rfData)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array rfIdxNum rfData))) (i : Fin.t rfNum) => + (fun (newArr : Expr type (SyntaxKind (Array rfIdxNum rfData))) (i : Fin rfNum) => (IF ReadArrayConst (ReadStruct e (Fin.FS (Fin.FS Fin.F1))) i then newArr @[ ReadStruct e Fin.F1 + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- @@ -3237,7 +3238,7 @@ Proof. existT (fullType type) (SyntaxKind (Array rfIdxNum rfData)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array rfIdxNum rfData))) (i : Fin.t rfNum) => + (fun (newArr : Expr type (SyntaxKind (Array rfIdxNum rfData))) (i : Fin rfNum) => (newArr @[ ReadStruct e Fin.F1 + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- ReadArrayConst (ReadStruct e (Fin.FS Fin.F1)) i])%kami_expr) (getFins rfNum) (Var type (SyntaxKind (Array rfIdxNum rfData)) regV))))) :: nil); simpl in *. @@ -3383,7 +3384,7 @@ Proof. specialize (H _ _ _ P0 _ HIn _ _ _ HESemAction); dest. exists (UmMeth (x, existT SignT (Bit (Nat.log2_up rfIdxNum), Array rfNum rfData) (evalExpr e, (evalExpr (BuildArray - (fun i : Fin.t rfNum => + (fun i : Fin rfNum => (Var type (SyntaxKind (Array rfIdxNum rfData)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up rfIdxNum))) (evalExpr e) + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i))))])%kami_expr)))))::x0); split. @@ -3537,7 +3538,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (IF ReadArrayConst mask0 i then newArr @[ idx + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- ReadArrayConst val i] else newArr)%kami_expr) @@ -3558,7 +3559,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (newArr @[ idx + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- ReadArrayConst val i])%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regV))))::x0); split; auto. @@ -3593,7 +3594,7 @@ Proof. existT (fullType type) (SyntaxKind (Array num Data)) (evalExpr (BuildArray - (fun i : Fin.t num => + (fun i : Fin num => (Var type (SyntaxKind (Array idxNum Data)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up idxNum))) (evalExpr idx) + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i))))])%kami_expr)))) :: x0); split; auto. @@ -3876,7 +3877,7 @@ Proof. (Var type (SyntaxKind (Array rfNum rfData)) (evalExpr (BuildArray - (fun i0 : Fin.t rfNum => + (fun i0 : Fin rfNum => (Var type (SyntaxKind (Array rfIdxNum rfData)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up rfIdxNum))) (evalExpr e) + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i0))))])%kami_expr))))))]); simpl; repeat intro; auto. @@ -4042,7 +4043,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (IF ReadArrayConst mask0 i then newArr @[ idx + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- ReadArrayConst val i] else newArr)%kami_expr) @@ -4063,7 +4064,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (newArr @[ idx + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- ReadArrayConst val i])%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regV))))::x); split; auto. @@ -4098,7 +4099,7 @@ Proof. existT (fullType type) (SyntaxKind (Array num Data)) (evalExpr (BuildArray - (fun i : Fin.t num => + (fun i : Fin num => (Var type (SyntaxKind (Array idxNum Data)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up idxNum))) (evalExpr idx) + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i))))])%kami_expr)))) :: x); split; auto. @@ -4265,7 +4266,7 @@ Proof. (Var type (SyntaxKind (Array rfNum rfData)) (evalExpr (BuildArray - (fun i0 : Fin.t rfNum => + (fun i0 : Fin rfNum => (Var type (SyntaxKind (Array rfIdxNum rfData)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up rfIdxNum))) (evalExpr e) + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i0))))])%kami_expr))))))]). @@ -4404,7 +4405,7 @@ Proof. specialize (H _ _ _ _ _ P0 HIn _ _ _ HESemAction); dest. exists (UmMeth (meth, existT SignT (Void, Array rfNum rfData) (WO, (evalExpr (BuildArray - (fun i : Fin.t rfNum => + (fun i : Fin rfNum => (Var type (SyntaxKind (Array rfIdxNum rfData)) regVal @[ Var type (SyntaxKind (Bit (Nat.log2_up rfIdxNum))) idx + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i))))])%kami_expr)))))::x); split. @@ -4571,7 +4572,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (IF ReadArrayConst mask0 i then newArr @[ idx + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- ReadArrayConst val i] else newArr)%kami_expr) @@ -4592,7 +4593,7 @@ Proof. existT (fullType type) (SyntaxKind (Array idxNum Data)) (evalExpr (fold_left - (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin.t num) => + (fun (newArr : Expr type (SyntaxKind (Array idxNum Data))) (i : Fin num) => (newArr @[ idx + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i)))) <- ReadArrayConst val i])%kami_expr) (getFins num) (Var type (SyntaxKind (Array idxNum Data)) regV))))::x); split; auto. @@ -4627,7 +4628,7 @@ Proof. existT (fullType type) (SyntaxKind (Array num Data)) (evalExpr (BuildArray - (fun i : Fin.t num => + (fun i : Fin num => (Var type (SyntaxKind (Array idxNum Data)) regV @[ Var type (SyntaxKind (Bit (Nat.log2_up idxNum))) (evalExpr idx) + Const type (ZToWord _ (Z.of_nat (proj1_sig (Fin.to_nat i))))])%kami_expr)))) :: x); split; auto. diff --git a/Compiler/CompilerSimpleSem.v b/Compiler/CompilerSimpleSem.v index 6419aeb..f28a978 100644 --- a/Compiler/CompilerSimpleSem.v +++ b/Compiler/CompilerSimpleSem.v @@ -60,7 +60,7 @@ Section SemSimple. (HisAddr : isAddr = false) (HWriteMap : Sem_RmeSimple (UpdRegRME readReg pred - (BuildArray (fun i : Fin.t num => + (BuildArray (fun i : Fin num => ReadArray arr (CABit Add (Var type (SyntaxKind _) (evalExpr idx) :: @@ -137,7 +137,7 @@ Section SemSimple. (HIn : In (dataArray, (existT _ (SyntaxKind (Array idxNum Data)) regVal)) updatedRegs) cont calls val contArray (HContArray : contArray = - BuildArray (fun i : Fin.t num => + BuildArray (fun i : Fin num => ReadArray (Var type _ regVal) (CABit Add (Var type (SyntaxKind _) (evalExpr idx) :: @@ -177,7 +177,7 @@ Section SemSimple. (HRegVal2 : In (dataArray, existT _ (SyntaxKind (Array idxNum Data)) regVal) updatedRegs) (contArray : Expr type (SyntaxKind (Array num Data))) (HContArray : contArray = - BuildArray (fun i : Fin.t num => + BuildArray (fun i : Fin num => ReadArray (Var type _ regVal) (CABit Add (Var type (SyntaxKind _) idx :: diff --git a/Compiler/PrettyPrintVerilog.hs b/Compiler/PrettyPrintVerilog.hs index d5775d0..d0fff76 100644 --- a/Compiler/PrettyPrintVerilog.hs +++ b/Compiler/PrettyPrintVerilog.hs @@ -10,22 +10,17 @@ import Control.Monad.State.Lazy import qualified Data.Map.Lazy as H import Debug.Trace import Numeric +import Specif log2_up :: Int -> Int log2_up = ceiling . (logBase 2) . fromIntegral -intToFin :: Int -> Int -> (Int,Int) -intToFin = (,) - ppDealSize0 :: T.Kind -> String -> String -> String ppDealSize0 ty def str = if T.size ty == 0 then def else str ppVecLen :: Int -> String ppVecLen n = "[" ++ show (n-1) ++ ":0]" -finToInt :: (Int,Int) -> Int -finToInt = snd - deformat :: String -> String deformat = concatMap (\c -> if c == '\n' then "\\n" else c:[]) @@ -197,7 +192,7 @@ ppRtlExpr who e = return $ ppDealSize0 k "0" (new ++ '[' : xidx ++ "]") T.ReadArrayConst n k vec idx -> do - let xidx = finToInt idx + let xidx = T.to_nat n idx -- finToInt idx xvec <- ppRtlExpr who vec new <- optionAddToTrunc who (T.Array n k) vec return $ ppDealSize0 k "0" (new ++ '[' : show xidx ++ "]") diff --git a/Compiler/Test.v b/Compiler/Test.v index 52d1e7b..85a3e64 100644 --- a/Compiler/Test.v +++ b/Compiler/Test.v @@ -2,6 +2,7 @@ Require Import Kami.Extraction. Require Import BinNat. Import FinFun.Fin2Restrict. Require Import Kami.AllNotations. +Require Import Kami.StdLib.Fin. Class toString (X : Type) := { to_string : X -> string @@ -71,16 +72,16 @@ Definition Counter := Bit 2. Definition init_val : word Xlen := Xlen 'h"e". (* mask = {true; false; false; false; true} *) -Definition mask_func1 : Fin.t num -> bool := fun (i : Fin.t num) => +Definition mask_func1 : Fin num -> bool := fun (i : Fin num) => match i with - | F1 _ => true - | FS _ (F1 _) => false - | FS _ (FS _ (F1 _)) => false - | FS _ (FS _ (FS _ (F1 _))) => false + | inl _ => true + | inr (inl _) => false + | inr (inr (inl _)) => false + | inr (inr (inr (inl _))) => false | _ => true - end. + end. -Definition mask_func2 : Fin.t num -> bool := fun i => negb (mask_func1 i). +Definition mask_func2 : Fin num -> bool := fun i => negb (mask_func1 i). Definition mask1 : ConstT (Array num Bool) := ConstArray mask_func1. Definition mask2 : ConstT (Array num Bool) := ConstArray mask_func2. @@ -114,42 +115,44 @@ Qed. (*good masks*) -Lemma mask1_under_true : exists (i : Fin.t num), mask_func1 i = true /\ f2n i < num - (write_index - read_under_index). +Definition f2n {n : nat} (i : Fin n) := proj1_sig (Fin.to_nat i). + +Lemma mask1_under_true : exists (i : Fin num), mask_func1 i = true /\ f2n i < num - (write_index - read_under_index). Proof. exists F1; simpl; auto. Qed. -Lemma mask1_under_false : exists (i : Fin.t num), mask_func1 i = false /\ f2n i < num - (write_index - read_under_index). +Lemma mask1_under_false : exists (i : Fin num), mask_func1 i = false /\ f2n i < num - (write_index - read_under_index). Proof. exists (FS F1); simpl; auto. Qed. -Lemma mask1_over_true : exists (i : Fin.t num), mask_func1 i = true /\ f2n i > (read_over_index - write_index). +Lemma mask1_over_true : exists (i : Fin num), mask_func1 i = true /\ f2n i > (read_over_index - write_index). Proof. exists (FS (FS (FS (FS F1)))); unfold f2n; simpl; auto. Qed. -Lemma mask1_over_false : exists (i : Fin.t num), mask_func1 i = false /\ f2n i > (read_over_index - write_index). +Lemma mask1_over_false : exists (i : Fin num), mask_func1 i = false /\ f2n i > (read_over_index - write_index). Proof. exists (FS (FS (FS F1))); unfold f2n; simpl; auto. Qed. -Lemma mask2_under_true : exists (i : Fin.t num), mask_func2 i = true /\ f2n i < num - (write_index - read_under_index). +Lemma mask2_under_true : exists (i : Fin num), mask_func2 i = true /\ f2n i < num - (write_index - read_under_index). Proof. exists (FS F1); simpl; auto. Qed. -Lemma mask2_under_false : exists (i : Fin.t num), mask_func2 i = false /\ f2n i < num - (write_index - read_under_index). +Lemma mask2_under_false : exists (i : Fin num), mask_func2 i = false /\ f2n i < num - (write_index - read_under_index). Proof. exists F1; simpl; auto. Qed. -Lemma mask2_over_true : exists (i : Fin.t num), mask_func2 i = true /\ f2n i > (read_over_index - write_index). +Lemma mask2_over_true : exists (i : Fin num), mask_func2 i = true /\ f2n i > (read_over_index - write_index). Proof. exists (FS (FS (FS F1))); unfold f2n; simpl; auto. Qed. -Lemma mask2_over_false : exists (i : Fin.t num), mask_func2 i = false /\ f2n i > (read_over_index - write_index). +Lemma mask2_over_false : exists (i : Fin num), mask_func2 i = false /\ f2n i > (read_over_index - write_index). Proof. exists (FS (FS (FS (FS F1)))); unfold f2n; simpl; auto. Qed. @@ -275,7 +278,7 @@ Definition expected_read_over(val : word Xlen) : ConstT (Array num Data) := ConstArray (fun i => if f2n i bool) : ConstT (Array num Data) := +Definition expected_read_under_masked(mask_val non_mask_val : word Xlen)(mf nmf : Fin num -> bool) : ConstT (Array num Data) := ConstArray (fun i => if f2n i bool) : ConstT (Array num Data) := +Definition expected_read_over_masked(mask_val non_mask_val : word Xlen)(mf nmf : Fin num -> bool) : ConstT (Array num Data) := ConstArray (fun i => match Compare_dec.le_lt_dec (num - (read_over_index - write_index)) (f2n i) with | left _ => init_val | right pf => if mf (read_over_Fin_to_write_Fin i pf) then mask_val else if nmf (read_over_Fin_to_write_Fin i pf) then non_mask_val else init_val end). -Definition expected_read_ot_mt(write_val old_val : word Xlen)(wmf omf : Fin.t num -> bool)(ot : OverlapType)(mt : MaskType) := +Definition expected_read_ot_mt(write_val old_val : word Xlen)(wmf omf : Fin num -> bool)(ot : OverlapType)(mt : MaskType) := match ot,mt with | Over,IsWrMask => expected_read_over_masked write_val old_val wmf omf | Over,NotIsWrMask => expected_read_over write_val diff --git a/Extraction.v b/Extraction.v index 2fff29c..e1674f7 100644 --- a/Extraction.v +++ b/Extraction.v @@ -1,5 +1,6 @@ Require Export List String Ascii BinInt BinNat. Require Export Kami.Syntax Kami.Compiler.CompilerSimple Kami.Compiler.Compiler Kami.Compiler.Rtl Kami.LibStruct Kami.Compiler.UnverifiedIncompleteCompiler. +Require Export Kami.StdLib.Fin. Require Import Kami.Notations. @@ -14,7 +15,6 @@ Set Extraction KeepSingleton. Unset Extraction AutoInline. Extract Inductive sigT => "(,)" ["(,)"]. -Extract Inductive Fin.t => "CustomExtract.EFin" ["CustomExtract.fin0" "CustomExtract.finS"] "CustomExtract.finRec". Extract Inductive N => "Prelude.Integer" ["0" "(\x -> x)"] "(\fn0 fnpos x -> if x Prelude.== 0 then fn0 () else fnpos x)". Extract Inlined Constant fst => "Prelude.fst". @@ -34,14 +34,6 @@ Extract Constant Nat.log2_up => "(\x -> Prelude.ceiling (Prelude.logBase 2 (Prel Extract Constant List.fold_left => "(\f bs a -> Data.List.foldl' f a bs)". Extract Constant natToWord => "(\sz n -> Prelude.toInteger n)". Extract Constant wordToNat => "(\_ -> Prelude.fromIntegral)". -Extract Constant sumSizes => "(\n f -> Prelude.sum (Prelude.map (\i -> f (n Prelude.-1,i)) [0..(n Prelude.-1)]))". -Extract Constant nth_Fin => "(\xs (_,i) -> xs Prelude.!! i)". -Extract Constant nth_Fin_map2 => "(\_ _ _ x -> x)". -Extract Constant getFins => "(\n -> Prelude.map ((,) (n Prelude.- 1)) [0..(n Prelude.- 1)])". -Extract Constant Fin.to_nat => "(\_ (_,i) -> i)". -Extract Constant Fin.cast => "(\_ x _ -> x)". -Extract Constant Fin.of_nat_lt => "(\i n -> (n Prelude.- 1,i))". -Extract Constant Fin_eq_dec => "(\_ x y -> x Prelude.== y)". Extract Inlined Constant getBool => "Prelude.id". Extract Inlined Constant String.append => "(Prelude.++)". Extract Constant ZToWord => "(\n x -> Prelude.mod x (2 Prelude.^ n))". @@ -73,7 +65,7 @@ Extract Inlined Constant Z.modulo => "Prelude.mod". Extract Inlined Constant N.succ_pos => "(\x -> x Prelude.+ 1)". Extract Inlined Constant N.add => "(Prelude.+)". Extract Inlined Constant N.sub => "(Prelude.-)". -Extract Inlined Constant N.mul => "(Prelude.*)". +Extract Inlined Constant N.mul => "(Prelude.* )". Extract Inlined Constant N.eqb => "(Prelude.==)". Extract Inlined Constant N.ltb => "(Prelude.<)". Extract Inlined Constant N.of_nat => "Prelude.toInteger". diff --git a/Lib/EclecticLib.v b/Lib/EclecticLib.v index b0152b9..ff0517b 100644 --- a/Lib/EclecticLib.v +++ b/Lib/EclecticLib.v @@ -4,6 +4,7 @@ Require Import Coq.Arith.Div2. Require Import Coq.NArith.NArith. Require Import Arith_base. Require Import Arith Coq.ZArith.Znat Psatz. +Require Import Kami.StdLib.Fin. Import ListNotations. @@ -58,117 +59,13 @@ Proof. exists x; split; [auto | rewrite H; auto]. Qed. *) -Fixpoint Fin_t_foldr - (A : Type) - (n : nat) - (init : A) - := match n return - forall (f : Fin.t n -> A -> A), A - with - | 0 => fun _ => init - | S m => fun f => f Fin.F1 (Fin_t_foldr m init (fun i => f (Fin.FS i))) - end. - -Section nth_Fin. - Variable A: Type. - Fixpoint nth_Fin (ls: list A): Fin.t (length ls) -> A := - match ls return Fin.t (length ls) -> A with - | nil => fun pf => Fin.case0 _ pf - | x :: xs => fun i => - match i in Fin.t n return n = length (x :: xs) -> A with - | Fin.F1 _ => fun _ => x - | Fin.FS _ y => fun pf => - nth_Fin xs - match eq_add_S _ _ pf in _ = Y return Fin.t Y with - | eq_refl => y - end - end eq_refl - end. - - Definition nth_Fin' (ls: list A) n (pf: n = length ls) (i: Fin.t n): A := - nth_Fin ls (Fin.cast i pf). - - Fixpoint nth_Fin'' (ls: list A) n (pf: n <= length ls) {struct ls} : Fin.t n -> A. - Proof. - refine( - match ls return (n <= length ls) -> Fin.t n -> A with - | nil => fun pf i => Fin.case0 _ (Fin.cast i _) - | x :: xs => fun pf i => - match i in Fin.t m return m <= length (x :: xs) -> A with - | Fin.F1 _ => fun _ => x - | Fin.FS _ z => fun pf => nth_Fin'' xs _ _ z - end _ - end _). - all: cbn in *; abstract omega. - Defined. - - Lemma nth_Fin'_nth : forall n d (i: Fin.t n) (xs: list A) (len_eq: n = length xs), - let i' := proj1_sig (Fin.to_nat i) in - nth_Fin' xs len_eq i = nth i' xs d. - Proof. - induction n; cbn; intros *; try easy. - destruct xs; cbn in *; try easy. - inversion len_eq. - destruct i eqn:?; cbn; auto. - destruct (Fin.to_nat _) eqn:?; cbn. - assert (n0 = n); subst. - { inversion len_eq; subst; auto. } - specialize (IHn d t xs (f_equal pred len_eq)). - rewrite Heqs in IHn; cbn in IHn; auto. - Qed. - - Lemma nth_Fin_nth : forall d (xs: list A) (i: Fin.t (length xs)), - let i' := proj1_sig (Fin.to_nat i) in - nth_Fin xs i = nth i' xs d. - Proof. - cbn; intros. - rewrite <- (nth_Fin'_nth _ _ _ eq_refl). - unfold nth_Fin'; f_equal. - clear; induction i; cbn; auto. - rewrite <- IHi; auto. - Qed. -End nth_Fin. - -Definition fin_case n x : - forall (P : Fin.t (S n) -> Type), - P Fin.F1 -> - (forall y, P (Fin.FS y)) -> - P x := - match x in Fin.t n0 - return - forall P, - match n0 return (Fin.t n0 -> (Fin.t n0 -> Type) -> Type) with - | 0 => fun _ _ => False - | S m => fun x P => P Fin.F1 -> (forall x0, P (Fin.FS x0)) -> P x - end x P - with - | Fin.F1 _ => fun _ H1 _ => H1 - | Fin.FS _ _ => fun _ _ HS => HS _ - end. - +(* + Accepts one argument: v : Fin (S n); and destructs v into two + cases where v = F1 and another where v = FS. +*) Ltac fin_dep_destruct v := - pattern v; apply fin_case; clear v; intros. - -Lemma Fin_cast_lemma : forall m n i (p q : m = n), - Fin.cast i p = Fin.cast i q. -Proof. - intros. - rewrite (UIP_nat _ _ p q); reflexivity. -Defined. - -Lemma fin_to_nat_cast : forall n (i: Fin.t n) m (Heq: n = m), - proj1_sig (Fin.to_nat (Fin.cast i Heq)) = proj1_sig (Fin.to_nat i). -Proof. - induction n; cbn; intros *; try easy. - destruct m; try easy. - assert (n = m) by auto. - destruct i eqn:?; cbn; auto. - assert (n0 = n) by (subst; auto); subst. - specialize (IHn t m eq_refl). - destruct (Fin.to_nat t) eqn:?; cbn in *. - rewrite <- (Fin_cast_lemma _ eq_refl). - destruct (Fin.to_nat (Fin.cast t eq_refl)) eqn:?; cbn in *; auto. -Qed. + (* pattern v; apply fin_case; clear v; intros. *) + simpl in v; destruct v as [i|y]; [destruct i|idtac]. Definition UIP(X : Type) := forall (x y : X)(p q : x = y), p = q. @@ -203,76 +100,6 @@ Proof. contradiction. Defined. -Definition map_length_red := - (fun (A B : Type) (f : A -> B) (l : list A) => - list_ind (fun l0 : list A => Datatypes.length (map f l0) = Datatypes.length l0) eq_refl - (fun (a : A) (l0 : list A) (IHl : Datatypes.length (map f l0) = Datatypes.length l0) => - f_equal_nat nat S (Datatypes.length (map f l0)) (Datatypes.length l0) IHl) l) - : forall (A B : Type) (f : A -> B) (l : list A), Datatypes.length (map f l) = Datatypes.length l. - -Section nth_Fin_map2. - Variable A B: Type. - Variable g: A -> B. - Variable f: B -> Type. - - Fixpoint nth_Fin_map2 (ls: list A): - forall (p : Fin.t (length (map g ls))) - (val: f (g (nth_Fin ls (Fin.cast p (map_length_red g ls))))), - f (nth_Fin (map g ls) p). - refine - match ls return forall (p : Fin.t (length (map g ls))) - (val: f (g (nth_Fin ls (Fin.cast p (map_length_red g ls))))), - f (nth_Fin (map g ls) p) with - | nil => fun i _ => Fin.case0 (fun j => f (nth_Fin (map g nil) j)) i - | x :: xs => fun p => _ - end. - fin_dep_destruct p. - + exact val. - + apply (nth_Fin_map2 xs y). - match goal with - | |- f (g (nth_Fin xs (Fin.cast y ?P))) => - rewrite (hedberg eq_nat_dec P (f_equal Init.Nat.pred (map_length_red g (x :: xs)))) - end. - exact val. - Defined. -End nth_Fin_map2. - -Section Fin. - -Fixpoint Fin_forallb{n} : (Fin.t n -> bool) -> bool := - match n return (Fin.t n -> bool) -> bool with - | 0 => fun _ => true - | S m => fun p => p Fin.F1 && Fin_forallb (fun i => p (Fin.FS i)) - end. - -Lemma Fin_forallb_correct{n} : forall p : Fin.t n -> bool, - Fin_forallb p = true <-> forall i, p i = true. -Proof. - induction n; intros; split; intros. - apply (Fin.case0 (fun i => p i = true)). - reflexivity. - simpl in H. - fin_dep_destruct i. - destruct (p F1); [auto|discriminate]. - apply (IHn (fun j => p (FS j))). - destruct (p F1); [auto|discriminate]. - simpl. - apply andb_true_intro; split. - apply H. - apply IHn. - intro; apply H. -Qed. - -Definition Fin_cast : forall m n, Fin.t m -> m = n -> Fin.t n := - fun m n i pf => match pf in _ = y return Fin.t y with - | eq_refl => i - end. - -End Fin. - - - - Lemma inversionPair A B (a1 a2: A) (b1 b2: B): (a1, b1) = (a2, b2) -> a1 = a2 /\ b1 = b2. @@ -347,27 +174,6 @@ Ltac existT_destruct dec := subst end. -Fixpoint Fin_eq_dec n a {struct a}: forall (b: Fin.t n), {a = b} + {a <> b}. -Proof. - refine - match a in Fin.t n return forall b: Fin.t n, {a = b} + {a <> b} with - | Fin.F1 _ => fun b => match b with - | Fin.F1 _ => left eq_refl - | _ => right _ - end - | Fin.FS _ x => fun b => match b in Fin.t (S m) return forall x: Fin.t m, (forall y: Fin.t m, {x = y} + {x <> y}) -> {Fin.FS x = b} + {Fin.FS x <> b} with - | Fin.F1 _ => fun _ _ => right _ - | Fin.FS _ y => fun _ f => - match f y with - | left eq1 => left (f_equal Fin.FS eq1) - | right neq => right _ - end - end x (Fin_eq_dec _ x) - end; intro; clear Fin_eq_dec; try (abstract discriminate). - abstract (injection H; intros; existT_destruct Nat.eq_dec; tauto). -Defined. - - Section fold_left_right. Variable A B: Type. Variable f: A -> B -> A. @@ -584,62 +390,16 @@ Proof. induction n; destruct xs; cbn; try easy; eauto. Qed. -Fixpoint getFins n := - match n return list (Fin.t n) with - | 0 => nil - | S m => Fin.F1 :: map Fin.FS (getFins m) - end. - -Fixpoint getFinsBound m n: list (Fin.t n) := - match m return (list (Fin.t n)) with - | 0 => nil - | S k => match n with - | 0 => nil - | S n' => Fin.F1 :: map Fin.FS (getFinsBound k n') - end - end. - -Definition mapOrFins n (x: Fin.t n) := fold_left (fun a b => x = b \/ a) (getFins n) False. - -Lemma getFins_length : forall n, length (getFins n) = n. -Proof. - induction n; cbn; auto. - rewrite map_length; auto. -Qed. - -Lemma getFins_all : forall n (i: Fin.t n), In i (getFins n). -Proof. - induction i; cbn; auto using in_map. -Qed. - -Lemma getFins_nth_error : forall n (i: Fin.t n), - let i' := proj1_sig (Fin.to_nat i) in - nth_error (getFins n) i' = Some i. -Proof. - induction i; cbn in *; auto. - destruct (Fin.to_nat i); cbn in *. - apply map_nth_error; auto. -Qed. - -Lemma getFins_nth : forall n d (i: Fin.t n), - let i' := proj1_sig (Fin.to_nat i) in - nth i' (getFins n) d = i. -Proof. - intros. - apply nth_error_nth. - apply getFins_nth_error. -Qed. - Section Arr. Variable A: Type. Variable def: A. - Definition list_arr n (arr : Fin.t n -> A):= map arr (getFins n). + Definition list_arr {n} (arr : Fin n -> A):= map arr (getFins n). Lemma list_arr_correct : - forall n (arr : Fin.t n -> A)(i: nat), + forall n (arr : Fin n -> A)(i: nat), match lt_dec i n with - | left pf => arr (Fin.of_nat_lt pf) + | left pf => arr (of_nat_lt pf) | right _ => def end = nth_default def (list_arr arr) i. Proof. @@ -661,8 +421,8 @@ Section Arr. Qed. Lemma list_arr_correct_simple : - forall n (arr : Fin.t n -> A) i, - nth_error (list_arr arr) (proj1_sig (Fin.to_nat i)) = Some (arr i). + forall n (arr : Fin n -> A) i, + nth_error (list_arr arr) (proj1_sig (to_nat i)) = Some (arr i). Proof. intros. unfold list_arr; apply map_nth_error. @@ -723,16 +483,24 @@ Proof. induction ls; simpl; auto. Qed. -Lemma mapOrFins_true n: forall (x: Fin.t n), mapOrFins x. +Lemma mapOrFins_true n: forall (i: Fin n), mapOrFins i. Proof. - induction x; unfold mapOrFins in *; simpl; intros. - - apply fold_left_or_init; auto. - - rewrite fold_left_map. - eapply (@fold_left_or_impl _ (fun b => x = b) (getFins n) _ False (Fin.FS x = Fin.F1 \/ False)); try tauto; congruence. + induction n as [|n IH]. + + exact (case0 (fun i => mapOrFins i)). + + destruct i as [u|j]; unfold mapOrFins; simpl. + - apply (fold_left_or_init); + left; destruct u; reflexivity. + - rewrite fold_left_map. + unfold mapOrFins in IH. + exact (fold_left_or_impl + (fun k : Fin n => j = k) + (getFins n) + (fun k : Fin n => inr j = FS k) + (False_ind (inr j = F1 \/ False)) + (fun (k : Fin n) (H : j = k) => ltac:(rewrite H; reflexivity)) + (IH j)). Qed. - - Lemma list_split A B C (f: A -> C) (g: B -> C): forall l l1 l2, map f l = map g l1 ++ map g l2 -> exists l1' l2', @@ -2392,7 +2160,7 @@ Theorem div2_minus_2 : forall n k, -> Nat.div2 (n - 2 * k) = Nat.div2 n - k. induction n as [n] using strong; intros. - do 2 (destruct n; simpl in *; intuition; repeat rewrite untimes2 in *). + do 2 (destruct n; simpl in *; intuition; repeat rewrite untimes2 in * ). destruct k; simpl in *; intuition. destruct k; simpl in *; intuition. @@ -3373,10 +3141,10 @@ Proof. apply IHm; lia. Qed. -Lemma Fineqb_refl {m} (n : t m) : - Fin.eqb n n = true. +Lemma Fineqb_refl {m} (n : Fin m) : + Kami.StdLib.Fin.eqb n n = true. Proof. - rewrite Fin.eqb_eq; reflexivity. + rewrite Kami.StdLib.Fin.eqb_eq; reflexivity. Qed. Lemma Nat_mod_congr a b c : @@ -3454,12 +3222,13 @@ Proof. Qed. Lemma list_arr_length {A : Type} n : - forall (arr : t n -> A), + forall (arr : Fin n -> A), n = length (list_arr arr). Proof. unfold list_arr; intros. rewrite map_length, getFins_length; reflexivity. Qed. + Lemma firstn_map {A B: Type} (l : list A) (f : A -> B): forall n, firstn n (map f l) = map f (firstn n l). @@ -3609,21 +3378,21 @@ Proof. - destruct l; auto; discriminate. Qed. -Lemma Fin_eqb_neq {n : nat} (p q : Fin.t n): - Fin.eqb p q = false <-> p <> q. +Lemma Fin_eqb_neq {n : nat} (p q : Fin n): + Kami.StdLib.Fin.eqb p q = false <-> p <> q. Proof. red; split; repeat intro. - - rewrite <- Fin.eqb_eq in H0; rewrite H0 in H; discriminate. - - destruct Fin.eqb eqn:G; auto. + - rewrite <- Kami.StdLib.Fin.eqb_eq in H0; rewrite H0 in H; discriminate. + - destruct Kami.StdLib.Fin.eqb eqn:G; auto. exfalso. - rewrite Fin.eqb_eq in G; contradiction. + rewrite Kami.StdLib.Fin.eqb_eq in G; contradiction. Qed. Section FifoProps. Variable size : nat. Local Notation lgSize := (Nat.log2_up size). Variable A : Type. - Variable implArray : Fin.t size -> A. + Variable implArray : Fin size -> A. Variable enqP1 deqP1 : Z. Variable enqP1Bnd : (0 <= enqP1 < 2 ^ Z.of_nat (lgSize + 1))%Z. @@ -3632,7 +3401,7 @@ Section FifoProps. Local Notation deq := (deqP1 mod (2 ^ (Z.of_nat lgSize)))%Z. Local Notation cutLen := ((enqP1 - deqP1) mod (2 ^ (Z.of_nat (lgSize + 1))))%Z. - Definition convertToList {n} (kamiArray : Fin.t n -> A) := @list_arr A n kamiArray. + Definition convertToList {n} (kamiArray : Fin n -> A) := @list_arr A n kamiArray. Local Notation specList := (firstn (Z.to_nat cutLen) (rotateList (Z.to_nat deq) (convertToList implArray))). @@ -3685,7 +3454,7 @@ Section FifoProps. Lemma hdCorrect : enqP1 <> deqP1 -> - hd_error specList = Some (implArray (Fin.of_nat_lt deq_lt_size)). + hd_error specList = Some (implArray (of_nat_lt deq_lt_size)). Proof. intros. rewrite hd_firstn. @@ -3773,13 +3542,13 @@ Section FifoProps. } rewrite <- (Z.mod_small _ _ P), Zminus_mod_idemp_l; reflexivity. Qed. - + Lemma listSnoc (val : A) : cutLen <> Z.of_nat size -> snoc val specList = firstn (Z.to_nat ((enqP1 + 1 - deqP1) mod 2 ^ Z.of_nat (lgSize + 1))) (rotateList (Z.to_nat deq) (convertToList - (fun i => if (Fin.eqb i (Fin.of_nat_lt enq_lt_size)) then val else implArray i))). + (fun i => if (Kami.StdLib.Fin.eqb i (of_nat_lt enq_lt_size)) then val else implArray i))). Proof. intros HNotFull. rewrite cutLen_succ, Z2Nat.inj_add; auto; try lia; @@ -3803,8 +3572,8 @@ Section FifoProps. (firstn (Z.to_nat cutLen) (rotateList (Z.to_nat deq) (convertToList - (fun i : t size => - if Fin.eqb i (of_nat_lt enq_lt_size) + (fun i : Fin size => + if Kami.StdLib.Fin.eqb i (of_nat_lt enq_lt_size) then val else implArray i)))) m <> None) as G1. { intro G1; rewrite G1 in G0; discriminate. } @@ -3818,7 +3587,7 @@ Section FifoProps. { unfold convertToList. rewrite <- list_arr_length; reflexivity. } - assert (length (convertToList (fun i : t size => if Fin.eqb i (of_nat_lt enq_lt_size) then val else implArray i)) = size) as P0. + assert (length (convertToList (fun i : Fin size => if Kami.StdLib.Fin.eqb i (of_nat_lt enq_lt_size) then val else implArray i)) = size) as P0. { unfold convertToList. rewrite <- list_arr_length; reflexivity. } @@ -3831,9 +3600,9 @@ Section FifoProps. unfold convertToList. repeat rewrite list_arr_correct_simple. f_equal. - destruct Fin.eqb eqn:G; auto. + destruct Kami.StdLib.Fin.eqb eqn:G; auto. exfalso. - rewrite Fin.eqb_eq in G. + rewrite Kami.StdLib.Fin.eqb_eq in G. assert (proj1_sig (to_nat (of_nat_lt P1)) = proj1_sig (to_nat (of_nat_lt enq_lt_size))). { rewrite G; auto. } repeat rewrite to_nat_of_nat in H; simpl in H. @@ -3875,7 +3644,9 @@ Section FifoProps. -- rewrite H in P4. assert (enqP1 - deqP1 = Z.of_nat m)%Z as P5 by lia. rewrite P5, Z.mod_small in l; try lia. - - assert (length (convertToList (fun i : t size => if Fin.eqb i (of_nat_lt enq_lt_size) then val else implArray i)) = size) as P. + * rewrite <- (Nat2Z.id m), <- Z2Nat.inj_lt in l; lia. + * rewrite <- (Nat2Z.id m), <- Z2Nat.inj_lt in l; lia. + - assert (length (convertToList (fun i : Fin size => if Kami.StdLib.Fin.eqb i (of_nat_lt enq_lt_size) then val else implArray i)) = size) as P. { unfold convertToList, list_arr. rewrite map_length, getFins_length; reflexivity. } @@ -3912,7 +3683,7 @@ Section FifoProps. rewrite <- P0 at 1. rewrite list_arr_correct_simple, Fineqb_refl; reflexivity. Qed. - + End FifoProps. Lemma app_emptyb {A : Type} (l1 l2 : list A) : diff --git a/Lib/VectorFacts.v b/Lib/VectorFacts.v index 6d20652..d0e905d 100644 --- a/Lib/VectorFacts.v +++ b/Lib/VectorFacts.v @@ -1,21 +1,23 @@ -Require Coq.Vectors.Vector. -Import Vectors.VectorDef.VectorNotations. +Require Kami.StdLib.Vector. +Require Import Kami.StdLib.Fin. +Import VectorDef. +Import VectorNotations. Set Implicit Arguments. Set Asymmetric Patterns. Definition Vector_caseS' {A'} (Q : nat -> Type) - (P : forall {n} (v : Vector.t A' (S n)), Q n -> Type) - (H : forall h {n} t q, @P n (h :: t) q) {n} (v: Vector.t A' (S n)) + (P : forall {n} (v : Vec A' (S n)), Q n -> Type) + (H : forall h {n} t q, @P n (h :: t) q) {n} (v: Vec A' (S n)) (q : Q n) : P v q. Proof. specialize (fun h t => H h _ t q). change n with (pred (S n)) in H, q |- *. set (Sn := S n) in *. - pose (fun Sn (v : Vector.t A' Sn) (q : Q (pred Sn)) => - match Sn return Vector.t A' Sn -> Q (pred Sn) -> Type with + pose (fun Sn (v : Vec A' Sn) (q : Q (pred Sn)) => + match Sn return Vec A' Sn -> Q (pred Sn) -> Type with | S n' => P n' | 0 => fun _ _ => True end v q) as P'. @@ -24,85 +26,71 @@ Proof. | _ => P' Sn v q end). change (forall h (t : match Sn with - | S n' => Vector.t A' n' - | 0 => Vector.t A' Sn + | S n' => Vec A' n' + | 0 => Vec A' Sn end), P' Sn (match Sn return match Sn with - | S n' => Vector.t A' n' - | 0 => Vector.t A' Sn - end -> Vector.t A' Sn + | S n' => Vec A' n' + | 0 => Vec A' Sn + end -> Vec A' Sn with | S _ => fun t => h :: t | 0 => fun t => t end t) q) in H. clearbody P'; clear P. clearbody Sn. - destruct v as [|h ? t]. - { constructor. } - { apply H. } + destruct Sn. + + constructor. + + destruct v. + apply H. Defined. Definition Vector_nth_map' A (f: A -> Type) (n: nat): - forall v (p: Fin.t n), + forall v (p: Fin n), f (Vector.nth v p) -> Vector.nth (Vector.map f v) p. Proof. - induction p. - - apply (Vector.caseS (fun n1 v1 => f (Vector.nth v1 Fin.F1) -> Vector.nth (Vector.map f v1) Fin.F1) (fun h n1 t => fun x => x) v). - - apply (Vector.caseS (fun n1 v1 => forall p4, (forall v2, f (Vector.nth v2 p4) -> Vector.nth (Vector.map f v2) p4) -> - f (Vector.nth v1 (Fin.FS p4)) -> Vector.nth (Vector.map f v1) (Fin.FS p4)) - (fun h n1 t p4 IHp2 => - fun X => IHp2 t X) - v p IHp - ). + induction n; intros v p. + - contradiction. + - destruct v as [v0 v]. + destruct p as [p|p]. + + destruct p; simpl; trivial. + + exact (IHn v p). Defined. -Definition Vector_nth_map A (f: A -> Type) n (v: Vector.t A n) p +Definition Vector_nth_map A (f: A -> Type) n (v: Vec A n) p (m: f (Vector.nth v p)): Vector.nth (Vector.map f v) p := @Vector_nth_map' _ f n v p m. -Definition Vector_nth_map2_l' A B (g: A -> B) (f: B -> Type) n (v: Vector.t A n) (p: Fin.t n): - (forall p: Fin.t n, Vector.nth (Vector.map (fun x => f (g x)) v) p) -> - f (Vector.nth (Vector.map g v) p) := - Fin.t_rect - (fun n0 p3 => - forall v0, - (forall p, Vector.nth (Vector.map (fun x => f (g x)) v0) p) -> f (Vector.nth (Vector.map g v0) p3)) - (fun n0 v0 => - Vector.caseS - (fun n1 v1 => - (forall p, Vector.nth (Vector.map (fun x => f (g x)) v1) p) -> f (Vector.nth (Vector.map g v1) Fin.F1)) - (fun h n1 t => fun x => x Fin.F1) v0) - (fun n0 p3 IHp1 v0 => - Vector.caseS - (fun n1 v1 => - forall p4, - (forall v2, - (forall p, Vector.nth (Vector.map (fun x => f (g x)) v2) p) -> f (Vector.nth (Vector.map g v2) p4)) -> - (forall p, Vector.nth (Vector.map (fun x => f (g x)) v1) p) -> f (Vector.nth (Vector.map g v1) (Fin.FS p4))) - (fun h n1 t p4 IHp2 => - fun X => (IHp2 t (fun (p: Fin.t n1) => (X (Fin.FS p)))) - ) v0 p3 IHp1) n p v. +Definition Vector_nth_map2_l' A B (g: A -> B) (f: B -> Type) n : + forall (v: Vec A n) (p: Fin n), + (forall p: Fin n, Vector.nth (Vector.map (fun x => f (g x)) v) p) -> + f (Vector.nth (Vector.map g v) p). +Proof. + induction n; intros v p. + - contradiction. + - destruct v as [v0 v]; destruct p as [p|p]; intro h. + + destruct p; simpl; exact (h Fin.F1). + + exact (IHn v p (fun q => h (inr q))). +Defined. -Definition Vector_nth_map2_l A B (g: A -> B) (f: B -> Type) n (v: Vector.t A n) - (m: forall p: Fin.t n, Vector.nth (Vector.map (fun x => f (g x)) v) p) - (p: Fin.t n): f (Vector.nth (Vector.map g v) p) := @Vector_nth_map2_l' _ _ g f n v p m. +Definition Vector_nth_map2_l A B (g: A -> B) (f: B -> Type) n (v: Vec A n) + (m: forall p: Fin n, Vector.nth (Vector.map (fun x => f (g x)) v) p) + (p: Fin n): f (Vector.nth (Vector.map g v) p) := @Vector_nth_map2_l' _ _ g f n v p m. -Definition Vector_nth_map2_r' A B (g: A -> B) (f: B -> Type) n (v: Vector.t A n) (p: Fin.t n): - f (g (Vector.nth v p)) -> - f (Vector.nth (Vector.map g v) p). +Definition Vector_nth_map2_r' A B (g: A -> B) (f: B -> Type) n: + forall (v: Vec A n) (p: Fin n), + f (g (Vector.nth v p)) -> + f (Vector.nth (Vector.map g v) p). Proof. - induction p. - - apply (Vector.caseS (fun n1 v1 => f (g (Vector.nth v1 Fin.F1)) -> - f (Vector.nth (Vector.map g v1) Fin.F1)) (fun h n1 t => fun x => x) v). - - apply (Vector.caseS (fun n1 v1 => forall p4, (forall v2, f (g (Vector.nth v2 p4)) -> f (Vector.nth (Vector.map g v2) p4)) -> - f (g (Vector.nth v1 (Fin.FS p4))) -> f (Vector.nth (Vector.map g v1) (Fin.FS p4))) - (fun h n1 t p4 IHp2 => - fun X => IHp2 t X) - v p IHp - ). + induction n; intros v p. + - contradiction. + - destruct v as [v0 v]. + destruct p as [p|p]. + + destruct p; simpl; trivial. + + exact (IHn v p). Defined. -Definition Vector_nth_map2_r A B (g: A -> B) (f: B -> Type) n (v: Vector.t A n) (p: Fin.t n) +Definition Vector_nth_map2_r A B (g: A -> B) (f: B -> Type) n (v: Vec A n) (p: Fin n) (m: f (g (Vector.nth v p))): f (Vector.nth (Vector.map g v) p) := @Vector_nth_map2_r' _ _ g f n v p m. @@ -110,40 +98,18 @@ Section find. Variable A: Type. Variable f: A -> bool. - Fixpoint Vector_find' n (v: Vector.t A n): match n with - | 0 => unit - | S m => Fin.t (S m) - end := - match v in Vector.t _ n0 return match n0 with - | 0 => unit - | S m0 => Fin.t (S m0) - end with - | Vector.nil => tt - | Vector.cons h n1 t => if f h - then Fin.F1 - else match n1 as n0 return (Vector.t _ n0 -> Fin.t (S n0)) with - | 0 => fun _ => Fin.F1 - | S n2 => fun t0 => Fin.FS (Vector_find' t0) - end t - end. - - Definition Vector_find n (v: Vector.t A (S n)): Fin.t (S n) := Vector_find' v. - - - Fixpoint Vector_find_opt n (v: Vector.t A n): option (Fin.t n) := - match v in Vector.t _ n0 return option (Fin.t n0) with - | Vector.nil => None - | Vector.cons h n1 t => if f h - then Some Fin.F1 - else match n1 as n0 return (Vector.t _ n0 -> option (Fin.t (S n0))) - with - | 0 => fun _ => None - | S n2 => fun t0 => - match Vector_find_opt t0 with - | None => None - | Some i => Some (Fin.FS i) - end - end t + Fixpoint Vector_find_opt {n} : forall v : Vec A n, option (Fin n) := + match n with + | 0 => fun _ => None + | S m => fun v : Vec A (S m) => + if f (fst v) + then Some Fin.F1 + else + let acc : option (Fin m) := Vector_find_opt (snd v) in + match acc with + | None => None + | Some p => Some (Fin.FS p) + end end. End find. diff --git a/LibStruct.v b/LibStruct.v index a4396b0..d76b175 100644 --- a/LibStruct.v +++ b/LibStruct.v @@ -1,4 +1,5 @@ Require Import Kami.Syntax Kami.Notations. +Require Import Kami.StdLib.Fin. (* TODO: move to KamiStdLib? *) Definition extractArbitraryRange ty sz (inst: Bit sz ## ty) (range: nat * nat): @@ -27,14 +28,14 @@ Definition nullStruct: Kind := Fixpoint BuildStructActionCont (ty: Kind -> Type) k - n: - forall (kinds : Fin.t n -> Kind) - (names : Fin.t n -> string) + {n}: + forall (kinds : Fin n -> Kind) + (names : Fin n -> string) (acts : forall i, ActionT ty (kinds i)) (cont: (forall i, Expr ty (SyntaxKind (kinds i))) -> ActionT ty k), ActionT ty k := - match n return forall (kinds : Fin.t n -> Kind) - (names : Fin.t n -> string) + match n return forall (kinds : Fin n -> Kind) + (names : Fin n -> string) (acts : forall i, ActionT ty (kinds i)) (cont : (forall i, Expr ty (SyntaxKind (kinds i))) -> ActionT ty k), ActionT ty k with @@ -48,24 +49,22 @@ Fixpoint BuildStructActionCont (fun i => acts (Fin.FS i)) (fun exps => cont (fun i => - match i in Fin.t (S m) return - forall (ks: - Fin.t (S m) -> Kind), - ty (ks Fin.F1) -> - (forall i: Fin.t m, Expr ty (SyntaxKind (ks (Fin.FS i)))) -> - Expr ty (SyntaxKind (ks i)) - with - | Fin.F1 _ => fun ks next exps => #next - | Fin.FS _ j => fun ks next exps => exps j + match i return + forall ks : Fin (S m) -> Kind, + ty (ks Fin.F1) -> + (forall j : Fin m, Expr ty (SyntaxKind (ks (Fin.FS j)))) -> + Expr ty (SyntaxKind (ks i)) with + | inl u => match u with tt => fun ks next exps => #next end + | inr j => fun ks next exps => exps j end kinds next exps)) end. -Definition BuildStructAction ty n (kinds: Fin.t n -> Kind) (names: Fin.t n -> string) (acts: forall i, ActionT ty (kinds i)) := +Definition BuildStructAction ty n (kinds: Fin n -> Kind) (names: Fin n -> string) (acts: forall i, ActionT ty (kinds i)) := BuildStructActionCont kinds names acts (fun x => Return (BuildStruct kinds names x)). Lemma WfConcatActionT_BuildStructActionCont: forall m k n kinds names acts cont, - (forall (i:Fin.t n), WfConcatActionT (acts i) m) -> + (forall (i:Fin n), WfConcatActionT (acts i) m) -> (forall x, WfConcatActionT (cont x) m) -> @WfConcatActionT type k (@BuildStructActionCont type k n kinds names acts cont) m. diff --git a/NotationsTest.v b/NotationsTest.v index 0a89494..6d6b6af 100644 --- a/NotationsTest.v +++ b/NotationsTest.v @@ -1,4 +1,5 @@ Require Import Kami.Syntax Kami.Notations Kami.Tactics. +Require Import Kami.StdLib.Fin. Section mod_test. Variable a: string. Local Notation "^ x" := (a ++ "." ++ x)%string (at level 0). diff --git a/Properties.v b/Properties.v index f8b310d..eb541eb 100644 --- a/Properties.v +++ b/Properties.v @@ -5,6 +5,7 @@ Require Import Coq.Sorting.Permutation. Require Import Coq.Sorting.PermutEq. Require Import RelationClasses Setoid Morphisms. Require Import ZArith. +Require Import Kami.StdLib.Fin. Definition filterRegs f m (o: RegsT) := filter (fun x => f (getBool (in_dec string_dec (fst x) (map fst (getAllRegisters m))))) o. @@ -498,13 +499,12 @@ Section evalExpr. repeat f_equal. Qed. - Lemma fin_to_nat_bound : forall n (x: Fin.t n), proj1_sig (Fin.to_nat x) < n. + Lemma fin_to_nat_bound : forall {n} (x: Fin n), proj1_sig (Fin.to_nat x) < n. Proof. - induction x; cbn; try lia. - destruct (Fin.to_nat x); cbn in *; lia. + intros n x. exact (proj2_sig (Fin.to_nat x)). Qed. - Lemma fin_to_word_id : forall n (i : Fin.t n), + Lemma fin_to_word_id : forall {n} (i : Fin n), wordToNat (natToWord (Nat.log2_up n) (proj1_sig (Fin.to_nat i))) = proj1_sig (Fin.to_nat i). Proof. intros. @@ -6716,12 +6716,12 @@ Proof. Qed. Lemma arr_nth_Fin' {A : Type} : - forall m (arr : t m -> A), - arr = (nth_Fin' _ (list_arr_length arr)). + forall m (arr : Fin m -> A), + arr = (nth_Fin' _ (list_arr_length m arr)). Proof. intros. apply functional_extensionality; intros. - rewrite (nth_Fin'_nth (arr x)). + rewrite (nth_Fin'_nth m (arr x)). rewrite <- nth_default_eq, <- list_arr_correct. destruct lt_dec. - specialize (of_nat_to_nat_inv x) as P. diff --git a/Simulator/CoqSim/Eval.v b/Simulator/CoqSim/Eval.v index c2cee1a..624916e 100644 --- a/Simulator/CoqSim/Eval.v +++ b/Simulator/CoqSim/Eval.v @@ -1,5 +1,4 @@ -Require Import Compare_dec List String Streams FinFun. -Import ListNotations Fin2Restrict. +Require Import Compare_dec String Streams. Require Import Kami.AllNotations. @@ -8,6 +7,10 @@ Require Import Kami.Simulator.CoqSim.TransparentProofs. Require Import Kami.Simulator.CoqSim.HaskellTypes. Import Kami.Simulator.CoqSim.HaskellTypes.Notations. Require Import Program. +Require Import Kami.StdLib.Fin. +Require Import Kami.StdLib.Vector. +Require Import List. +Import ListNotations. Section Eval. @@ -47,8 +50,17 @@ Fixpoint print_Val(k : Kind)(ff : FullFormat k) : eval_Kind k -> string := match ff with | FBool n _ => fun x => pad_with " " n (if x then "1" else "0") | FBit n m bf => fun x => pad_with "0" m (print_BF bf x) - | FStruct n fk fs ffs => fun x => ("{ " ++ String.concat "; " (v_to_list (vmap (fun '(str1,str2) => str1 ++ ":" ++ str2) (add_strings fs (tup_to_vec _ (fun i => print_Val (ffs i)) x)))) ++ "; }")%string - | FArray n k' ff' => fun x => ("[" ++ String.concat "; " (List.map (fun i => natToDecStr (f2n i) ++ "=" ++ print_Val ff' (vector_index i x)) (getFins n)) ++ "; ]")%string + | FStruct n fk fs ffs => fun x : Tuple (fun i : Fin n => eval_Kind (fk i)) => + ("{ " ++ + String.concat "; " + (Vector.to_list + (Vector.map + (fun '(str1,str2) => str1 ++ ":" ++ str2) + (Vector.add_strings fs + (tup_to_vec _ (fun i : Fin n => @print_Val (fk i) (ffs i)) x)))) ++ + "; }")%string + | FArray n k' ff' => fun x => + ("[" ++ String.concat "; " (List.map (fun i => natToDecStr (f2n i) ++ "=" ++ print_Val ff' (vector_index i x)) (getFins n)) ++ "; ]")%string end. (* for checkpointing *) @@ -56,7 +68,7 @@ Fixpoint print_Val2(k : Kind)(ff : FullFormat k) : eval_Kind k -> string := match ff with | FBool n _ => fun x => pad_with " " n (if x then "tt" else "ff") | FBit n m bf => fun x => pad_with "0" m (print_BF bf x) - | FStruct n fk fs ffs => fun x => (("{ " ++ String.concat " ; " (v_to_list ((tup_to_vec _ (fun i => print_Val2 (ffs i)) x)))) ++ " }")%string + | FStruct n fk fs ffs => fun x => (("{ " ++ String.concat " ; " (Vector.to_list ((tup_to_vec _ (fun i => print_Val2 (ffs i)) x)))) ++ " }")%string | FArray n k' ff' => fun x => ("[ " ++ String.concat " ; " (List.map (fun i => print_Val2 ff' (vector_index i x)) (getFins n)) ++ " ]")%string end. @@ -88,8 +100,8 @@ Definition default_val_FK(k : FullKind) : eval_FK k := | NativeKind T t => t end. -Fixpoint rand_tuple{n} : forall ts : Fin.t n -> Type, (forall i, IO (ts i)) -> IO (Tuple ts) := - match n with +Fixpoint rand_tuple{n} : forall ts : Fin n -> Type, (forall i, IO (ts i)) -> IO (Tuple ts) := + match n as m return forall ts : Fin m -> Type, (forall i, IO (ts i)) -> IO (Tuple ts) with | 0 => fun _ _ => ret tt | S m => fun ts mxs => ( do x <- mxs Fin.F1; @@ -146,10 +158,10 @@ Definition eval_BinBit{m n p}(op : BinBitOp m n p) : BV m -> BV n -> BV p := Definition eval_CABit{n}(op : CABitOp) : list (BV n) -> BV n := match op with - | Add => bv_add | Mul => bv_mul | Band => bv_band | Bxor => bv_bxor + | Add => bv_add end. Definition eval_BinBitBool{m n}(op : BinBitBoolOp m n) : BV m -> BV n -> bool := @@ -206,22 +218,21 @@ Fixpoint eval_Expr{k}(e : Expr eval_Kind k) : eval_FK k := | @FromNative _ k' e' => eval_KindFromType (eval_Expr e') end. -Fixpoint get_chunk_struct{n} : forall (f : Fin.t n -> nat)(v : BV (sumSizes f))(i : Fin.t n), BV (f i) := +Fixpoint get_chunk_struct{n} : forall (f : Fin n -> nat)(v : BV (sumSizes f))(i : Fin n), BV (f i) := match n with - | 0 => fun f _ i => case0 (fun j => BV (f j)) i + | 0 => fun f _ i => Fin.case0 (fun j => BV (f j)) i | _ => fun f v i => fin_case _ (fun j => BV (f j)) (bv_trunc_msb v) (fun j => get_chunk_struct (fun k => f (FS k)) (bv_trunc_lsb v) j) end. -Fixpoint get_chunk_array{n} : forall (k : nat)(v : BV (n * k))(i : Fin.t n), BV k. - refine match n with - | 0 => fun _ _ i => case0 _ i - | S m => fun k v i => _ +Fixpoint get_chunk_array{n} : forall (k : nat)(v : BV (n * k))(i : Fin n), BV k := + match n with + | 0 => fun _ _ i => Fin.case0 _ i + | S m => fun k (v : BV ((S m) * k)) (i : Fin (S m)) => + match i with + | inl _ => @bv_trunc_lsb k (m * k) v + | inr j => @get_chunk_array m k (@bv_trunc_msb k (m * k) v) j + end end. -Proof. - dependent destruction i. - - exact (bv_trunc_lsb v). - - exact (get_chunk_array _ _ (bv_trunc_msb v) i). -Defined. Fixpoint val_unpack(k : Kind) : BV (size k) -> eval_Kind k := match k return BV (size k) -> eval_Kind k with diff --git a/Simulator/CoqSim/HaskellTypes.v b/Simulator/CoqSim/HaskellTypes.v index bcad94d..557fb00 100644 --- a/Simulator/CoqSim/HaskellTypes.v +++ b/Simulator/CoqSim/HaskellTypes.v @@ -1,7 +1,8 @@ Require Extraction. -Require Import String Fin. +Require Import String. Require Import Kami.All. Require Import Kami.Simulator.CoqSim.Misc. +Require Import Kami.StdLib.Fin. Extraction Language Haskell. @@ -32,22 +33,26 @@ Extract Constant map_of_list => "Data.Map.Strict.fromList". Parameter Vector : nat -> Type -> Type. -Parameter vector_index : forall {X n}, Fin.t n -> Vector n X -> X. +Parameter vector_index_aux : forall {X n}, nat -> Vector n X -> X. + +Definition vector_index {X n} (p : Fin n) : Vector n X -> X := + vector_index_aux (proj1_sig (Fin.to_nat p)). + Parameter vector_map : forall {X Y n}, (X -> Y) -> Vector n X -> Vector n Y. Parameter vector_eq : forall {X n}, (X -> X -> bool) -> Vector n X -> Vector n X -> bool. Parameter vector_to_list : forall {X n}, Vector n X -> list X. -Parameter make_vector : forall {X n}, (Fin.t n -> X) -> Vector n X. +Parameter make_vector : forall {X n}, (Fin n -> X) -> Vector n X. Parameter vector_slice : forall {X n} (i m : nat), Vector n X -> Vector m X. Parameter vector_updates : forall {X n}, Vector n X -> list (nat * X) -> Vector n X. -Fixpoint Fin_to_list{X n} : (Fin.t n -> X) -> list X := +Fixpoint Fin_to_list{X n} : (Fin n -> X) -> list X := match n with - | 0 => fun _ => nil - | S m => fun f => cons (f Fin.F1) (Fin_to_list (fun i => f (Fin.FS i))) + | 0 => fun _ => List.nil + | S m => fun f => List.cons (f F1) (Fin_to_list (fun i => f (FS i))) end. Extract Constant Vector "a" => "Data.Vector.Vector a". -Extract Constant vector_index => "(\_ (n,i) v -> v Data.Vector.! i)". +Extract Constant vector_index_aux => "(\_ i v -> v Data.Vector.! i)". Extract Constant vector_map => "(\_ -> Data.Vector.map)". Extract Constant vector_eq => "(\_ eqb v1 v2 -> Data.Vector.foldr (Prelude.&&) Prelude.True (Data.Vector.zipWith eqb v1 v2))". Extract Constant vector_to_list => "(\ _ -> Data.Vector.toList)". @@ -99,7 +104,7 @@ Extract Constant bv_srl => "(\_ _ -> Data.BitVector.shr)". Extract Constant bv_sra => "(\_ _ -> Data.BitVector.ashr)". Extract Constant bv_concat => "(\_ _ -> (Data.BitVector.#))". Extract Constant bv_add => "(\_ -> Prelude.foldr (Prelude.+) 0)". -Extract Constant bv_mul => "(\_ -> Prelude.foldr (Prelude.*) 1)". +Extract Constant bv_mul => "(\_ -> Prelude.foldr (Prelude.* ) 1)". Extract Constant bv_band => "(\n -> Prelude.foldr (Data.Bits..&.) (Data.BitVector.ones n))". Extract Constant bv_bor => "(\n -> Prelude.foldr (Data.Bits..|.) (Data.BitVector.zeros n))". Extract Constant bv_bxor => "(\n -> Prelude.foldr Data.Bits.xor (Data.BitVector.zeros n))". @@ -160,4 +165,4 @@ Extract Constant Harr_updates => "(\a ps -> Control.Monad.foldM (\_ (i,e) -> Dat Extract Constant Arr "a" => "Data.Vector.Mutable.MVector (Control.Monad.Primitive.PrimState Prelude.IO) a". Extract Constant arr_repl => "Data.Vector.Mutable.replicate". Extract Constant arr_slice => "(\i m a -> Data.Vector.Generic.freeze (Data.Vector.Mutable.slice i m a))". -Extract Constant arr_updates => "(\a ps -> Control.Monad.foldM (\_ (i,x) -> Data.Vector.Mutable.write a i x) () ps)". \ No newline at end of file +Extract Constant arr_updates => "(\a ps -> Control.Monad.foldM (\_ (i,x) -> Data.Vector.Mutable.write a i x) () ps)". diff --git a/Simulator/CoqSim/Misc.v b/Simulator/CoqSim/Misc.v index 84ca8c6..a8da2e2 100644 --- a/Simulator/CoqSim/Misc.v +++ b/Simulator/CoqSim/Misc.v @@ -1,99 +1,41 @@ -Require Import Fin Bool Kami.Lib.EclecticLib String Ascii List Streams. +Require Import Bool Kami.Lib.EclecticLib String Ascii List Streams. +Require Import Kami.StdLib.Vector. +Require Import Kami.StdLib.Fin. Import ListNotations. -Fixpoint Fin n := - match n with - | 0 => Empty_set - | S m => (unit + Fin m)%type - end. - -Section Vector. - -Fixpoint Vec X n : Type := - match n with - | 0 => unit - | S m => (X * Vec X m)%type - end. - -Fixpoint vec_index{n X} : Fin.t n -> Vec X n -> X := - match n with - | 0 => case0 _ - | S m => fun i v => fin_case i _ (fst v) (fun j => vec_index j (snd v)) - end. - -Fixpoint mkVec{n X} : (Fin.t n -> X) -> Vec X n := - match n with - | 0 => fun _ => tt - | S m => fun f => (f Fin.F1, mkVec (fun j => f (Fin.FS j))) - end. - -Fixpoint VecEq{n X} : (X -> X -> bool) -> Vec X n -> Vec X n -> bool := - match n with - | 0 => fun _ _ _ => true - | S m => fun eq v1 v2 => eq (fst v1) (fst v2) && VecEq eq (snd v1) (snd v2) - end. - -Fixpoint vmap{n X Y}(f : X -> Y) : Vec X n -> Vec Y n := - match n with - | 0 => fun _ => tt - | S m => fun '(x,xs) => (f x, vmap f xs) - end. - -Fixpoint v_to_list{n X} : Vec X n -> list X := - match n with - | 0 => fun _ => [] - | S m => fun '(x,xs) => x::v_to_list xs - end. - -Fixpoint add_indices_aux{n X} : nat -> Vec X n -> Vec (nat * X) n := - match n return nat -> Vec X n -> Vec (nat * X) n with - | 0 => fun _ _ => tt - | S m => fun acc '(x,xs) => ((acc,x), add_indices_aux (S acc) xs) - end. - -Definition add_indices{n X} : Vec X n -> Vec (nat * X) n := add_indices_aux 0. - -Fixpoint add_strings{n X} : (Fin.t n -> string) -> Vec X n -> Vec (string * X) n := - match n return (Fin.t n -> string) -> Vec X n -> Vec (string * X) n with - | 0 => fun _ _ => tt - | S m => fun strs '(x,xs) => ((strs Fin.F1,x),add_strings (fun j => strs (Fin.FS j)) xs) - end. - -End Vector. - Section Tuple. -Fixpoint Tuple{n} : (Fin.t n -> Type) -> Type := +Fixpoint Tuple{n} : (Fin n -> Type) -> Type := match n with | 0 => fun _ => unit | S m => fun ts => ((ts Fin.F1) * (Tuple (fun j => ts (Fin.FS j))))%type end. -Fixpoint Tup_map{n} : forall (ts1 ts2 : Fin.t n -> Type)(fs : forall i, ts1 i -> ts2 i)(t : Tuple ts1), Tuple ts2 := +Fixpoint Tup_map{n} : forall (ts1 ts2 : Fin n -> Type)(fs : forall i, ts1 i -> ts2 i)(t : Tuple ts1), Tuple ts2 := match n with | 0 => fun _ _ _ _ => tt | S m => fun ts1 ts2 fs t => (fs F1 (fst t), Tup_map (fun i => ts1 (FS i)) (fun i => ts2 (FS i)) (fun i => fs (FS i)) (snd t)) end. -Fixpoint tup_index{n} : forall (i : Fin.t n) ts, Tuple ts -> ts i := +Fixpoint tup_index{n} : forall (i : Fin n) ts, Tuple ts -> ts i := match n with - | 0 => case0 _ + | 0 => case0 (fun i => _) | S m => fun i ts t => fin_case i _ (fst t) (fun j => tup_index j (fun j => ts (Fin.FS j)) (snd t)) end. -Fixpoint mkTup{n} : forall ts : Fin.t n -> Type, (forall i, ts i) -> Tuple ts := +Fixpoint mkTup{n} : forall ts : Fin n -> Type, (forall i, ts i) -> Tuple ts := match n with | 0 => fun _ _ => tt | S m => fun ts es => (es Fin.F1, mkTup (fun j => ts (Fin.FS j)) (fun j => es (Fin.FS j))) end. -Fixpoint TupEq{n} : forall ts : Fin.t n -> Type, (forall i, ts i -> ts i -> bool) -> Tuple ts -> Tuple ts -> bool := +Fixpoint TupEq{n} : forall ts : Fin n -> Type, (forall i, ts i -> ts i -> bool) -> Tuple ts -> Tuple ts -> bool := match n with | 0 => fun _ _ _ _ => true | S m => fun ts eqs t1 t2 => eqs Fin.F1 (fst t1) (fst t2) && TupEq (fun j => ts (Fin.FS j)) (fun j => eqs (Fin.FS j)) (snd t1) (snd t2) end. -Fixpoint tup_to_vec{n X} : forall (ts : Fin.t n -> Type)(to_X : forall i, ts i -> X), Tuple ts -> Vec X n := +Fixpoint tup_to_vec{n X} : forall (ts : Fin n -> Type)(to_X : forall i, ts i -> X), Tuple ts -> Vec X n := match n with | 0 => fun _ _ _ => tt | S m => fun ts to_X '(x,t) => (to_X Fin.F1 x, tup_to_vec (fun j => ts (Fin.FS j)) (fun j => to_X (Fin.FS j)) t) @@ -103,8 +45,8 @@ End Tuple. Section Lookup. -Fixpoint Fin_lookup{X}(pred : X -> bool){n} : (Fin.t n -> X) -> option (Fin.t n) := - match n return (Fin.t n -> X) -> option (Fin.t n) with +Fixpoint Fin_lookup{X}(pred : X -> bool){n} : (Fin n -> X) -> option (Fin n) := + match n return (Fin n -> X) -> option (Fin n) with | 0 => fun _ => None | S m => fun f => if pred (f F1) then Some F1 else match Fin_lookup pred (fun j => f (FS j)) with diff --git a/Simulator/CoqSim/RegisterFile.v b/Simulator/CoqSim/RegisterFile.v index a977e4a..3b612f4 100644 --- a/Simulator/CoqSim/RegisterFile.v +++ b/Simulator/CoqSim/RegisterFile.v @@ -1,5 +1,4 @@ Require Import String. -Require Import FinFun. Require Import Kami.AllNotations. Require Import Kami.Syntax. @@ -8,6 +7,9 @@ Require Import Kami.Simulator.CoqSim.Misc. Require Import Kami.Simulator.CoqSim.Eval. Require Import Kami.Simulator.CoqSim.HaskellTypes. Import Kami.Simulator.CoqSim.HaskellTypes.Notations. +Require Import Kami.StdLib.Vector. +Require Import Kami.StdLib.Fin. +Require Import List. Section RegFile. @@ -100,10 +102,10 @@ Defined. Definition file_writes_mask(file : RegFile)(i : nat)(mask : Val (Array (chunk_size file) Bool))(vals : Val (Array (chunk_size file) (kind file))) : list (nat * Val (kind file)) := let mask_indices := filter (fun i => vector_index i mask) (getFins _) in - map (fun j => (i + Fin2Restrict.f2n j, vector_index j vals)) mask_indices. + map (fun j => (i + Fin.f2n j, vector_index j vals)) mask_indices. Definition file_writes_no_mask(file : RegFile)(i : nat)(vals : Val (Array (chunk_size file) (kind file))) : list (nat * Val (kind file)) := - map (fun j => (i + Fin2Restrict.f2n j, vector_index j vals)) (getFins _). + map (fun j => (i + Fin.f2n j, vector_index j vals)) (getFins _). Definition void_nil : {k : Kind & Val k} := existT _ (Bit 0) (nat_to_bv 0). @@ -119,8 +121,8 @@ Proof. * exact (error "Kind mismatch."). Defined. -Fixpoint Tup_lookup{n} : forall (i : Fin.t n)(ks : Fin.t n -> Kind), Tuple (fun i => Val (ks i)) -> {k : Kind & Val k} := - match n return forall (i : Fin.t n)(ks : Fin.t n -> Kind), Tuple (fun i => Val (ks i)) -> {k : Kind & Val k} with +Fixpoint Tup_lookup{n} : forall (i : Fin n)(ks : Fin n -> Kind), Tuple (fun i => Val (ks i)) -> {k : Kind & Val k} := + match n return forall (i : Fin n)(ks : Fin n -> Kind), Tuple (fun i => Val (ks i)) -> {k : Kind & Val k} with | 0 => fun i => case0 _ i | S m => fun i ks X => fin_case i _ (existT _ (ks F1) (fst X)) (fun j => (Tup_lookup j _ (snd X))) end. diff --git a/Simulator/CoqSim/TransparentProofs.v b/Simulator/CoqSim/TransparentProofs.v index 0a0dc32..ebb9b23 100644 --- a/Simulator/CoqSim/TransparentProofs.v +++ b/Simulator/CoqSim/TransparentProofs.v @@ -72,13 +72,17 @@ Proof. pose (proj1 (Fin_forallb_correct _) H0). rewrite (hedberg Nat.eq_dec _ eq_refl) in e. simpl in e. - apply e. + specialize e with i. + rewrite <- (cast_eq i (eq_refl)) in e. + exact e. apply String_eqb_eq2. apply andb_true_iff in H0; destruct H0. pose (proj1 (Fin_forallb_correct _) H1). rewrite (hedberg Nat.eq_dec _ eq_refl) in e. simpl in e. - apply e. + specialize e with i. + rewrite <- (cast_eq i (eq_refl)) in e. + exact e. + simpl in H0. rewrite (@silly_lemma_false) in H0 by auto; discriminate. - rewrite H0; apply Kind_decb_refl. @@ -131,4 +135,4 @@ Proof. - right; congruence. - right; congruence. - right; congruence. -Defined. \ No newline at end of file +Defined. diff --git a/StdLib/Fin.v b/StdLib/Fin.v new file mode 100644 index 0000000..4676367 --- /dev/null +++ b/StdLib/Fin.v @@ -0,0 +1,507 @@ +Require Import Bool String Ascii List. +Require Import Lt. +Require Import Nat. +Require Import PeanoNat. +Require Import ProofIrrelevance. +Import ListNotations. + +Fixpoint Fin (n : nat) := + match n with + | 0 => Empty_set + | S m => (unit + Fin m)%type + end. + +Definition F1 {n : nat} : Fin (S n) := inl tt. + +Definition FS {n : nat} (i : Fin n) : Fin (S n) := @inr unit (Fin n) i. + +Definition case0 (F : Fin 0 -> Type) (x : Fin 0) : F x := Empty_set_rect F x. + + +Definition caseS' {n : nat} (p : Fin (S n)) : forall (P : Fin (S n) -> Type) + (P1 : P F1) + (PS : forall (p : Fin n), P (FS p)), P p. +Proof. + destruct p. + - intros; destruct u; exact P1. + - intros; apply PS. +Defined. + +Definition caseS (P: forall {n}, Fin (S n) -> Type) + (P1: forall n, @P n F1) (PS : forall {n} (p: Fin n), P (FS p)) + {n} (p: Fin (S n)) : P p := caseS' p P (P1 n) PS. + +Lemma FS_inj : forall n (i j: Fin n) (H: FS i = FS j), i = j. +Proof. + intro n. + induction n. + + exact (case0 (fun i => forall j : Fin 0, FS i = FS j -> i = j)). + + intros i j H. injection H. trivial. +Qed. + +Fixpoint to_nat {n : nat} : forall (i : Fin n), {m : nat | m < n} := + match n return forall i : Fin n, {m : nat | m < n} with + | 0 => case0 (fun _ => {m : nat | m < 0}) + | S m => + fun i : Fin (S m) => + match i with + | inl _ => exist _ 0 (Nat.lt_0_succ m) + | inr j => + exist (fun o => o < S m) + (S (proj1_sig (@to_nat m j))) + (Nat.le_lt_trans + (S (proj1_sig (@to_nat m j))) m (S m) + (proj2_sig (@to_nat m j)) + (Nat.lt_succ_diag_r m)) + end + end. + +Definition f2n {n : nat} (i : Fin n) := proj1_sig (to_nat i). + +Goal proj1_sig (@to_nat 1 F1) = 0. Proof. reflexivity. Qed. + +Goal proj1_sig (@to_nat 4 F1) = 0. Proof. reflexivity. Qed. + +Goal proj1_sig (@to_nat 4 (FS (FS (FS F1)))) = 3. Proof. reflexivity. Qed. + +Fixpoint of_nat (x n : nat) : (Fin n) + {exists m, x = n + m}. + refine + match n with + | 0 => inright (ex_intro (fun m => x = 0 + m) x (Nat.add_0_l x)) + | S n => + match x with + | 0 => inleft F1 + | S x => + match of_nat x n with + | inleft i => inleft (FS i) + | inright H => + inright + (ex_ind + (fun m (Hm : x = n + m) => + ex_intro + (fun o => S x = S n + o) m _ + ) + H) + end + end + end. +Proof. + abstract (rewrite Hm; rewrite <- (plus_Sn_m n m); reflexivity). +Defined. + +Fixpoint of_nat_lt {x n : nat} : x < n -> Fin n := + match n with + | 0 => fun H : x < 0 => False_rect _ (Nat.nlt_0_r x H) + | S n => + match x with + | 0 => fun _ => F1 + | S x => fun H : S x < S n => FS (@of_nat_lt x n (Lt.lt_S_n x n H)) + end + end. + +Lemma to_nat_of_nat_inv : forall (x n : nat) (H : x < n), proj1_sig (@to_nat n (@of_nat_lt x n H)) = x. +Proof. + induction x. + + induction n. + - intro Hcontr; contradict Hcontr; exact (Nat.nlt_0_r 0). + - intro H; simpl; reflexivity. + + induction n. + - intro Hcontr; contradict Hcontr; exact (Nat.nlt_0_r (S x)). + - intro H. + simpl. + exact (f_equal S (IHx n (lt_S_n _ _ H))). +Qed. + +Lemma of_nat_ext {n} {m} (H H0 : n < m) : of_nat_lt H = of_nat_lt H0. +Proof. + rewrite (proof_irrelevance (n < m) H H0). + reflexivity. +Qed. + +Lemma to_nat_S : forall n (i : Fin n), S (proj1_sig (@to_nat n i)) = proj1_sig (@to_nat (S n) (inr i)). +Proof. + induction n as [|m HInd]. + + exact (case0 (fun i => S (proj1_sig (to_nat i)) = (proj1_sig (@to_nat (S 0) (inr i))))). + + intro i; simpl; reflexivity. +Qed. + +Lemma of_nat_to_nat_inv {n} : forall i : Fin n, of_nat_lt (proj2_sig (to_nat i)) = i. +Proof. + induction n. + + exact (case0 (fun i => of_nat_lt (proj2_sig (to_nat i)) = i)). + + intro i. + destruct i as [j|j]. + - destruct j; reflexivity. + - simpl. + apply (f_equal (@inr unit (Fin n))). + rewrite <- (IHn j) at 5. + apply (@of_nat_ext (proj1_sig (to_nat j)) n). +Qed. + +Lemma of_nat_to_nat_FS {n} : + forall (i : Fin n), + of_nat_lt (proj2_sig (to_nat (FS i))) = + FS (of_nat_lt (proj2_sig (to_nat i))). +Proof. + induction n. + + exact + (case0 (fun i => + of_nat_lt (proj2_sig (to_nat (FS i))) = + FS (of_nat_lt (proj2_sig (to_nat i))))). + + intro i. + destruct i as [i|i]. + - reflexivity. + - simpl. + apply (f_equal (fun j => @inr unit (Fin (S n)) (@inr unit (Fin n) j))). + apply (@of_nat_ext (proj1_sig (to_nat i)) n). +Qed. + +Lemma to_nat_of_nat {n} {m} (H : n < m) : to_nat (of_nat_lt H) = exist _ n H. +Proof. + induction m. + + exact (False_ind _ (Nat.nlt_0_r _ H)). + + destruct n as [|n]. + - simpl; rewrite (proof_irrelevance (0 < S m) (Nat.lt_0_succ m) H); reflexivity. + - simpl. + apply (eq_sig_hprop (fun o H H0 => proof_irrelevance (o < S m) H H0) _ _). + simpl. + exact (f_equal S (to_nat_of_nat_inv n m (lt_S_n n m H))). +Qed. + +Lemma to_nat_inj {n} (i j : Fin n) : proj1_sig (to_nat i) = proj1_sig (to_nat j) -> i = j. +Proof. + induction n. + + exact (case0 (fun _ => proj1_sig (to_nat i) = proj1_sig (to_nat j) -> i = j) i). + + destruct i, j. + - destruct u, u0. exact (fun _ => eq_refl). + - exact (fun Hcontr : 0 = S _ => False_ind _ (O_S _ Hcontr)). + - exact (fun Hcontr : S _ = 0 => False_ind _ (O_S _ (eq_sym Hcontr))). + - exact (fun H => f_equal (@inr unit (Fin n)) (IHn f f0 (Nat.succ_inj _ _ H) )). +Qed. + +Lemma eq_dec {n} (i j : Fin n): {i = j} + {i <> j}. +Proof. + induction n. + + exact (case0 (fun _ => {i = j}+{i <> j}) i). + + destruct i, j. + - destruct u, u0; exact (left _ (eq_refl)). + - right; discriminate. + - right; discriminate. + - destruct (IHn f f0) as [Heq|Hneq]. + * exact (left _ (f_equal (@inr unit (Fin n)) Heq)). + * right; intro Heq; injection Heq; exact (Hneq). +Defined. + +Definition eqb {n m} (i : Fin n) (j : Fin m) : bool := + match Nat.eq_dec n m with + | left Heq => + if eq_dec i (eq_rect_r Fin j Heq) + then true + else false + | _ => false + end. + +Lemma eqb_nat_eq : forall n m (i : Fin n) (j : Fin m), eqb i j = true -> n = m. +Proof. + intros n m i j H. + unfold eqb in H. + destruct (Nat.eq_dec n m) as [Heq|Hneq] in H. + + assumption. + + contradict H; discriminate. +Qed. + +Lemma eqb_eq : forall n (i j : Fin n), eqb i j = true <-> i = j. +Proof. + intros n i j. + unfold eqb. + destruct (Nat.eq_dec n n) as [Heq|Hneq]. + + unfold eq_rect_r. + rewrite <- (Eq_rect_eq.eq_rect_eq nat n (fun y : nat => Fin y) j (eq_sym Heq)). + destruct (eq_dec i j) as [H|H]. + - exact (conj (fun _ => H) (fun _ => eq_refl)). + - split. + * discriminate. + * contradiction. + + contradiction. +Qed. + +Local Fixpoint cast_aux {n} {m} : n = m -> Fin n -> Fin m := + match n, m with + | 0, m => fun _ => case0 (fun _ => Fin m) + | S n, 0 => fun (H : S n = 0) _ => False_rect (Fin 0) (Nat.neq_succ_0 n H) + | S n, S m => + fun (H : S n = S m) i => + match i with + | inl _ => @F1 m + | inr j => FS (cast_aux (Nat.succ_inj n m H) j) + end + end. + +Definition cast {n} (i: Fin n) {m} (H : n = m) : Fin m := cast_aux H i. + +Lemma cast_eq {n} : forall (i : Fin n) (H : n = n), i = cast i H. +Proof. + induction n. + + exact (case0 (fun i => forall H : 0 = 0, i = cast i H)). + + induction i as [i|j]. + - intro H; simpl; destruct i; reflexivity. + - intro H. + exact (f_equal (@inr unit (Fin n)) (IHn j (Nat.succ_inj n n H))). +Qed. + +Fixpoint Fin_foldr {A : Type} (n : nat) (init : A) : forall (f : Fin n -> A -> A), A := + match n with + | 0 => fun _ => init + | S n => fun f => f F1 (Fin_foldr n init (fun i => f (FS i))) + end. + +Fixpoint nth_Fin {A : Type} (xs : list A) : Fin (length xs) -> A := + match xs with + | [] => case0 (fun _ => A) + | x :: xs => + fun i => + match i with + | inl _ => x + | inr j => nth_Fin xs j + end + end. + +Definition nth_Fin' {A : Type} (xs : list A) {n : nat} (H : n = length xs) (i : Fin n) : A := + nth_Fin xs (cast i H). + +Fixpoint nth_Fin'' {A : Type} (xs : list A) : forall n : nat, n <= length xs -> Fin n -> A := + match xs as ys return forall n, n <= length ys -> Fin n -> A with + | [] => + fun n (H : n <= 0) (i : Fin n) => + let Heq : n = 0 := eq_sym (Le.le_n_0_eq n H) in + case0 (fun _ => A) (cast i Heq) + | y :: ys => + fun n => + match n as m return m <= length (y :: ys) -> Fin m -> A with + | 0 => fun _ => case0 (fun _ => A) + | S m => + fun (H : S m <= S (length ys)) (i : Fin (S m)) => + match i with + | inl _ => y + | inr j => nth_Fin'' ys m (le_S_n _ _ H) j + end + end + end. + +Lemma nth_Fin'_nth {A : Type} : forall (n : nat) (default : A) (i : Fin n) (xs : list A) (H : n = length xs), + nth_Fin' xs H i = nth (proj1_sig (to_nat i)) xs default. +Proof. + induction n. + + exact (fun default => case0 (fun i => forall xs (H : 0 = length xs), nth_Fin' xs H i = nth (proj1_sig (to_nat i)) xs default)). + + intro default. + destruct i as [i|j]. + - destruct i. + destruct xs as [|x xs]. + * exact (fun Hcontr => False_ind _ (Nat.neq_succ_0 _ Hcontr)). + * intro H. + unfold nth_Fin'. + simpl. + reflexivity. + - destruct xs as [|x xs]. + * intro H. + contradict H. + exact (Nat.neq_succ_0 n). + * intro H. + unfold nth_Fin'. + simpl. + unfold nth_Fin' in IHn. + exact (IHn default j xs (Nat.succ_inj n (length xs) H)). +Qed. + +Lemma nth_Fin_nth {A : Type} (default : A) : forall (xs : list A) (i : Fin (length xs)), + nth_Fin xs i = nth (proj1_sig (to_nat i)) xs default. +Proof. + induction xs. + + exact (case0 (fun i => nth_Fin [] i = nth (proj1_sig (to_nat i)) [] default)). + + destruct i as [i|j]. + - reflexivity. + - exact (IHxs j). +Qed. + +Definition fin_case {n : nat} + (i : Fin (S n)) : forall + (F : Fin (S n) -> Type) + (f1 : F F1) + (fs : forall j, F (FS j)), F i := + match i return + forall F : Fin (S n) -> Type, + F (@inl unit (Fin n) tt) -> + (forall j : Fin n, F (@FS n j)) -> + F i with + | inl u => match u with tt => fun _ f1 _ => f1 end + | inr j => fun _ _ fs => fs j + end. + +(* TODO: LLEE: replace the fin_dep_destruct tactic from EclecticLib. *) + +Lemma Fin_cast_lemma : forall (n m : nat) (i : Fin n) (H H0 : n = m), cast i H = cast i H0. +Proof. + intros n m i H H0. + rewrite (proof_irrelevance (n = m) H H0). + reflexivity. +Qed. + +Lemma fin_to_nat_cast : forall (n : nat) (i: Fin n) (m : nat) (H: n = m), + proj1_sig (to_nat (cast i H)) = proj1_sig (to_nat i). +Proof. + induction n as [|n]. + + exact (case0 (fun i => forall m H, proj1_sig (to_nat (cast i H)) = proj1_sig (to_nat i))). + + destruct i as [i|j]. + - destruct m as [|m]. + * intro H; contradict H; exact (Nat.neq_succ_0 n). + * intro H; reflexivity. + - destruct m as [|m]. + * intro H; contradict H; exact (Nat.neq_succ_0 n). + * intro H. + simpl. + exact (f_equal S (IHn j m (Nat.succ_inj n m H))). +Qed. + +Fixpoint map_length_red {A B : Type} (f : A -> B) (xs : list A) := + match xs return length (map f xs) = length xs with + | [] => ltac:(reflexivity) + | y :: ys => + f_equal_nat nat S + (length (map f ys)) + (length ys) + (map_length_red f ys) + end. + +Fixpoint nth_Fin_map2_aux {A B : Type} (f : A -> B) (F : B -> Type) (xs : list A) : + forall (i : Fin (length (map f xs))) + (j : Fin (length xs)), + f2n i = f2n j -> + F (f (nth_Fin xs j)) -> + F (nth_Fin (map f xs) i) := + match xs return + forall (i : Fin (length (map f xs))) + (j : Fin (length xs)), + f2n i = f2n j -> + F (f (nth_Fin xs j)) -> + F (nth_Fin (map f xs) i) with + | [] => case0 (fun i => _) + | y :: ys => + fun (i : Fin (length (map f (y :: ys)))) + (j : Fin (length (y :: ys))) => + match i as k return + @f2n (length (map f (y :: ys))) k = @f2n (length (y :: ys)) j -> + F (f (nth_Fin (y :: ys) j)) -> + F (nth_Fin (map f (y :: ys)) k) with + | inl u => + match j as k return + f2n (inl u : Fin (length (map f (y :: ys)))) = @f2n (length (y :: ys)) k -> + F (f (nth_Fin (y :: ys) k)) -> + F (nth_Fin (map f (y :: ys)) (inl u : Fin (length (map f (y :: ys))))) with + | inl v => + match u, v with + | tt, tt => + fun _ (x : F (f (nth_Fin (y :: ys) (inl tt : Fin (length (y :: ys)))))) => x + end + | inr _ => fun H : 0 = S _ => False_rect _ (Nat.neq_0_succ _ H) + end + | inr a => + match j as k return + f2n (inr a : Fin (length (map f (y :: ys)))) = @f2n (length (y :: ys)) k -> + F (f (nth_Fin (y :: ys) k)) -> + F (nth_Fin (map f (y :: ys)) (inr a : Fin (length (map f (y :: ys))))) with + | inl _ => fun H : S _ = 0 => False_rect _ (Nat.neq_succ_0 _ H) + | inr b => + fun H => + nth_Fin_map2_aux f F ys + (a : Fin (length (map f ys))) + (b : Fin (length ys)) + (Nat.succ_inj _ _ H) + end + end + end. + +Definition nth_Fin_map2 {A B : Type} (f : A -> B) (F : B -> Type) (xs : list A) (i : Fin (length (map f xs))) : + F (f (nth_Fin xs (cast i (map_length_red f xs)))) -> + F (nth_Fin (map f xs) i) := + nth_Fin_map2_aux f F xs i (cast i (map_length_red f xs)) + (eq_sym (fin_to_nat_cast + (length (map f xs)) i + (length xs) + (map_length_red f xs))). + +Fixpoint Fin_forallb {n} : (Fin n -> bool) -> bool := + match n with + | 0 => fun _ => true + | S n => fun f => f F1 && Fin_forallb (fun i => f (FS i)) + end. + +Lemma Fin_forallb_correct {n} : + forall f : Fin n -> bool, + Fin_forallb f = true <-> forall i, f i = true. +Proof. + induction n as [|n]. + + intro f; split; intro H. + - exact (case0 (fun i => f i = true)). + - unfold Fin_forallb; reflexivity. + + intro f; split; intro H. + destruct i as [u|j]. + - destruct u; simpl; unfold Fin_forallb in H; exact (proj1 (andb_prop (f F1) _ H)). + - unfold Fin_forallb in H. + fold (@Fin_forallb n (fun i : Fin n => f (FS i))) in H. + exact (proj1 (IHn (fun i => f (FS i))) (proj2 (andb_prop _ _ H)) j). + - simpl; apply (andb_true_intro); split. + * exact (H F1). + * exact (proj2 (IHn (fun i => f (FS i))) (fun i => H (FS i))). +Qed. + +Fixpoint getFins (n : nat) : list (Fin n) := + match n with + | 0 => [] + | S m => F1 :: map FS (getFins m) + end. + +Fixpoint getFinsBound n m : list (Fin m) := + match n with + | 0 => [] + | S n => + match m with + | 0 => [] + | S m => F1 :: map FS (getFinsBound n m) + end + end. + +Definition mapOrFins {n : nat} (i : Fin n) := fold_left (fun acc x => i = x \/ acc) (getFins n) False. + +Lemma getFins_length : forall n, length (getFins n) = n. +Proof. + induction n as [|n IH]. + + reflexivity. + + simpl; rewrite map_length; rewrite IH; reflexivity. +Qed. + +Lemma getFins_all : forall n (i : Fin n), In i (getFins n). +Proof. + induction n as [|n IH]. + + exact (case0 (fun i => In i (getFins 0))). + + destruct i as [u|j]. + - destruct u; left; reflexivity. + - simpl. right. exact (in_map FS (getFins n) j (IH j)). +Qed. + +Lemma getFins_nth_error : forall {n} (i : Fin n), + nth_error (getFins n) (proj1_sig (to_nat i)) = Some i. +Proof. + induction n as [|n IH]. + + exact (case0 (fun i => nth_error (getFins 0) (proj1_sig (to_nat i)) = Some i)). + + destruct i as [u|j]. + - destruct u; reflexivity. + - simpl; exact (@map_nth_error _ _ FS (proj1_sig (to_nat j)) (getFins n) j (IH j)). +Qed. + +Lemma getFins_nth {A : Type} : forall n (default : Fin n) (i : Fin n), + nth (proj1_sig (to_nat i)) (getFins n) default = i. +Proof. + intros. + apply nth_error_nth. + apply getFins_nth_error. +Qed. diff --git a/StdLib/StdEquiv.v b/StdLib/StdEquiv.v new file mode 100644 index 0000000..f4b9627 --- /dev/null +++ b/StdLib/StdEquiv.v @@ -0,0 +1,55 @@ +Require Import Kami.StdLib.Fin. +Require Import Kami.StdLib.VectorDef. +Require Import Vector. + +Fixpoint fin_to_fin_new {n} (i : Fin.t n) : Fin n := + match i with + | Fin.F1 => F1 + | Fin.FS j => FS (fin_to_fin_new j) + end. + +Fixpoint fin_new_to_fin {n} (i : Fin n) : Fin.t n := + match n as n0 return (Fin n0 -> Fin.t n0) with + | O => StdLib.Fin.case0 (fun _ => Fin.t 0) + | S m => + (fun j => + match j with + | inl _ => Fin.F1 + | inr k => Fin.FS (fin_new_to_fin k) + end) + end i. + +Lemma F2FN_id {n} (i : Fin.t n) : + fin_new_to_fin (fin_to_fin_new i) = i. +Proof. induction i; simpl; [|rewrite IHi]; reflexivity. Qed. + +Lemma FN2F_id {n} (i : Fin n) : + fin_to_fin_new (fin_new_to_fin i) = i. +Proof. + induction n; destruct i; simpl. + - destruct u; reflexivity. + - rewrite IHn; reflexivity. +Qed. + +Fixpoint vec_to_vec_new {A n} (v : Vector.t A n) : Vec A n := + match v with + | Vector.nil _ => Kami.StdLib.VectorDef.nil _ + | Vector.cons _ h _ tl => Kami.StdLib.VectorDef.cons _ h _ (vec_to_vec_new tl) + end. + +Fixpoint vec_new_to_vec {A n} (v : Vec A n) : Vector.t A n := + match n as n0 return (Vec A n0 -> Vector.t A n0) with + | O => StdLib.VectorDef.case0 (fun _ => Vector.t A 0) (Vector.nil _) + | S m => (fun v' => Vector.cons _ (fst v') _ (vec_new_to_vec (snd v'))) + end v. + +Lemma V2VN_id {A n} (v : Vector.t A n) : + vec_new_to_vec (vec_to_vec_new v) = v. +Proof. induction v; simpl; [|rewrite IHv]; reflexivity. Qed. + +Lemma VN2V_id {A n} (v : Vec A n) : + vec_to_vec_new (vec_new_to_vec v) = v. +Proof. + induction n; destruct v; simpl; auto. + rewrite IHn; reflexivity. +Qed. diff --git a/StdLib/Vector.v b/StdLib/Vector.v new file mode 100644 index 0000000..2c4f445 --- /dev/null +++ b/StdLib/Vector.v @@ -0,0 +1,7 @@ +Require Kami.StdLib.Fin. +Require Kami.StdLib.VectorDef. +Require Kami.StdLib.VectorSpec. +Require Kami.StdLib.VectorEq. +Include Kami.StdLib.VectorDef. +Include Kami.StdLib.VectorSpec. +Include Kami.StdLib.VectorEq. diff --git a/StdLib/VectorDef.v b/StdLib/VectorDef.v new file mode 100644 index 0000000..974222e --- /dev/null +++ b/StdLib/VectorDef.v @@ -0,0 +1,313 @@ +Require Import String. +Require Import Kami.StdLib.Fin. +Require Import Arith_base. +Import EqNotations. +Local Open Scope nat_scope. + +Fixpoint Vec X n : Type := + match n with + | 0 => unit + | S m => (X * Vec X m)%type + end. + +Definition nil : forall (A : Type), Vec A 0 := fun _ => tt. +Definition cons : forall (A : Type), A -> forall (n : nat), Vec A n -> Vec A (S n) := + fun (A : Type) (a : A) (n : nat) (v : Vec A n) => (a, v). + + +Local Notation "[ ]" := (nil _) (format "[ ]"). +Local Notation "h :: t" := (cons _ h _ t) (at level 60, right associativity). + +Section SCHEMES. + + Definition rectS {A} (P : forall {n}, Vec A (S n) -> Type) + (bas : forall (a : A), P (a :: [])) + (rect : forall a {n} (v : Vec A (S n)), P v -> P (a :: v)) : + forall (n : nat) (v : Vec A (S n)), P v. + Proof. + induction n; intros. + - repeat destruct v. + apply (bas a). + - destruct v; apply rect, IHn. + Defined. + + Definition case0 {A} (P : Vec A 0 -> Type) (H : P []) (v : Vec A 0) := + match v as u return (P u) with + | tt => H + end. + + Definition caseS {A} (P : forall {n}, Vec A (S n) -> Type) + (H : forall h {n} t, @P n (h :: t)) {n} (v : Vec A (S n)) : P v := + let (a, v0) as p return (P p) := v in H a v0. + + Definition caseS' {A} {n : nat} (v : Vec A (S n)) (P : Vec A (S n) -> Type) + (H : forall h t, P (h :: t)) : P v := + let (a, v0) as p return (P p) := v in H a v0. + + Definition rect2 {A B} (P : forall {n}, Vec A n -> Vec B n -> Type) + (bas : P [] []) (rect : forall {n v1 v2}, @P n v1 v2 -> + forall a b, P (a :: v1) (b :: v2)) : + forall (n : nat) (v1 : Vec A n) (v2 : Vec B n), P v1 v2. + Proof. + induction n; intros; destruct v1, v2. + - apply bas. + - apply rect, IHn. + Defined. + +End SCHEMES. + +Section BASES. +(** The first element of a non empty vector *) +Definition hd {A} := @caseS _ (fun n v => A) (fun h n t => h). +Global Arguments hd {A} {n} v. + +(** The last element of an non empty vector *) +Definition last {A} := @rectS _ (fun _ _ => A) (fun a => a) (fun _ _ _ H => H). +Global Arguments last {A} {n} v. + +(** Build a vector of n{^ th} [a] *) +Definition const {A} (a:A) := nat_rect _ [] (fun n x => cons _ a n x). + +(** The [p]{^ th} element of a vector of length [m]. *) +Fixpoint nth {n X} : Vec X n -> Fin n -> X := + match n with + | 0 => fun _ => Fin.case0 _ + | S m => fun v i => fin_case i _ (fst v) (fun j => nth (snd v) j) + end. + +(** An equivalent definition of [nth]. *) +Definition nth_order {A} {n} (v: Vec A n) {p} (H: p < n) := +(nth v (Fin.of_nat_lt H)). + +Definition replace {A n} (v : Vec A n) (p : Fin n) (a : A) : Vec A n. +Proof. + induction n. + - inversion p. + - simpl in *. + destruct p, v. + + apply (a, v). + + apply (a0, (IHn v f)). +Defined. + +(** Version of replace with [lt] *) +Definition replace_order {A n} (v: Vec A n) {p} (H: p < n) := +replace v (Fin.of_nat_lt H). + +(** Remove the first element of a non empty vector *) +Definition tl {A} := @caseS _ (fun n v => Vec A n) (fun h n t => t). +Global Arguments tl {A} {n} v. + +(** Destruct a non empty vector *) +Definition uncons {A} {n : nat} (v : Vec A (S n)) : A * Vec A n := (hd v, tl v). + +(** Remove last element of a non-empty vector *) +Definition shiftout {A} := @rectS _ (fun n _ => Vec A n) (fun a => []) + (fun h _ _ H => h :: H). +Global Arguments shiftout {A} {n} v. + +(** Add an element at the end of a vector *) +Fixpoint shiftin {A} {n : nat} (a : A) (v : Vec A n) : Vec A (S n) := + match n as n0 return (Vec A n0 -> Vec A (S n0)) with + | O => (fun v' => (a, v')) + | S m => (fun v' => (fst v', (shiftin a (snd v')))) + end v. + +(** Copy last element of a vector *) +Definition shiftrepeat {A} := @rectS _ (fun n _ => Vec A (S (S n))) + (fun h => h :: h :: []) (fun h _ _ H => h :: H). +Global Arguments shiftrepeat {A} {n} v. + +(** Take first [p] elements of a vector *) +Fixpoint take {A} {n} (p:nat) (le:p <= n) (v:Vec A n) : Vec A p := + match p as p0 return p0 <= n -> Vec A n -> Vec A p0 with + | O => fun _ _ => [] + | S p' => + match n as n0 return S p' <= n0 -> Vec A n0 -> Vec A (S p') with + | O => fun le => False_rect _ (Nat.nle_succ_0 p' le) + | S n' => fun le v' => ((fst v'), take p' (le_S_n p' _ le) (snd v')) + end + end le v. + +(** Remove [p] last elements of a vector *) +Lemma trunc : forall {A} {n} (p:nat), n > p -> Vec A n + -> Vec A (n - p). +Proof. + induction p as [| p f]; intros H v. + rewrite <- minus_n_O. + exact v. + + apply shiftout. + + rewrite minus_Sn_m. + apply f. + auto with *. + exact v. + auto with *. +Defined. + +(** Concatenation of two vectors *) +Fixpoint append {A}{n}{p} (v : Vec A n) (w : Vec A p) : Vec A (n + p) := + match n as n0 return Vec A n0 -> Vec A (n0 + p) with + | O => (fun _ => w) + | S m => (fun v' => (fst v', (append (snd v') w))) + end v. + +Infix "++" := append. + +(** Split a vector into two parts *) +Fixpoint splitat {A} (l : nat) {r : nat} : + Vec A (l + r) -> Vec A l * Vec A r := + match l with + | 0 => fun v => ([], v) + | S l' => fun v => + let (v1, v2) := splitat l' (tl v) in + (hd v::v1, v2) + end. + +(** Two definitions of the tail recursive function that appends two lists but *) +(* reverses the first one *) + +(** This one has the exact expected computational behavior *) +Fixpoint rev_append_tail {A n p} (v : Vec A n) (w: Vec A p) + : Vec A (tail_plus n p) := + match n as n0 return Vec A n0 -> Vec A (tail_plus n0 p) with + | O => (fun _ => w) + | S m => (fun v' => rev_append_tail (snd v') ((fst v') :: w)) + end v. + +Import EqdepFacts. + +(** This one has a better type *) +Definition rev_append {A n p} (v: Vec A n) (w: Vec A p) + :Vec A (n + p) := + rew <- (plus_tail_plus n p) in (rev_append_tail v w). + +(** rev [a₁ ; a₂ ; .. ; an] is [an ; a{n-1} ; .. ; a₁] *) + +(* Caution : There is a lot of rewrite garbage in this definition *) +Definition rev {A n} (v : Vec A n) : Vec A n := + rew <- (plus_n_O _) in (rev_append v []). + +End BASES. +Local Notation "v [@ p ]" := (nth v p) (at level 1). + +Section ITERATORS. +(** * Here are special non dependent useful instantiation of induction schemes *) + +(** Uniform application on the arguments of the vector *) +Definition map {A} {B} (f : A -> B) : forall {n} (v : Vec A n), Vec B n := + fix map_fix {n} (v : Vec A n) : Vec B n := + match n as n0 return Vec A n0 -> Vec B n0 with + | O => (fun _ => []) + | S m => (fun v' => (f (fst v'), (map_fix (snd v')))) + end v. + +(** map2 g [x1 .. xn] [y1 .. yn] = [(g x1 y1) .. (g xn yn)] *) +Definition map2 {A B C} (g:A -> B -> C) : + forall (n : nat), Vec A n -> Vec B n -> Vec C n := +@rect2 _ _ (fun n _ _ => Vec C n) (nil C) (fun _ _ _ H a b => (g a b) :: H). +Global Arguments map2 {A B C} g {n} v1 v2. + +(** fold_left f b [x1 .. xn] = f .. (f (f b x1) x2) .. xn *) +Definition fold_left {A B:Type} (f : B -> A -> B) : forall (b : B) {n} (v : Vec A n), B := + fix fold_left_fix (b:B) {n} (v : Vec A n) : B := + match n as n0 return Vec A n0 -> B with + | O => fun _ => b + | S m => (fun v' => (fold_left_fix (f b (fst v')) (snd v'))) + end v. + +(** fold_right f [x1 .. xn] b = f x1 (f x2 .. (f xn b) .. ) *) +Definition fold_right {A B : Type} (f : A -> B -> B) := + fix fold_right_fix {n} (v : Vec A n) (b : B) : B := + match n as n0 return Vec A n0 -> B with + | O => fun _ => b + | S m => (fun v' => f (fst v') (fold_right_fix (snd v') b)) + end v. + +(** fold_right2 g c [x1 .. xn] [y1 .. yn] = g x1 y1 (g x2 y2 .. (g xn yn c) .. ) *) +(* c is before the vectors to be compliant with "refolding". *) +Definition fold_right2 {A B C} (g:A -> B -> C -> C) (c: C) := +@rect2 _ _ (fun _ _ _ => C) c (fun _ _ _ H a b => g a b H). + + +(** fold_left2 f b [x1 .. xn] [y1 .. yn] = g .. (g (g a x1 y1) x2 y2) .. xn yn *) +Definition fold_left2 {A B C: Type} (f : A -> B -> C -> A) := + fix fold_left2_fix (a : A) {n} (v : Vec B n) : Vec C n -> A := + match n as n0 return Vec B n0 -> Vec C n0 -> A with + | O => (fun v' w => case0 (fun _ => A) a w) + | S m => + (fun v' w => caseS' w (fun _ => A) + (fun wh wt => fold_left2_fix (f a (fst v') wh) (snd v') wt)) + end v. + +End ITERATORS. + +Section SCANNING. +Inductive Forall {A} (P: A -> Prop): forall {n} (v: Vec A n), Prop := + |Forall_nil: Forall P [] + |Forall_cons {n} x (v: Vec A n): P x -> Forall P v -> Forall P (x::v). +Hint Constructors Forall : core. + +Inductive Exists {A} (P:A->Prop): forall {n}, Vec A n -> Prop := + |Exists_cons_hd {m} x (v: Vec A m): P x -> Exists P (x::v) + |Exists_cons_tl {m} x (v: Vec A m): Exists P v -> Exists P (x::v). +Hint Constructors Exists : core. + +Inductive In {A} (a:A): forall {n}, Vec A n -> Prop := + |In_cons_hd {m} (v: Vec A m): In a (a::v) + |In_cons_tl {m} x (v: Vec A m): In a v -> In a (x::v). +Hint Constructors In : core. + +Inductive Forall2 {A B} (P:A->B->Prop): forall {n}, Vec A n -> Vec B n -> Prop := + |Forall2_nil: Forall2 P [] [] + |Forall2_cons {m} x1 x2 (v1:Vec A m) v2: P x1 x2 -> Forall2 P v1 v2 -> + Forall2 P (x1::v1) (x2::v2). +Hint Constructors Forall2 : core. + +Inductive Exists2 {A B} (P:A->B->Prop): forall {n}, Vec A n -> Vec B n -> Prop := + |Exists2_cons_hd {m} x1 x2 (v1: Vec A m) (v2: Vec B m): P x1 x2 -> Exists2 P (x1::v1) (x2::v2) + |Exists2_cons_tl {m} x1 x2 (v1:Vec A m) v2: Exists2 P v1 v2 -> Exists2 P (x1::v1) (x2::v2). +Hint Constructors Exists2 : core. + +End SCANNING. + +Section VECTORLIST. +(** * vector <=> list functions *) + +Fixpoint of_list {A} (l : list A) : Vec A (List.length l) := +match l as l' return Vec A (List.length l') with + |Datatypes.nil => [] + |(h :: tail)%list => (h :: (of_list tail)) +end. + +Definition to_list {A}{n} (v : Vec A n) : list A := +Eval cbv delta beta in fold_right (fun h H => Datatypes.cons h H) v Datatypes.nil. + +Fixpoint add_indices_aux{n X} : nat -> Vec X n -> Vec (nat * X) n := + match n return nat -> Vec X n -> Vec (nat * X) n with + | 0 => fun _ _ => tt + | S m => fun acc '(x,xs) => ((acc,x), add_indices_aux (S acc) xs) + end. + +Definition add_indices{n X} : Vec X n -> Vec (nat * X) n := add_indices_aux 0. + +Fixpoint add_strings{n X} : (Fin n -> string) -> Vec X n -> Vec (string * X) n := + match n return (Fin n -> string) -> Vec X n -> Vec (string * X) n with + | 0 => fun _ _ => tt + | S m => fun strs '(x,xs) => ((strs Fin.F1,x),add_strings (fun j => strs (Fin.FS j)) xs) + end. + +End VECTORLIST. + +Module VectorNotations. +Declare Scope vector_scope. +Delimit Scope vector_scope with vector. +Notation "[ ]" := [] (format "[ ]") : vector_scope. +Notation "h :: t" := (h :: t) (at level 60, right associativity) + : vector_scope. +Notation "[ x ]" := (x :: []) : vector_scope. +Notation "[ x ; y ; .. ; z ]" := (cons _ x _ (cons _ y _ .. (cons _ z _ (nil _)) ..)) : vector_scope. +Notation "v [@ p ]" := (nth v p) (at level 1, format "v [@ p ]") : vector_scope. +Infix "++" := append : vector_scope. +Open Scope vector_scope. +End VectorNotations. diff --git a/StdLib/VectorEq.v b/StdLib/VectorEq.v new file mode 100644 index 0000000..e72dfb8 --- /dev/null +++ b/StdLib/VectorEq.v @@ -0,0 +1,63 @@ +Require Import Kami.StdLib.VectorDef. +Require Import Kami.StdLib.VectorSpec. +Import Kami.StdLib.VectorDef.VectorNotations. + +Section BEQ. + + Variables (A: Type) (A_beq: A -> A -> bool). + Hypothesis A_eqb_eq: forall x y, A_beq x y = true <-> x = y. + + Definition eqb: + forall {m n} (v1: Vec A m) (v2: Vec A n), bool := + fix fix_beq {m n} v1 v2 := + match m as m0 return (Vec A m0 -> bool) with + | O => (fun _ => match n with + | O => true + | S n' => false + end) + | S m' => (fun v1' => + match n as n0 return (Vec A n0 -> bool) with + | O => fun _ => false + | S n' => (fun v2' => (A_beq (fst v1') (fst v2') && + fix_beq (snd v1') (snd v2'))%bool) + end v2) + end v1. + + Lemma eqb_nat_eq: forall m n (v1: Vec A m) (v2: Vec A n) + (Hbeq: eqb v1 v2 = true), m = n. + Proof. + intros m n v1; revert n. + induction m; destruct n; auto; try discriminate; intros. + destruct v1, v2; simpl in *. + apply andb_prop in Hbeq; destruct Hbeq. + f_equal; eauto. + Qed. + + Lemma eqb_eq: forall n (v1: Vec A n) (v2: Vec A n), + eqb v1 v2 = true <-> v1 = v2. + Proof. + refine (@rect2 _ _ _ _ _); [now constructor | simpl]. + intros ? ? ? Hrec h1 h2; destruct Hrec; destruct (A_eqb_eq h1 h2); split. + + intros Hbeq. apply andb_prop in Hbeq; destruct Hbeq. + f_equal; now auto. + + intros Heq. destruct (cons_inj Heq). apply andb_true_intro. + split; now auto. + Qed. + + Definition eq_dec n (v1 v2: Vec A n): {v1 = v2} + {v1 <> v2}. + Proof. + case_eq (eqb v1 v2); intros. + + left; now apply eqb_eq. + + right. intros Heq. apply <- eqb_eq in Heq. congruence. + Defined. + +End BEQ. + +Section CAST. + + Definition cast: forall {A m} (v: Vec A m) {n}, m = n -> Vec A n. + Proof. + intros; subst; auto. + Defined. + +End CAST. diff --git a/StdLib/VectorSpec.v b/StdLib/VectorSpec.v new file mode 100644 index 0000000..489bb03 --- /dev/null +++ b/StdLib/VectorSpec.v @@ -0,0 +1,181 @@ +Require Import Kami.StdLib.Fin. +Require Import Kami.StdLib.VectorDef. +Import Kami.StdLib.VectorDef.VectorNotations. + +Definition cons_inj {A} {a1 a2} {n} {v1 v2 : Vec A n} + (eq : a1 :: v1 = a2 :: v2) : a1 = a2 /\ v1 = v2 := + match eq in _ = x return caseS _ (fun a2' _ v2' => fun v1' => a1 = a2' /\ v1' = v2') x v1 + with | eq_refl => conj eq_refl eq_refl + end. + +Lemma eta {A} {n} (v : Vec A (S n)) : v = hd v :: tl v. +Proof. +intros; apply caseS with (v:=v); intros; reflexivity. +Defined. + +(** Lemmas are done for functions that use [Fin.t] but thanks to [Peano_dec.le_unique], all +is true for the one that use [lt] *) + +Lemma eq_nth_iff A n (v1 v2: Vec A n): + (forall p1 p2, p1 = p2 -> v1 [@ p1 ] = v2 [@ p2 ]) <-> v1 = v2. +Proof. +split. +- revert n v1 v2; refine (@rect2 _ _ _ _ _); simpl; intros. + + reflexivity. + + f_equal. apply (H0 Fin.F1 Fin.F1 eq_refl). + apply H. intros p1 p2 H1; + apply (H0 (Fin.FS p1) (Fin.FS p2) (f_equal (@Fin.FS n) H1)). +- intros; now f_equal. +Qed. + +Lemma nth_order_last A: forall n (v: Vec A (S n)) (H: n < S n), + nth_order v H = last v. +Proof. +unfold nth_order; refine (@rectS _ _ _ _); now simpl. +Qed. + +(* Lemma shiftin_nth A a n (v: Vec A n) k1 k2 (eq: k1 = k2): *) +(* nth (shiftin a v) (Fin.L_R 1 k1) = nth v k2. *) +(* Proof. *) +(* subst k2; induction k1. *) +(* - generalize dependent n. apply caseS ; intros. now simpl. *) +(* - generalize dependent n. refine (@caseS _ _ _) ; intros. now simpl. *) +(* Qed. *) + +Lemma shiftin_last A a n (v: Vec A n): last (shiftin a v) = a. +Proof. + induction n. + - destruct v; reflexivity. + - apply IHn. +Qed. + +(* Lemma shiftrepeat_nth A: forall n k (v: Vec A (S n)), *) +(* nth (shiftrepeat v) (Fin.L_R 1 k) = nth v k. *) +(* Proof. *) +(* refine (@Fin.rectS _ _ _); lazy beta; [ intros n v | intros n p H v ]. *) +(* - revert n v; refine (@caseS _ _ _); simpl; intros. now destruct t. *) +(* - revert p H. *) +(* refine (match v as v' in t _ m return match m as m' return t A m' -> Prop with *) +(* |S (S n) => fun v => forall p : Fin.t (S n), *) +(* (forall v0 : t A (S n), (shiftrepeat v0) [@ Fin.L_R 1 p ] = v0 [@p]) -> *) +(* (shiftrepeat v) [@Fin.L_R 1 (Fin.FS p)] = v [@Fin.FS p] *) +(* |_ => fun _ => True end v' with *) +(* |[] => I |h :: t => _ end). destruct n0. exact I. now simpl. *) +(* Qed. *) + +Lemma shiftrepeat_last A: forall n (v: Vec A (S n)), last (shiftrepeat v) = last v. +Proof. +refine (@rectS _ _ _ _); now simpl. +Qed. + +Lemma const_nth A (a: A) n (p: Fin n): (const a n)[@ p] = a. +Proof. + induction n; [inversion p|]. + simpl in *. + destruct p. + - destruct u; simpl; reflexivity. + - simpl; apply IHn. +Qed. + +Lemma nth_map {A B} (f: A -> B) {n} v (p1 p2: Fin n) (eq: p1 = p2): + (map f v) [@ p1] = f (v [@ p2]). +Proof. + subst p2; induction n. + - inversion p1. + - destruct p1. + + destruct u; simpl; reflexivity. + + apply IHn. +Qed. + +Lemma nth_map2 {A B C} (f: A -> B -> C) {n} v w (p1 p2 p3: Fin n): + p1 = p2 -> p2 = p3 -> (map2 f v w) [@p1] = f (v[@p2]) (w[@p3]). +Proof. +intros; subst p2; subst p3; revert n v w p1. +refine (@rect2 _ _ _ _ _); simpl. +- exact (Fin.case0 _). +- intros n v1 v2 H a b p; revert n p v1 v2 H; refine (@Fin.caseS _ _ _); + now simpl. +Qed. + +Lemma fold_left_right_assoc_eq {A B} {f: A -> B -> A} + (assoc: forall a b c, f (f a b) c = f (f a c) b) + {n} (v: Vec B n): forall a, fold_left f a v = fold_right (fun x y => f y x) v a. +Proof. +assert (forall n h (v: Vec B n) a, fold_left f (f a h) v = f (fold_left f a v) h). +- induction n0; auto; intros. + destruct v0; simpl. + rewrite assoc, <- IHn0; reflexivity. +- induction n; auto; intros. + destruct v; simpl; rewrite <- IHn. + apply H. +Qed. + +Lemma to_list_of_list_opp {A} (l: list A): to_list (of_list l) = l. +Proof. +induction l. +- reflexivity. +- unfold to_list; simpl. now f_equal. +Qed. + +Lemma take_O : forall {A} {n} le (v:Vec A n), take 0 le v = []. +Proof. + intros; destruct n; reflexivity. +Qed. + +Lemma take_idem : forall {A} p n (v:Vec A n) le le', + take p le' (take p le v) = take p le v. +Proof. + induction p; intros n v le le'. + - destruct n; auto. + - destruct n. inversion le. simpl. apply f_equal. apply IHp. +Qed. + +Lemma take_app : forall {A} {n} (v:Vec A n) {m} (w:Vec A m) le, take n le (append v w) = v. +Proof. + induction n; intros m w v le; simpl. + - destruct m; apply take_O. + - destruct m; simpl; f_equal. + apply IHn. +Qed. + +(* Proof is irrelevant for [take] *) +Lemma take_prf_irr : forall {A} p {n} (v:Vec A n) le le', take p le v = take p le' v. +Proof. + induction p; intros n v le le'. + - destruct n; reflexivity. + - destruct n; [inversion le|]; simpl; f_equal; apply IHp. +Qed. + +Lemma uncons_cons {A} : forall {n : nat} (a : A) (v : Vec A n), + uncons (a::v) = (a,v). +Proof. reflexivity. Qed. + +Lemma append_comm_cons {A} : forall {n m : nat} (v : Vec A n) (w : Vec A m) (a : A), + a :: (v ++ w) = (a :: v) ++ w. +Proof. reflexivity. Qed. + +Lemma splitat_append {A} : forall {n m : nat} (v : Vec A n) (w : Vec A m), + splitat n (v ++ w) = (v, w). +Proof with simpl; auto. + intros n m v. + generalize dependent m. + induction n; intros... + - destruct v; reflexivity. + - destruct v; simpl; rewrite IHn; reflexivity. +Qed. + +Lemma append_splitat {A} : forall {n m : nat} (v : Vec A n) (w : Vec A m) (vw : Vec A (n+m)), + splitat n vw = (v, w) -> + vw = v ++ w. +Proof with auto. + intros n m v. + generalize dependent m. + induction n; intros; inversion H... + destruct (splitat n (tl vw)) as [v' w'] eqn:Heq. + apply pair_equal_spec in H1. + destruct H1; subst. + rewrite <- append_comm_cons. + rewrite (eta vw); simpl. + f_equal. + apply IHn; assumption. +Qed. diff --git a/Syntax.v b/Syntax.v index 1422961..6e62467 100644 --- a/Syntax.v +++ b/Syntax.v @@ -1,5 +1,8 @@ -Require Export Bool Ascii String Fin List FunctionalExtensionality Psatz PeanoNat. +Require Export Bool Ascii String FunctionalExtensionality Psatz PeanoNat. +Require Export Kami.StdLib.Fin. +Require Export Kami.StdLib.Vector. Require Export Kami.Lib.VectorFacts Kami.Lib.EclecticLib. +Require Export List. Require Export Kami.Lib.Word Kami.Lib.WordProperties. Export ListNotations. @@ -19,7 +22,7 @@ Global Open Scope list_scope. Inductive Kind := | Bool : Kind | Bit : nat -> Kind -| Struct : forall n, (Fin.t n -> Kind) -> (Fin.t n -> string) -> Kind +| Struct : forall {n}, (Fin n -> Kind) -> (Fin n -> string) -> Kind | Array : nat -> Kind -> Kind. Inductive FullKind: Type := @@ -29,8 +32,8 @@ Inductive FullKind: Type := Inductive ConstT: Kind -> Type := | ConstBool: bool -> ConstT Bool | ConstBit n: word n -> ConstT (Bit n) -| ConstStruct n fk fs (fv: forall i, ConstT (fk i)): ConstT (@Struct n fk fs) -| ConstArray n k (fk: Fin.t n -> ConstT k): ConstT (Array n k). +| ConstStruct {n} fk fs (fv: forall i, ConstT (fk i)): ConstT (@Struct n fk fs) +| ConstArray {n} {k} (fk: Fin n -> ConstT k): ConstT (Array n k). Inductive ConstFullT: FullKind -> Type := | SyntaxConst k: ConstT k -> ConstFullT (SyntaxKind k) @@ -98,7 +101,7 @@ Fixpoint type (k: Kind): Type := | Bool => bool | Bit n => word n | Struct n fk fs => forall i, type (fk i) - | Array n k' => Fin.t n -> type k' + | Array n k' => Fin n -> type k' end. Fixpoint evalConstT k (e: ConstT k): type k := @@ -131,19 +134,19 @@ Section Phoas. Expr (SyntaxKind Bool) | ITE k: Expr (SyntaxKind Bool) -> Expr k -> Expr k -> Expr k | Eq k: Expr (SyntaxKind k) -> Expr (SyntaxKind k) -> Expr (SyntaxKind Bool) - | ReadStruct n (fk: Fin.t n -> Kind) (fs: Fin.t n -> string) + | ReadStruct {n} (fk: Fin n -> Kind) (fs: Fin n -> string) (e: Expr (SyntaxKind (Struct fk fs))) i: Expr (SyntaxKind (fk i)) - | BuildStruct n (fk: Fin.t n -> Kind) (fs: Fin.t n -> string) + | BuildStruct {n} (fk: Fin n -> Kind) (fs: Fin n -> string) (fv: forall i, Expr (SyntaxKind (fk i))): Expr (SyntaxKind (Struct fk fs)) - | ReadArray n m k: Expr (SyntaxKind (Array n k)) -> + | ReadArray {n} m k: Expr (SyntaxKind (Array n k)) -> Expr (SyntaxKind (Bit m)) -> Expr (SyntaxKind k) - | ReadArrayConst n k: Expr (SyntaxKind (Array n k)) -> - Fin.t n -> + | ReadArrayConst {n} k: Expr (SyntaxKind (Array n k)) -> + Fin n -> Expr (SyntaxKind k) - | BuildArray n k: (Fin.t n -> Expr (SyntaxKind k)) -> Expr (SyntaxKind (Array n k)) + | BuildArray {n} k: (Fin n -> Expr (SyntaxKind k)) -> Expr (SyntaxKind (Array n k)) | Kor k: list (Expr (SyntaxKind k)) -> Expr (SyntaxKind k) | ToNative k: Expr (SyntaxKind k) -> Expr (@NativeKind (type k) (evalConstT (getDefaultConst k))) @@ -153,22 +156,22 @@ Section Phoas. Definition UpdateArray n m k (e: Expr (SyntaxKind (Array n k))) (i: Expr (SyntaxKind (Bit m))) (v: Expr (SyntaxKind k)) := - BuildArray (fun i' : Fin.t n => - ITE (Eq i (Const (natToWord _ (proj1_sig (Fin.to_nat i'))))) v + BuildArray (fun i' : Fin n => + ITE (Eq i (Const (natToWord _ (proj1_sig (to_nat i'))))) v (ReadArrayConst e i')). Definition UpdateArrayConst n k (e: Expr (SyntaxKind (Array n k))) - (i: Fin.t n) + (i: Fin n) (v: Expr (SyntaxKind k)) := - BuildArray (fun i' : Fin.t n => - match Fin.eq_dec i i' with + BuildArray (fun i' : Fin n => + match Kami.StdLib.Fin.eq_dec i i' with | left _ => v | right _ => ReadArrayConst e i' end). - Definition UpdateStruct n (fk: Fin.t n -> Kind) (fs: Fin.t n -> string) + Definition UpdateStruct n (fk: Fin n -> Kind) (fs: Fin n -> string) (e: Expr (SyntaxKind (Struct fk fs))) i (v: Expr (SyntaxKind (fk i))) := - BuildStruct fk fs (fun i' => match Fin_eq_dec i i' with + BuildStruct fk fs (fun i' => match Kami.StdLib.Fin.eq_dec i i' with | left pf => match pf in _ = Y return Expr (SyntaxKind (fk Y)) with @@ -268,10 +271,10 @@ Section Phoas. end. Defined. - Fixpoint sumSizes n: (Fin.t n -> nat) -> nat := - match n return (Fin.t n -> nat) -> nat with + Fixpoint sumSizes {n}: (Fin n -> nat) -> nat := + match n return (Fin n -> nat) -> nat with | 0 => fun _ => 0 - | S m => fun sizes => sumSizes (fun x => sizes (Fin.FS x)) + sizes Fin.F1 + | S m => fun sizes => sumSizes (fun x => sizes (FS x)) + sizes F1 end. Fixpoint size (k: Kind) {struct k} := @@ -285,19 +288,19 @@ Section Phoas. (* ConstExtract: LSB, MIDDLE, MSB *) (* Concat: MSB, LSB *) - Fixpoint concatStructExpr n {struct n}: - forall (sizes: Fin.t n -> nat) + Fixpoint concatStructExpr {n} {struct n}: + forall (sizes: Fin n -> nat) (f: forall i, Expr (SyntaxKind (Bit (sizes i)))), Expr (SyntaxKind (Bit (sumSizes sizes))) := match n return forall - (sizes: Fin.t n -> nat) + (sizes: Fin n -> nat) (f: forall i, Expr (SyntaxKind (Bit (sizes i)))), Expr (SyntaxKind (Bit (sumSizes sizes))) with | 0 => fun _ _ => Const WO | S m => fun sizes f => BinBit - (Concat _ _) (f Fin.F1) - (@concatStructExpr m (fun x => (sizes (Fin.FS x))) (fun x => f (Fin.FS x))) + (Concat _ _) (f F1) + (@concatStructExpr m (fun x => (sizes (FS x))) (fun x => f (FS x))) end. Fixpoint pack (k: Kind): Expr (SyntaxKind k) -> Expr (SyntaxKind (Bit (size k))). @@ -323,31 +326,83 @@ Section Phoas. end; abstract lia. Defined. - Fixpoint sumSizesMsbs n (i: Fin.t n) {struct i}: (Fin.t n -> nat) -> nat := - match i in Fin.t n return (Fin.t n -> nat) -> nat with - | Fin.F1 _ => fun _ => 0 - | Fin.FS m f => fun sizes => sumSizesMsbs f (fun j => sizes (Fin.FS j)) + sizes Fin.F1 + Fixpoint sumSizesMsbs {n} : forall i: Fin n, (Fin n -> nat) -> nat := + match n as m return forall i : Fin m, (Fin m -> nat) -> nat with + | 0 => Fin.case0 (fun i => (Fin 0 -> nat) -> nat) + | S m => + fun i => + match i with + | inl _ => fun _ => 0 + | inr j => + fun sizes : Fin (S m) -> nat => + sumSizesMsbs j (fun k : Fin m => sizes (FS k)) + sizes F1 + end end. - Lemma helper_sumSizes n (i: Fin.t n): - forall (sizes: Fin.t n -> nat), sumSizes sizes = (sumSizes sizes - (sumSizesMsbs i sizes + sizes i)) + sizes i + sumSizesMsbs i sizes. + Lemma sumSizesMsbsLt {n} : forall (sizes : Fin n -> nat) (i : Fin n), sumSizesMsbs i sizes + sizes i <= sumSizes sizes. Proof. - induction i; simpl; intros; auto. - - lia. - - specialize (IHi (fun x => sizes (Fin.FS x))). - lia. + induction n as [|m IH]. + + exact (fun sizes => Fin.case0 (fun i => sumSizesMsbs i sizes + sizes i <= sumSizes sizes)). + + intro sizes; destruct i as [u|j]. + - destruct u; simpl; unfold F1; exact (le_plus_r (sumSizes (fun j : Fin m => sizes (FS j))) (sizes F1)). + - simpl. + rewrite <- (Nat.add_assoc + (sumSizesMsbs j (fun k : Fin m => sizes (FS k))) + (sizes F1) + (sizes (inr j : Fin (S m)))). + rewrite (Nat.add_comm + (sizes F1) + (sizes (inr j : Fin (S m)))). + rewrite (Nat.add_assoc + (sumSizesMsbs j (fun k : Fin m => sizes (FS k))) + (sizes (inr j : Fin (S m))) + (sizes F1)). + apply (plus_le_compat_r + (sumSizesMsbs j (fun k : Fin m => sizes (FS k)) + sizes (inr j)) + (sumSizes (fun k : Fin m => sizes (FS k))) + (sizes F1)). + exact (IH (fun k : Fin m => sizes (FS k)) j). + Qed. + + Lemma helper_sumSizes n : forall (i: Fin n), + forall (sizes: Fin n -> nat), + sumSizes sizes = + (sumSizes sizes - (sumSizesMsbs i sizes + sizes i)) + sizes i + sumSizesMsbs i sizes. + Proof. + intros i sizes. + rewrite (Nat.add_comm _ (sumSizesMsbs i sizes)). + rewrite (Nat.add_comm _ (sizes i)). + rewrite (Nat.add_assoc + (sumSizesMsbs i sizes) + (sizes i) + _). + rewrite (le_plus_minus_r + (sumSizesMsbs i sizes + sizes i) + (sumSizes sizes) + (sumSizesMsbsLt sizes i)). + reflexivity. Qed. - Lemma helper_array n (i: Fin.t n): + Lemma helper_array {n} (i: Fin n): forall size_k, - n * size_k = (proj1_sig (Fin.to_nat i) * size_k) + size_k + (n * size_k - ((proj1_sig (Fin.to_nat i) * size_k) + size_k)) . + n * size_k = + (proj1_sig (to_nat i) * size_k) + size_k + + (n * size_k - ((proj1_sig (to_nat i) * size_k) + size_k)) . Proof. - induction i; simpl; intros; auto. - - lia. - - case_eq (Fin.to_nat i); simpl; intros. - rewrite H in *; simpl in *. - rewrite IHi at 1. - lia. + intro size_k. + set (Hlt := + @mult_le_compat_r + (S (proj1_sig (to_nat i))) + n + size_k + (proj2_sig (to_nat i)) + : S (proj1_sig (to_nat i)) * size_k <= n * size_k). + simpl in Hlt. + rewrite (Nat.add_comm size_k (proj1_sig (to_nat i) * size_k)) in Hlt. + rewrite (le_plus_minus_r + (proj1_sig (to_nat i) * size_k + size_k) + (n * size_k) Hlt). + reflexivity. Qed. Fixpoint unpack (k: Kind): Expr (SyntaxKind (Bit (size k))) -> Expr (SyntaxKind k) := @@ -362,11 +417,11 @@ Section Phoas. _ (ConstExtract _ _ (sumSizesMsbs i (fun j => size (fk j))) - (@castBits _ _ (helper_sumSizes i (fun j => size (fk j))) e))) + (@castBits _ _ (helper_sumSizes n i (fun j => size (fk j))) e))) | Array n k => fun e => BuildArray - (fun i => unpack _ (ConstExtract (proj1_sig (Fin.to_nat i) * size k) _ _ + (fun i => unpack _ (ConstExtract (proj1_sig (to_nat i) * size k) _ _ (@castBits _ _ (helper_array _ _) e))) end. End BitOps. @@ -388,7 +443,7 @@ Section Phoas. match k return FullFormat k with | Bool => FBool 1 Hex | Bit n => FBit n ((n+3)/4) Hex - | Struct n fk fs => FStruct fk fs (fun i => fullFormatHex (fk i)) + | Struct n fk fs => FStruct n fk fs (fun i => fullFormatHex (fk i)) | Array n k => FArray n (fullFormatHex k) end. @@ -396,7 +451,7 @@ Section Phoas. match k return FullFormat k with | Bool => FBool 1 Binary | Bit n => FBit n n Binary - | Struct n fk fs => FStruct fk fs (fun i => fullFormatBinary (fk i)) + | Struct n fk fs => FStruct n fk fs (fun i => fullFormatBinary (fk i)) | Array n k => FArray n (fullFormatBinary k) end. @@ -404,7 +459,7 @@ Section Phoas. match k return FullFormat k with | Bool => FBool 1 Decimal | Bit n => FBit n 0 Decimal - | Struct n fk fs => FStruct fk fs (fun i => fullFormatDecimal (fk i)) + | Struct n fk fs => FStruct n fk fs (fun i => fullFormatDecimal (fk i)) | Array n k => FArray n (fullFormatDecimal k) end. @@ -479,7 +534,7 @@ Definition RuleT := Attribute (Action Void). Inductive RegFileInitT (IdxNum: nat) (Data: Kind) := | RFNonFile (init: option (ConstT Data)) -| RFFile (isAscii: bool) (isArg: bool) (file: string) (offset size: nat) (init: Fin.t IdxNum -> ConstT Data). +| RFFile (isAscii: bool) (isArg: bool) (file: string) (offset size: nat) (init: Fin IdxNum -> ConstT Data). Record SyncRead := { readReqName : string ; readResName : string ; @@ -522,7 +577,7 @@ Definition getRegFileRegisters m := end | RFFile isAscii isArg file offset size init => Some (SyntaxConst (ConstArray init)) end) :: match readers with - | Async _ => nil + | Async _ => [] | Sync isAddr read => if isAddr then map (fun x => (readRegName x, existT RegInitValT (SyntaxKind (Bit (Nat.log2_up IdxNum))) @@ -551,14 +606,14 @@ Definition getStructVal ty ls := (BuildStruct (fun i => snd (nth_Fin (map (@projT1 _ _) ls) i)) (fun j => fst (nth_Fin (map (@projT1 _ _) ls) j)) (fun k => nth_Fin_map2 (@projT1 _ _) (fun x => Expr ty (SyntaxKind (snd x))) - ls k (projT2 (nth_Fin ls (Fin.cast k (map_length_red (@projT1 _ _) ls)))))). + ls k (projT2 (nth_Fin ls (Kami.StdLib.Fin.cast k (map_length_red (@projT1 _ _) ls)))))). Arguments getStructVal : simpl never. Definition getStructConst ls := (ConstStruct (fun i => snd (nth_Fin (map (@projT1 _ _) ls) i)) (fun j => fst (nth_Fin (map (@projT1 _ _) ls) j)) (fun k => nth_Fin_map2 (@projT1 _ _) (fun x => ConstT (snd x)) - ls k (projT2 (nth_Fin ls (Fin.cast k (map_length_red (@projT1 _ _) ls)))))). + ls k (projT2 (nth_Fin ls (Kami.StdLib.Fin.cast k (map_length_red (@projT1 _ _) ls)))))). Arguments getStructConst : simpl never. Definition WriteRq lgIdxNum Data := (getStruct (cons ("addr", Bit lgIdxNum) @@ -579,11 +634,11 @@ Definition WriteRqMask lgIdxNum num Data := (getStruct (cons ("addr", Bit lgIdxN Definition buildNumDataArray num dataArray IdxNum Data ty (idx: ty (Bit (Nat.log2_up IdxNum))) := ReadReg dataArray (SyntaxKind (Array IdxNum Data)) (fun val => - Return (BuildArray (fun i: Fin.t num => + Return (BuildArray (fun i: Fin num => ReadArray (Var ty _ val) (CABit Add (Var ty (SyntaxKind _) idx :: - Const ty (natToWord _ (proj1_sig (Fin.to_nat i))) :: nil))))). + Const ty (natToWord _ (proj1_sig (to_nat i))) :: nil))))). Definition updateNumDataArray num dataArray IdxNum Data ty (idxData: ty (WriteRq (Nat.log2_up IdxNum) (Array num Data))): @@ -594,11 +649,11 @@ Definition updateNumDataArray num dataArray IdxNum Data ty (idxData: ty (WriteRq (fold_left (fun newArr i => (UpdateArray newArr (CABit Add (ReadStruct (Var ty (SyntaxKind _) idxData) - Fin.F1 :: - Const ty (natToWord _ (proj1_sig (Fin.to_nat i))) :: + F1 :: + Const ty (natToWord _ (proj1_sig (to_nat i))) :: nil)) (ReadArrayConst (ReadStruct (Var ty (SyntaxKind _) idxData) - (Fin.FS Fin.F1)) i))) (getFins num) + (FS F1)) i))) (getFins num) (Var ty (SyntaxKind (Array IdxNum Data)) val)) (Return (Const _ WO))). @@ -610,14 +665,14 @@ Definition updateNumDataArrayMask num dataArray IdxNum Data ty (idxData: ty (Wri WriteReg dataArray (fold_left (fun newArr i => ITE - (ReadArrayConst (ReadStruct (Var ty (SyntaxKind _) idxData) (Fin.FS (Fin.FS Fin.F1))) i) + (ReadArrayConst (ReadStruct (Var ty (SyntaxKind _) idxData) (FS (FS F1))) i) (UpdateArray newArr (CABit Add (ReadStruct (Var ty (SyntaxKind _) idxData) - Fin.F1 :: Const ty (natToWord _ (proj1_sig (Fin.to_nat i))) :: + F1 :: Const ty (natToWord _ (proj1_sig (to_nat i))) :: nil)) (ReadArrayConst (ReadStruct (Var ty (SyntaxKind _) idxData) - (Fin.FS Fin.F1)) i)) + (FS F1)) i)) newArr ) (getFins num) (Var ty (SyntaxKind (Array IdxNum Data)) val)) @@ -1348,7 +1403,7 @@ Proof. | _,_ => false end). destruct (Nat.eqb n m) eqn:G. - exact (Fin_forallb (fun i => Kind_decb (ks i) (ks' (Fin_cast i (mk_eq _ _ G)))) && Fin_forallb (fun i => String.eqb (fs i) (fs' (Fin_cast i (mk_eq _ _ G))))). + exact (Fin_forallb (fun i => Kind_decb (ks i) (ks' (Fin.cast i (mk_eq _ _ G)))) && Fin_forallb (fun i => String.eqb (fs i) (fs' (Fin.cast i (mk_eq _ _ G))))). exact false. Defined. @@ -1359,8 +1414,8 @@ Proof. - rewrite silly_lemma_true with (pf := (Nat.eqb_refl _)) by apply Nat.eqb_refl. rewrite andb_true_iff; split; rewrite Fin_forallb_correct; intros. - + rewrite (hedberg Nat.eq_dec _ eq_refl); simpl; apply H. - + rewrite (hedberg Nat.eq_dec _ eq_refl); simpl; apply String.eqb_refl. + + rewrite (hedberg Nat.eq_dec _ eq_refl); simpl; rewrite <- (cast_eq i (eq_refl n)); apply H. + + rewrite (hedberg Nat.eq_dec _ eq_refl); simpl; rewrite <- (cast_eq i (eq_refl n)); apply String.eqb_refl. - rewrite andb_true_iff; split; auto. apply Nat.eqb_refl. Qed. @@ -1373,13 +1428,19 @@ Proof. - destruct (n =? n0)%nat eqn:G. + simpl in H0. rewrite (@silly_lemma_true bool (n =? n0)%nat _ _ G) in H0 by auto. - pose proof G. + pose proof G. (* <=== *) rewrite Nat.eqb_eq in H1 by auto. rewrite andb_true_iff in H0; destruct H0 as [G1 G2]; rewrite Fin_forallb_correct in G1,G2; subst. rewrite (hedberg Nat.eq_dec _ eq_refl) in G1,G2; simpl in *. setoid_rewrite H in G1. setoid_rewrite String.eqb_eq in G2. - f_equal; extensionality i; auto. + f_equal; extensionality i. + * rewrite (G1 i). + rewrite <- (cast_eq i (eq_refl)). + reflexivity. + * rewrite (G2 i). + rewrite <- (cast_eq i (eq_refl)). + reflexivity. + simpl in H0. rewrite silly_lemma_false in H0; try discriminate; auto. - rewrite H0; apply Kind_decb_refl. @@ -1428,9 +1489,9 @@ Proof. extensionality x. apply Fin.case0. apply x. - + destruct (IHn (fun i => k (Fin.FS i)) (fun i => X (Fin.FS i)) (fun i => s (Fin.FS i)) - (fun i => e1 (Fin.FS i)) (fun i => e2 (Fin.FS i))). - * destruct (X Fin.F1 (e1 Fin.F1) (e2 Fin.F1)). + + destruct (IHn (fun i => k (FS i)) (fun i => X (FS i)) (fun i => s (FS i)) + (fun i => e1 (FS i)) (fun i => e2 (FS i))). + * destruct (X F1 (e1 F1) (e2 F1)). -- left. extensionality x. apply (Fin.caseS' x); try assumption; apply equal_f_dep; assumption. @@ -1444,8 +1505,8 @@ Proof. apply Fin.case0. apply x. + simpl in *. - destruct (IHn (fun i => e1 (Fin.FS i)) (fun i => e2 (Fin.FS i))). - * destruct (IHk (e1 Fin.F1) (e2 Fin.F1)). + destruct (IHn (fun i => e1 (FS i)) (fun i => e2 (FS i))). + * destruct (IHk (e1 F1) (e2 F1)). -- left. extensionality x. apply (Fin.caseS' x); try assumption; apply equal_f; assumption. @@ -1574,7 +1635,7 @@ Section Semantics. | BuildStruct n fk fs fv => fun i => @evalExpr _ (fv i) | ReadArray n m k fv i => match lt_dec (Z.to_nat (wordVal _ (@evalExpr _ i))) n with - | left pf => fun fv => fv (Fin.of_nat_lt pf) + | left pf => fun fv => fv (of_nat_lt pf) | right _ => fun _ => evalConstT (getDefaultConst k) end (@evalExpr _ fv) | ReadArrayConst n k fv i => @@ -2324,29 +2385,29 @@ Local Definition option_bind Local Notation "X >>- F" := (option_bind X F) (at level 85, only parsing). Fixpoint struct_get_field_index' - (name: string) n + (name: string) {n} := match n return - forall (get_name : Fin.t n -> string), - option (Fin.t n) + forall (get_name : Fin n -> string), + option (Fin n) with | 0 => fun _ => None | S m => fun get_name => - if String.eqb (get_name Fin.F1) name - then Some Fin.F1 - else match struct_get_field_index' name _ (fun i => get_name (Fin.FS i)) with - | Some i => Some (Fin.FS i) + if String.eqb (get_name F1) name + then Some F1 + else match struct_get_field_index' name (fun i => get_name (FS i)) with + | Some i => Some (FS i) | None => None end end. -Definition struct_get_field_index n (kinds: Fin.t n -> Kind) (names: Fin.t n -> string) ty (e: Expr ty (SyntaxKind (Struct kinds names))) name +Definition struct_get_field_index n (kinds: Fin n -> Kind) (names: Fin n -> string) ty (e: Expr ty (SyntaxKind (Struct kinds names))) name := struct_get_field_index' name names. Local Definition struct_get_field_aux (ty: Kind -> Type) (n : nat) - (get_kind : Fin.t n -> Kind) - (get_name : Fin.t n -> string) + (get_kind : Fin n -> Kind) + (get_name : Fin n -> string) (packet : Expr ty (SyntaxKind (Struct get_kind get_name))) (name : string) : option ({kind : Kind & Expr ty (SyntaxKind kind)}) @@ -2361,8 +2422,8 @@ Local Definition struct_get_field_aux Definition struct_get_field (ty: Kind -> Type) (n : nat) - (get_value : Fin.t n -> Kind) - (get_name : Fin.t n -> string) + (get_value : Fin n -> Kind) + (get_name : Fin n -> string) (packet : Expr ty (SyntaxKind (Struct get_value get_name))) (name : string) (k : Kind) @@ -2383,8 +2444,8 @@ Defined. Definition struct_get_field_default (ty: Kind -> Type) (n : nat) - (get_value : Fin.t n -> Kind) - (get_name : Fin.t n -> string) + (get_value : Fin n -> Kind) + (get_name : Fin n -> string) (packet : Expr ty (SyntaxKind (Struct get_value get_name))) (name : string) (kind : Kind) @@ -2398,8 +2459,8 @@ Definition struct_get_field_default Definition struct_set_field (ty: Kind -> Type) (n : nat) - (get_kind : Fin.t n -> Kind) - (get_name : Fin.t n -> string) + (get_kind : Fin n -> Kind) + (get_name : Fin n -> string) (packet : Expr ty (SyntaxKind (Struct get_kind get_name))) (name : string) (kind : Kind) @@ -2420,8 +2481,8 @@ Defined. Definition struct_set_field_default (ty: Kind -> Type) (n : nat) - (get_kind : Fin.t n -> Kind) - (get_name : Fin.t n -> string) + (get_kind : Fin n -> Kind) + (get_name : Fin n -> string) (packet : Expr ty (SyntaxKind (Struct get_kind get_name))) (name : string) (kind : Kind) @@ -2474,7 +2535,6 @@ Hint Unfold Datatypes.fst Datatypes.snd String.append - EclecticLib.nth_Fin : KamiDb. (* TODO + PUAR: Linux/Certikos diff --git a/Tactics.v b/Tactics.v index 0f18631..f7ba05d 100644 --- a/Tactics.v +++ b/Tactics.v @@ -65,6 +65,8 @@ Local Ltac finish_append := Ltac discharge_append := simpl; unfold getBool in *; process_append; finish_append. +(* TODO: LLEE: *) +(* Goal forall (a b c: string), (a ++ "a" <> a ++ "b" /\ a ++ "a" ++ b <> c ++ "b" ++ b @@ -73,7 +75,7 @@ Goal forall (a b c: string), /\ (a ++ "a") ++ b <> a ++ "b" ++ a /\ (a ++ (b ++ "b")) ++ "c" <> (a ++ b) ++ "d")%string. Proof. intuition idtac; discharge_append. Qed. - +*) Ltac discharge_DisjKey := repeat match goal with | |- DisjKey _ _ => diff --git a/Tutorial/ExtractEx.v b/Tutorial/ExtractEx.v index 91a6956..bb0a684 100644 --- a/Tutorial/ExtractEx.v +++ b/Tutorial/ExtractEx.v @@ -4,7 +4,7 @@ Require Import Kami.Tutorial.TacticsEx. (* Example of how to extract a module to be used by the Haskell simulator *) Definition IncrMod : BaseModule := IncrementerImpl 5 "test". - +(* Separate Extraction getFins @@ -27,3 +27,4 @@ Separate Extraction unpack IncrMod. +*) diff --git a/Utila.v b/Utila.v index 6f6356f..b3367e2 100644 --- a/Utila.v +++ b/Utila.v @@ -3,6 +3,7 @@ expressions. *) Require Import Kami.Syntax Kami.Notations Kami.LibStruct. +Require Import Kami.StdLib.Fin. Require Import List. Import Word.Notations. Require Import Kami.Lib.EclecticLib. @@ -1305,7 +1306,7 @@ Section utila. rewrite map_map, <- IHxs; auto. Qed. - Lemma list_to_array_id : forall n (xs: ArrTy type n) (i: Fin.t n), + Lemma list_to_array_id : forall n (xs: ArrTy type n) (i: Fin n), let i' := Fin.cast i (array_to_list_len xs) in (evalExpr (list_to_array (array_to_list xs))) i' = (evalExpr xs) i. Proof. @@ -1313,7 +1314,7 @@ Section utila. unfold array_to_list, array_to_list'. erewrite nth_Fin_nth. rewrite map_nth; cbn. - rewrite fin_to_nat_cast, getFins_nth; auto. + rewrite fin_to_nat_cast, (@getFins_nth (Expr type (SyntaxKind A))); auto. Unshelve. auto. Qed. @@ -1330,7 +1331,7 @@ Section utila. unfold array_to_list; intros; apply array_to_list'_forall. Qed. - Lemma array_to_list_nth {ty} : forall n (xs: ArrTy ty n) (i: Fin.t n) i', + Lemma array_to_list_nth {ty} : forall n (xs: ArrTy ty n) (i: Fin n) i', i' = proj1_sig (Fin.to_nat i) -> nth_error (array_to_list xs) i' = Some (ReadArrayConst xs i). Proof. @@ -1372,7 +1373,7 @@ Section utila. - induction ys; constructor; inv Hall; auto. Qed. - Definition fin_to_bit {ty n} (i: Fin.t n) : Bit (Nat.log2_up n) @# ty := + Definition fin_to_bit {ty n} (i: Fin n) : Bit (Nat.log2_up n) @# ty := Const _ (natToWord _ (proj1_sig (Fin.to_nat i))). Definition array_forall_except {ty n}