From fcd2f9b99aad351800329a823921dc5cc7c09030 Mon Sep 17 00:00:00 2001 From: Gregory Date: Mon, 10 Jan 2022 20:08:00 +0200 Subject: [PATCH] binary search tree --- searching/binary_search_tree.go | 389 +++++++++++++++++++++++++++ searching/binary_search_tree_test.go | 331 +++++++++++++++++++++++ searching/symbol_table.go | 23 ++ 3 files changed, 743 insertions(+) create mode 100644 searching/binary_search_tree.go create mode 100644 searching/binary_search_tree_test.go create mode 100644 searching/symbol_table.go diff --git a/searching/binary_search_tree.go b/searching/binary_search_tree.go new file mode 100644 index 0000000..4ddb1f3 --- /dev/null +++ b/searching/binary_search_tree.go @@ -0,0 +1,389 @@ +package searching + +import "github.com/fotonmoton/algorithms/fundamentals/queue" + +type bstNode[K any, V any] struct { + left *bstNode[K, V] + right *bstNode[K, V] + key K + val V + n int64 +} + +// TODO: maybe pass pointers for recursive funcs? +type bst[K any, V any] struct { + root *bstNode[K, V] + cmp func(*K, *K) int +} + +func NewBST[K any, V any](cmp func(*K, *K) int) SymbolTable[K, V] { + return &bst[K, V]{nil, cmp} +} + +func (t *bst[K, V]) Put(key K, val V) { + t.root = t.put(key, val, t.root) +} + +func (t *bst[K, V]) put(key K, val V, node *bstNode[K, V]) *bstNode[K, V] { + if node == nil { + return &bstNode[K, V]{nil, nil, key, val, 1} + } + + cmp := t.cmp(&key, &node.key) + + if cmp < 0 { + node.left = t.put(key, val, node.left) + } + + if cmp == 0 { + node.val = val + } + + if cmp > 0 { + node.right = t.put(key, val, node.right) + } + + node.n = t.size(node.left) + t.size(node.right) + 1 + return node +} + +func (t *bst[K, V]) Get(key K) *V { + return t.get(key, t.root) +} + +func (t *bst[K, V]) get(key K, node *bstNode[K, V]) *V { + if node == nil { + return nil + } + + cmp := t.cmp(&key, &node.key) + + if cmp < 0 { + return t.get(key, node.left) + } + + if cmp > 0 { + return t.get(key, node.right) + } + + return &node.val +} + +func (t *bst[_, __]) Size() int64 { + return t.size(t.root) +} + +func (t *bst[K, V]) size(node *bstNode[K, V]) int64 { + if node == nil { + return 0 + } + + return node.n +} + +func (t *bst[K, _]) Min() *K { + if t.root == nil { + return nil + } + + return &t.min(t.root).key +} + +func (t *bst[K, V]) min(node *bstNode[K, V]) *bstNode[K, V] { + if node.left == nil { + return node + } + + return t.min(node.left) +} + +func (t *bst[K, _]) Max() *K { + if t.root == nil { + return nil + } + + return &t.max(t.root).key +} + +func (t *bst[K, V]) max(node *bstNode[K, V]) *bstNode[K, V] { + if node.right == nil { + return node + } + + return t.max(node.right) +} + +func (t *bst[K, V]) Floor(key K) *K { + largest := t.floor(key, t.root) + + if largest == nil { + return nil + } + + return &largest.key +} + +func (t *bst[K, V]) floor(key K, node *bstNode[K, V]) *bstNode[K, V] { + if node == nil { + return nil + } + + cmp := t.cmp(&key, &node.key) + + if cmp == 0 { + return node + } + + if cmp < 0 { + return t.floor(key, node.left) + } + + larger := t.floor(key, node.right) + + if larger != nil { + return larger + } + + return node +} + +func (t *bst[K, V]) Ceiling(key K) *K { + smallest := t.ceiling(key, t.root) + + if smallest == nil { + return nil + } + + return &smallest.key +} + +func (t *bst[K, V]) ceiling(key K, node *bstNode[K, V]) *bstNode[K, V] { + if node == nil { + return nil + } + + cmp := t.cmp(&key, &node.key) + + if cmp == 0 { + return node + } + + if cmp > 0 { + return t.ceiling(key, node.right) + } + + smaller := t.ceiling(key, node.left) + + if smaller != nil { + return smaller + } + + return node +} + +func (t *bst[K, V]) Rank(key K) int64 { + return t.rank(key, t.root) +} + +func (t *bst[K, V]) rank(key K, node *bstNode[K, V]) int64 { + if node == nil { + return 0 + } + + cmp := t.cmp(&key, &node.key) + + // If we found key in a tree then left subtree + // will always contain keys less than current node key + // and right subtree will always ontain greater keys (by BST definition). + // So we simply return left subtree size + if cmp == 0 { + return t.size(node.left) + } + + // If current node key is bigger than key for which rank is searched + // we should descend deeper in left subtree + if cmp < 0 { + return t.rank(key, node.left) + } + + // If we found node with key that is less than search key + // we get the size of the left subtree, add 1 to count current node in + // rank value and descend deeper in right subtree. + return 1 + t.size(node.left) + t.rank(key, node.right) +} + +func (t *bst[K, V]) KeyByRank(i int64) *K { + node := t.keyByRank(i, t.root) + + if node == nil { + return nil + } + + return &node.key +} + +func (t *bst[K, V]) keyByRank(rank int64, node *bstNode[K, V]) *bstNode[K, V] { + if node == nil { + return nil + } + + // We need left subtree size to substract it from our index + // when we descend deeper in right subtree + leftSize := t.size(node.left) + + if rank < leftSize { + return t.keyByRank(rank, node.left) + } + + if rank > leftSize { + // We subtract left size subtree + return t.keyByRank(rank-leftSize-1, node.right) + } + + return node +} + +func (t *bst[K, V]) Contains(key K) bool { + return t.Get(key) == nil +} + +func (t *bst[K, V]) IsEmpty() bool { + return t.Size() == 0 +} + +func (t *bst[K, V]) DeleteMin() { + + if t.root == nil { + return + } + + t.root = t.deleteMin(t.root) +} + +func (t *bst[K, V]) deleteMin(node *bstNode[K, V]) *bstNode[K, V] { + if node.left == nil { + return node.right + } + + node.left = t.deleteMin(node.left) + node.n = t.size(node.left) + t.size(node.right) + 1 + + return node +} + +func (t *bst[K, V]) DeleteMax() { + + if t.root == nil { + return + } + + t.root = t.deleteMax(t.root) +} + +func (t *bst[K, V]) deleteMax(node *bstNode[K, V]) *bstNode[K, V] { + if node.right == nil { + return node.left + } + + node.right = t.deleteMax(node.right) + node.n = t.size(node.left) + t.size(node.right) + 1 + + return node +} + +func (t *bst[K, V]) Delete(key K) { + t.root = t.delete(key, t.root) +} + +func (t *bst[K, V]) delete(key K, node *bstNode[K, V]) *bstNode[K, V] { + if node == nil { + return nil + } + + cmp := t.cmp(&key, &node.key) + + if cmp < 0 { + node.left = t.delete(key, node.left) + } else if cmp > 0 { + node.right = t.delete(key, node.right) + } else { + + // Shortcut: we can return left or right subtree if we have only one of them + // without size recalculation and pointers juggling + if node.right == nil { + return node.left + } + if node.left == nil { + return node.right + } + + // Needed to delete "min" node in right subtree + tmp := node + // We substitute current node with "min" node from right subtree. + // When "node" variable will be returned to the caller "tmp" node + // will be erased by "node" value and be marked for garbage collection. + // At least it should work as described + node = t.min(tmp.right) + // We prevent "node" duplication in the tree by deleting it from right subtree + node.right = t.deleteMin(tmp.right) + // Left subtree stays unchanged + node.left = tmp.left + } + node.n = t.size(node.left) + t.size(node.right) + 1 + return node +} + +func (t *bst[K, V]) KeysBetween(lo, hi K) []K { + q := queue.NewQueue[K]() + t.keysBetween(lo, hi, t.root, q) + keys := make([]K, 0, q.Size()) + + for !q.IsEmpty() { + keys = append(keys, q.Dequeue()) + } + + return keys +} + +func (t *bst[K, V]) keysBetween(lo, hi K, node *bstNode[K, V], q queue.Queue[K]) { + if node == nil { + return + } + cmplo := t.cmp(&lo, &node.key) + cmphi := t.cmp(&hi, &node.key) + + if cmplo < 0 { + t.keysBetween(lo, hi, node.left, q) + } + + if cmplo <= 0 && cmphi >= 0 { + q.Enqueue(node.key) + } + + if cmphi > 0 { + t.keysBetween(lo, hi, node.right, q) + } +} + +func (t *bst[K, V]) Keys() []K { + if t.IsEmpty() { + return []K{} + } + + q := queue.NewQueue[K]() + t.keysBetween(*t.Min(), *t.Max(), t.root, q) + keys := make([]K, 0, q.Size()) + + for !q.IsEmpty() { + keys = append(keys, q.Dequeue()) + } + + return keys +} + +func (t *bst[K, V]) SizeBetween(lo K, hi K) int64 { + q := queue.NewQueue[K]() + t.keysBetween(lo, hi, t.root, q) + + return int64(q.Size()) +} diff --git a/searching/binary_search_tree_test.go b/searching/binary_search_tree_test.go new file mode 100644 index 0000000..1447716 --- /dev/null +++ b/searching/binary_search_tree_test.go @@ -0,0 +1,331 @@ +package searching + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func intCompare(a, b *int) int { + if *a < *b { + return -1 + } + + if *a > *b { + return 1 + } + + return 0 +} + +func TestPut(t *testing.T) { + table := NewBST[int, int](intCompare) + + table.Put(1, 10) + table.Put(2, 20) + + assert.Equal(t, 10, *table.Get(1)) + assert.Equal(t, 20, *table.Get(2)) + + // rewrite + table.Put(1, 11) + + assert.Equal(t, 11, *table.Get(1)) + assert.Equal(t, 20, *table.Get(2)) + assert.Equal(t, int64(2), table.Size()) + +} + +// TODO: test with delete +func TestGet(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Nil(t, table.Get(0)) + + table.Put(1, 2) + + assert.Equal(t, 2, *table.Get(1)) +} + +// TODO: test with delete +func TestSize(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Equal(t, int64(0), table.Size()) + + table.Put(1, 1) + + assert.Equal(t, int64(1), table.Size()) + + table.Put(2, 2) + + assert.Equal(t, int64(2), table.Size()) +} + +// TODO: test with delete +func TestMin(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Nil(t, table.Min()) + + table.Put(3, 3) + + assert.Equal(t, 3, *table.Min()) + + table.Put(2, 2) + + assert.Equal(t, 2, *table.Min()) + + table.Put(4, 4) + + assert.Equal(t, 2, *table.Min()) + + table.Put(1, 1) + + assert.Equal(t, 1, *table.Min()) +} + +// TODO: test with delete +func TestMax(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Nil(t, table.Max()) + + table.Put(1, 1) + assert.Equal(t, 1, *table.Max()) + + table.Put(2, 2) + assert.Equal(t, 2, *table.Max()) + + table.Put(5, 5) + assert.Equal(t, 5, *table.Max()) + + table.Put(4, 4) + assert.Equal(t, 5, *table.Max()) + + table.Put(3, 3) + assert.Equal(t, 5, *table.Max()) + + table.Put(5, 55) + assert.Equal(t, 5, *table.Max()) + + table.Put(6, 6) + assert.Equal(t, 6, *table.Max()) +} + +// TODO: test with delete +func TestFloor(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Nil(t, table.Floor(0)) + + table.Put(1, 1) + assert.Equal(t, 1, *table.Floor(1)) + + table.Put(5, 5) + assert.Equal(t, 5, *table.Floor(5)) + assert.Equal(t, 1, *table.Floor(4)) + + table.Put(4, 4) + assert.Equal(t, 5, *table.Floor(5)) + assert.Equal(t, 4, *table.Floor(4)) + assert.Equal(t, 1, *table.Floor(3)) + assert.Nil(t, table.Floor(0)) +} + +// TODO: test with delete +func TestCeiling(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Nil(t, table.Ceiling(0)) + + table.Put(5, 5) + assert.Equal(t, 5, *table.Ceiling(5)) + + table.Put(4, 4) + assert.Equal(t, 4, *table.Ceiling(0)) + assert.Equal(t, 5, *table.Ceiling(5)) + + table.Put(3, 3) + assert.Equal(t, 3, *table.Ceiling(0)) + assert.Equal(t, 3, *table.Ceiling(1)) + assert.Equal(t, 3, *table.Ceiling(3)) + assert.Equal(t, 4, *table.Ceiling(4)) + assert.Equal(t, 5, *table.Ceiling(5)) + assert.Nil(t, table.Ceiling(6)) +} + +// TODO: test with delete +func TestRank(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Equal(t, int64(0), table.Rank(1)) + + table.Put(0, 0) + assert.Equal(t, int64(1), table.Rank(1)) + + table.Put(1, 1) + assert.Equal(t, int64(2), table.Rank(2)) + + table.Put(4, 4) + assert.Equal(t, int64(2), table.Rank(2)) + assert.Equal(t, int64(2), table.Rank(3)) + assert.Equal(t, int64(3), table.Rank(5)) + + table.Put(2, 2) + assert.Equal(t, int64(2), table.Rank(2)) + assert.Equal(t, int64(3), table.Rank(3)) + assert.Equal(t, int64(4), table.Rank(5)) + + table.Put(3, 3) + assert.Equal(t, int64(2), table.Rank(2)) + assert.Equal(t, int64(3), table.Rank(3)) + assert.Equal(t, int64(4), table.Rank(4)) + assert.Equal(t, int64(5), table.Rank(5)) +} + +// TODO: test with delete +func TestKeyByRank(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.Nil(t, table.KeyByRank(1)) + + table.Put(0, 0) + assert.Nil(t, table.KeyByRank(1)) + assert.Equal(t, 0, *table.KeyByRank(table.Rank(0))) + + table.Put(5, 5) + assert.Equal(t, 5, *table.KeyByRank(table.Rank(5))) + assert.EqualValues(t, 1, table.Rank(*table.KeyByRank(1))) +} + +func TestDeleteMin(t *testing.T) { + table := NewBST[int, int](intCompare) + + table.DeleteMin() + + table.Put(0, 0) + assert.EqualValues(t, 1, table.Size()) + + table.DeleteMin() + assert.EqualValues(t, 0, table.Size()) + + table.Put(5, 5) + table.Put(0, 0) + table.Put(1, 1) + table.Put(2, 2) + + assert.Equal(t, 0, *table.Get(0)) + + table.DeleteMin() + + assert.Nil(t, table.Get(0)) + assert.EqualValues(t, 3, table.Size()) +} + +func TestDeleteMax(t *testing.T) { + table := NewBST[int, int](intCompare) + + table.DeleteMin() + + table.Put(0, 0) + assert.EqualValues(t, 1, table.Size()) + + table.DeleteMax() + assert.EqualValues(t, 0, table.Size()) + + table.Put(0, 0) + table.Put(5, 5) + table.Put(1, 1) + table.Put(2, 2) + + assert.Equal(t, 5, *table.Get(5)) + + table.DeleteMax() + + assert.Nil(t, table.Get(5)) + assert.EqualValues(t, 3, table.Size()) +} + +// TODO: add more cases +func TestDelete(t *testing.T) { + table := NewBST[int, int](intCompare) + + table.Delete(0) + + table.Put(0, 0) + + table.Delete(0) + assert.EqualValues(t, 0, table.Size()) + assert.Nil(t, table.Get(0)) + + table.Put(0, 0) + table.Put(5, 5) + table.Put(1, 1) + table.Put(2, 2) + + assert.Equal(t, 1, *table.Get(1)) + + table.Delete(1) + assert.Nil(t, table.Get(1)) + assert.EqualValues(t, 3, table.Size()) + + table.Delete(2) + table.Delete(5) + table.Delete(0) + assert.EqualValues(t, 0, table.Size()) +} + +func TestKeysBetween(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.EqualValues(t, []int{}, table.KeysBetween(0, 10)) + + table.Put(1, 1) + + assert.EqualValues(t, []int{}, table.KeysBetween(2, 10)) + assert.EqualValues(t, []int{1}, table.KeysBetween(1, 1)) + + table.Put(2, 2) + table.Put(5, 5) + + assert.EqualValues(t, []int{5}, table.KeysBetween(3, 10)) + assert.EqualValues(t, []int{1, 2, 5}, table.KeysBetween(1, 5)) +} + +func TestKeys(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.EqualValues(t, []int{}, table.Keys()) + + table.Put(1, 1) + + assert.EqualValues(t, []int{1}, table.Keys()) + + table.Put(2, 2) + table.Put(5, 5) + + assert.EqualValues(t, []int{1, 2, 5}, table.Keys()) + + table.Delete(2) + + assert.EqualValues(t, []int{1, 5}, table.Keys()) + +} + +func TestSizeBetween(t *testing.T) { + table := NewBST[int, int](intCompare) + + assert.EqualValues(t, 0, table.SizeBetween(0, 10)) + + table.Put(1, 1) + + assert.EqualValues(t, 0, table.SizeBetween(2, 10)) + assert.EqualValues(t, 1, table.SizeBetween(1, 1)) + + table.Put(2, 2) + table.Put(5, 5) + + assert.EqualValues(t, 1, table.SizeBetween(3, 10)) + assert.EqualValues(t, 3, table.SizeBetween(1, 5)) +} diff --git a/searching/symbol_table.go b/searching/symbol_table.go new file mode 100644 index 0000000..2cf4b09 --- /dev/null +++ b/searching/symbol_table.go @@ -0,0 +1,23 @@ +package searching + +// TODO: think about pointer semantics: where pointers should be used? +// Does go compiler silently convert values to pointers when they are leave table? +type SymbolTable[K any, V any] interface { + Put(K, V) // add value V with associated key K to symbol table + Get(K) *V // get value V with associated key K to symbol table, nil if value is absent + Size() int64 // number of key-value pairs + Min() *K // smallest key + Max() *K // largest key + Floor(K) *K // largest key less than or equal to K + Ceiling(K) *K // smallest key greater or equal to K + Rank(K) int64 // number of keys less than K. Rank(*Index(in)) = in + KeyByRank(int64) *K // key of specified rank. *Index(Rank(K)) = K + Contains(K) bool // check if key K exists in symbol table + IsEmpty() bool // check if symbol table is empty + DeleteMin() // delete value with smallest key + DeleteMax() // delete value with largest key + Delete(K) // delete value associated with key K. + KeysBetween(K, K) []K // keys between two other keys in sorted order + Keys() []K // all existing keys in sorted order + SizeBetween(K, K) int64 // number of keys between two keys +}