aboutsummaryrefslogtreecommitdiff
path: root/di.go
blob: 8e41ee9b088c7a7fcb4b9e171e4048a838eedabc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package di

import (
	"fmt"
	"sync"
)

var ( // singletones
	services sync.Map
	cache    sync.Map
)

func init() {
	services = sync.Map{}
	cache = sync.Map{}
}

// Register service in di
func Register[T any](id string, constructor func() (*T, error)) {
	services.Store(id, constructor)
}

// Get services by type
func GetByType[T any]() ([]*T, error) {
	var err error
	result := []*T{}
	services.Range(func(id, constructor any) bool {
		if constructor, ok := constructor.(func() (*T, error)); ok {
			if instance, ok := cache.Load(id); ok {
				if instance, ok := instance.(*T); ok {
					result = append(result, instance)
				}
				return true
			}
			instance, instErr := constructor()
			if instErr != nil {
				err = instErr
				return false
			}
			cache.Store(id, instance)
			result = append(result, instance)
		}
		return true
	})
	return result, err
}

// Get services by interface
func GetByInterface[Interface any]() ([]Interface, error) {
	var err error
	result := []Interface{}
	services.Range(func(id, constructor any) bool {
		if constructor, ok := constructor.(func() (Interface, error)); ok {
			if instance, ok := cache.Load(id); ok {
				if instance, ok := instance.(Interface); ok {
					result = append(result, instance)
				}
				return true
			}
			instance, instErr := constructor()
			if instErr != nil {
				err = instErr
				return false
			}
			cache.Store(id, instance)
			result = append(result, instance)
		}
		return true
	})
	return result, err
}

// Get service by id and type
func Get[T any](id string) (*T, error) {
	if instance, ok := cache.Load(id); ok {
		if instance, ok := instance.(*T); ok {
			return instance, nil
		}
		return nil, fmt.Errorf("invalid type for service %s (%t)", id, instance)
	}
	if constructor, ok := services.Load(id); ok {
		if constructor, ok := constructor.(func() (*T, error)); ok {
			instance, err := constructor()
			if err != nil {
				return nil, err
			}
			cache.Store(id, instance)
			return instance, nil
		}
		return nil, fmt.Errorf("invalid constructor")
	}
	return nil, fmt.Errorf("unknown service %s", id)
}