aboutsummaryrefslogtreecommitdiff
path: root/trie.go
blob: 3ed3c70acea29aae470f621daa84ff11c5d0ce5c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
package eventbus

import (
	"slices"
	"sync"
)

type node[T comparable] struct {
	children map[string]*node[T]
	values   []T
	mu       sync.RWMutex
}

func (n *node[T]) Put(path []string, value T) {
	n.mu.Lock()
	defer n.mu.Unlock()

	if len(path) == 0 {
		n.values = append(n.values, value)
		return
	}

	head, tail := path[0], path[1:]
	if _, ok := n.children[head]; !ok {
		n.children[head] = &node[T]{
			children: make(map[string]*node[T], len(n.children)),
			values:   make([]T, 0, cap(n.values)),
		}
	}

	n.children[head].Put(tail, value)
}

func (n *node[T]) Get(prefix []string, wildcardNode string) []T {
	n.mu.RLock()
	defer n.mu.RUnlock()

	if len(prefix) == 0 {
		return n.values
	}

	head, tail := prefix[0], prefix[1:]

	// Формируем список узлов которые нужно будет обойти
	list := make([]*node[T], 0, 2)

	if _, ok := n.children[head]; ok && head != wildcardNode {
		list = append(list, n.children[head])
	}

	// Добавляем wildcard только если он существует
	if _, ok := n.children[wildcardNode]; ok {
		list = append(list, n.children[wildcardNode])
	}

	result := make([]T, 0, len(n.values)+cap(n.values)*len(list))
	result = append(result, n.values...)

	// Собираем результаты от всех узлов в списке
	for _, child := range list {
		childResult := child.Get(tail, wildcardNode)
		result = append(result, childResult...)
	}

	return result
}

func (n *node[T]) Remove(value T) {
	n.mu.Lock()
	defer n.mu.Unlock()
	n.Walk(func(cur, child *node[T], name string) {
		child.remove(value)
		if len(child.children) == 0 && len(child.values) == 0 {
			delete(cur.children, name)
		}
	})
}

func (n *node[T]) Clear(remover func(value T)) {
	n.mu.Lock()
	defer n.mu.Unlock()
	n.Walk(func(cur, child *node[T], name string) {
		for _, value := range child.values {
			remover(value)
			child.remove(value)
		}
		if len(child.children) == 0 && len(child.values) == 0 {
			delete(cur.children, name)
		}
	})
}

func (n *node[T]) Walk(cb func(cur, child *node[T], name string)) {
	for name, child := range n.children {
		child.Walk(cb)
		cb(n, child, name)
	}
}

func (n *node[T]) remove(value T) {
	idx := slices.Index(n.values, value)
	if idx == -1 {
		return
	}
	n.values = slices.Delete(n.values, idx, idx+1)
}