diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean index 1cf870d30dfb..da9c19021214 100644 --- a/src/Init/Data/BitVec/Lemmas.lean +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -619,6 +619,8 @@ theorem le_toInt {w : Nat} {x : BitVec w} : -2 ^ w ≤ 2 * x.toInt := by · rw [← Nat.two_pow_pred_add_two_pow_pred (by omega), ← Nat.two_mul, Nat.add_sub_cancel] simp only [Nat.zero_lt_succ, Nat.mul_lt_mul_left, Int.natCast_mul, Int.Nat.cast_ofNat_Int] norm_cast; omega +@[simp] theorem toInt_cast (h : w = v) (x : BitVec w) : (cast h x).toInt = x.toInt := by + simp [toInt_eq_toNat_bmod, h] /-! ### slt -/ @@ -673,6 +675,12 @@ theorem zeroExtend_eq_setWidth {v : Nat} {x : BitVec w} : (x.setWidth v).toFin = Fin.ofNat' (2^v) x.toNat := by ext; simp +theorem setWidth'_eq {x : BitVec w} (h : w ≤ v) : x.setWidth' h = x.setWidth v := by + apply eq_of_toNat_eq + rw [toNat_setWidth, toNat_setWidth'] + rw [Nat.mod_eq_of_lt] + exact Nat.lt_of_lt_of_le x.isLt (Nat.pow_le_pow_right (Nat.zero_lt_two) h) + @[simp] theorem setWidth_eq (x : BitVec n) : setWidth n x = x := by apply eq_of_toNat_eq let ⟨x, lt_n⟩ := x @@ -2026,6 +2034,23 @@ theorem append_def (x : BitVec v) (y : BitVec w) : (x ++ y).toNat = x.toNat <<< n ||| y.toNat := rfl +/-- Helper theorem to show that the expression in `(x ++ y).toFin` is inbounds. -/ +theorem toNat_append_lt {m n : Nat} (x : BitVec m) (y : BitVec n) : + x.toNat <<< n ||| y.toNat < 2 ^ (m + n) := by + have hnLe : 2^n ≤ 2 ^(m + n) := by + rw [Nat.pow_add] + exact Nat.le_mul_of_pos_left (2 ^ n) (Nat.two_pow_pos m) + apply Nat.or_lt_two_pow + · have := Nat.two_pow_pos n + rw [Nat.shiftLeft_eq, Nat.pow_add, Nat.mul_lt_mul_right] + <;> omega + · omega + +@[simp] theorem toFin_append {x : BitVec m} {y : BitVec n} : + (x ++ y).toFin = @Fin.mk (2^(m+n)) (x.toNat <<< n ||| y.toNat) (toNat_append_lt x y) := by + ext + simp + theorem getLsbD_append {x : BitVec n} {y : BitVec m} : getLsbD (x ++ y) i = if i < m then getLsbD y i else getLsbD x (i - m) := by simp only [append_def, getLsbD_or, getLsbD_shiftLeftZeroExtend, getLsbD_setWidth'] @@ -2072,6 +2097,173 @@ theorem msb_append {x : BitVec w} {y : BitVec v} : ext simp [getElem_append] +theorem append_zero {n m : Nat} {x : BitVec n} : + x ++ 0#m = x.signExtend (n + m) <<< m := by + induction m + case zero => + simp [signExtend] + case succ i ih => + simp [bv_toNat] + sorry + +def lhs (x : BitVec n) (y : BitVec m) : Int := (x++y).toInt +def rhs (x : BitVec n) (y : BitVec m) : Int := if n == 0 then y.toInt else (x.toInt * (2^m)) + y.toNat + +def eq (x: BitVec n) (y: BitVec m) : Bool := (lhs x y) = (rhs x y) + + +#eval (-5#10 ++ 3#2).toInt + +def test : Bool := Id.run do + for i in [0, 1, 2, 3, 4, 5, 6, 7, 8] do + for j in [0, 1, 2, 3, 4, 5, 6, 7, 8] do + for n in [0, 1, 2, 3, 4] do + for m in [0, 1, 2, 3, 4] do + let x := BitVec.ofNat n i + let y := BitVec.ofNat m j + if (!eq x y) then + return false + return true + +private theorem Nat.lt_mul_of_le_of_lt_of_lt {a b c : Nat} (hab : a ≤ b) (ha : 0 < a) (hc : 1 < c) : + a < b * c := by + have : a * 1 < b * c := Nat.mul_lt_mul_of_le_of_lt' (by omega) (by simp [hc]) (by omega) + simp at this + simp [this] + +private theorem Nat.two_pow_lt_two_pow_add {n m : Nat} (h : m ≠ 0) : + 2 ^ n < 2 ^ (n + m) := by + apply Nat.pow_lt_pow_of_lt (by omega) (by omega) + +@[simp] theorem signExtend_shiftLeft_msb {n m : Nat} {x : BitVec n} : + (signExtend (n + m) x <<< m).msb = x.msb := by + induction m + case zero => + simp [signExtend] + case succ i ih => + rw [← ih] + rw [msb_setWidth] + + unfold BitVec.msb getMsbD + simp + by_cases h : (0 < n + i) + · + rw [← Nat.add_assoc] + simp [h] + have h' : (0 < n + i + 1) := by omega + have hh : (n + i - (i + 1)) = (n + i - i - 1) := by + omega + rw [hh] + simp + have hhh : (n + i - 1 - i) = (n + i - i - 1) := by omega + rw [hhh] + simp + rw [getLsbD_signExtend] + + + + + + + simp [BitVec.msb, getMsbD] + + by_cases h : 0 < n + (i + 1) + · simp [h] + + sorry + · simp [h] + sorry + +@[simp] theorem signExtend_toNat_shift_mod : + ((signExtend (n + m) x).toNat <<< m) % ↑(2 ^ (n + m)) = (signExtend (n + m) x).toNat <<< m := + sorry + +@[simp] theorem toInt_append_zero {n m : Nat} {x : BitVec n} : + (x ++ 0#m).toInt = x.toInt * (2 ^ m) := by + by_cases m0 : m = 0 + · subst m0 + simp + · simp only [ofNat_eq_ofNat, append_zero, toInt_eq_msb_cond] + by_cases h1 : (signExtend (n + m) x <<< m).msb + · by_cases h2: x.msb + · norm_cast + simp [h1, h2] + norm_cast + rw [Int.sub_mul, Nat.pow_add] + norm_cast + simp + rw [Nat.shiftLeft_eq] + norm_cast + have aa := @Nat.pow_pos 2 m (by omega) + norm_cast + have bb := @Nat.mul_right_cancel_iff (2^m) ((signExtend (n + m) x).toNat) + apply bb + rfl + rw [Nat.mul_right_cancel (m := 2 ^ m)] + simp [aa] + rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)] + norm_cast + simp [h3] + simp_all + rw [Nat.shiftLeft_eq] + · simp only [signExtend_shiftLeft_of_lt] at h1 + contradiction + · by_cases h2: x.msb + · simp [signExtend_shiftLeft_of_lt, h2] at h1 + · sorry + +@[simp] theorem toInt_append {x : BitVec n} {y : BitVec m} : + (x ++ y).toInt = if n == 0 then y.toInt else x.toInt * (2 ^ m) + y.toNat := by + by_cases n0 : n = 0 + · subst n0 + simp [BitVec.eq_nil x] + · by_cases m0 : m = 0 + · subst m0 + simp [BitVec.eq_nil y, n0] + · simp [m0] + by_cases y0 : y = 0 + · simp [toInt_append_zero, y0, n0] + rw [toInt_eq_toNat_cond] + rw [toInt_eq_toNat_cond] + split + · + split + <;> norm_cast + <;> simp + <;> rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)] + <;> norm_cast + <;> simp [h3] + <;> simp_all + · rw [Nat.shiftLeft_eq] + · rename_i aa bb + rw [Nat.shiftLeft_eq] at aa + rw [Nat.pow_add] at aa + rw [← Nat.mul_assoc] at aa + + sorry + · + split + <;> norm_cast + <;> simp + <;> rw [Nat.mod_eq_of_lt (a := x.toNat) (by omega)] + <;> norm_cast + <;> simp [h3] + <;> simp_all + · rename_i aa bb + rw [Nat.shiftLeft_eq] at aa + rw [Nat.pow_add] at aa + rw [← Nat.mul_assoc] at aa + + + + + simp_all + + + sorry + · simp [Nat.shiftLeft_eq, Int.sub_mul, Nat.pow_add] + · sorry + @[simp] theorem cast_append_right (h : w + v = w + v') (x : BitVec w) (y : BitVec v) : (x ++ y).cast h = x ++ y.cast (by omega) := by ext