The objective was to create a weight-balanced binary search tree with a well-defined interface. Any critiques or suggestions for improvement are welcome.
module WBT = struct exception Empty_tree let bal_factor = 2 module type Ordered_type = sig include Stdlib.Map.OrderedType end module type Tree_type = sig type elt type t val mem : elt -> t -> bool val weight : t -> int val insert : elt -> t -> t val remove : elt -> t -> t val insert_seq : elt Seq.t -> t -> t val of_seq : elt Seq.t -> t val to_seq : t -> elt Seq.t end module Tree (T : Ordered_type) : Tree_type with type elt = T.t = struct type elt = T.t type t = Leaf | Node of elt * int * t * t let rec mem v = function | Leaf -> false | Node (v', _, l, r) -> let c = T.compare v v' in c = 0 || c < 0 && mem v l || mem v r let rec max_value = function | Leaf -> raise Empty_tree | Node (v, _, _, Leaf) -> v | Node (_, _, _, r) -> max_value r let rec min_value = function | Leaf -> raise Empty_tree | Node (v, _, Leaf, _) -> v | Node (_, _, l, _) -> min_value l let weight = function | Leaf -> 0 | Node (_, w, _, _) -> w let balanced = function | Leaf -> true | Node (_, _, l, r) -> abs (weight l - weight r) <= bal_factor let rec insert v = function | Leaf -> Node (v, 1, Leaf, Leaf) | Node (v', w, l, r) as n -> let c = T.compare v v' in if c = 0 then n else if c < 0 then let l' = balance @@ insert v l in balance @@ Node (v', w + 1, l', r) else let r' = balance @@ insert v r in balance @@ Node (v', w + 1, l, r') and remove v = function | Leaf -> Leaf | (Node (v', _, Leaf, (Leaf as b)) | Node (v', _, Leaf, b) | Node (v', _, b, Leaf)) when T.compare v v' = 0 -> b | Node (v', w, l, r) -> let c = T.compare v v' in if c = 0 then if weight l >= weight r then let v'' = max_value l in balance @@ Node (v'', w - 1, balance @@ remove v'' l, r) else let v'' = max_value r in balance @@ Node (v'', w - 1, l, balance @@ remove v'' r) else if c < 0 then let l' = balance @@ remove v l in balance @@ Node (v', weight l' + weight r + 1, l', r) else let r' = balance @@ remove v r in balance @@ Node (v', weight l + weight r' + 1, l, r') and rotate_left = function | Leaf -> Leaf | Node (_, _, _, Leaf) as n -> n | Node (v, w, l, r) -> let v' = min_value r in let r' = balance @@ remove v' r in balance @@ Node (v', w, insert v l, r') and rotate_right = function | Leaf -> Leaf | Node (_, _, Leaf, _) as n -> n | Node (v, w, l, r) -> let v' = max_value l in let l' = balance @@ remove v' l in balance @@ Node (v', w, l', insert v r) and balance = function | Leaf -> Leaf | n when balanced n -> n | Node (_, _, l, r) as n -> if weight l < weight r then balance @@ rotate_left n else balance @@ rotate_right n let insert_seq seq t = Seq.fold_left (fun t x -> insert x t) t seq let of_seq seq = insert_seq seq Leaf let to_seq t = let leaf v = Node (v, 1, Leaf, Leaf) in let rec aux t stack () = match t, stack with | Leaf, [] -> Seq.Nil | Leaf, x::xs -> aux x xs () | Node (v, _, Leaf, Leaf), [] -> Seq.Cons (v, Seq.empty) | Node (v, _, Leaf, Leaf), x::xs -> Seq.Cons (v, aux x xs) | Node (v, _, l, r), _ -> aux l (leaf v :: r :: stack) () in aux t [] end end