aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAleksandr Kiryukhin <aakiryukhin@avito.ru>2022-07-26 04:56:30 +0300
committerAleksandr Kiryukhin <aakiryukhin@avito.ru>2022-07-26 04:56:30 +0300
commita84b0dc4a3605eb532f95bbab75463e2a0596b9c (patch)
tree9f46e0129ab0ca70270526eba41c9aa6b1157397
parentd3ad517530f0b03e66fe575afe90ec1429ec567d (diff)
Added get by interface method
-rw-r--r--README.md16
-rw-r--r--di.go44
-rw-r--r--di_test.go36
3 files changed, 80 insertions, 16 deletions
diff --git a/README.md b/README.md
index c20db6e..70745a9 100644
--- a/README.md
+++ b/README.md
@@ -17,12 +17,17 @@ di.Register("service id", func () (*Service, error) { /* construct service */ })
Get dependencies by type:
```go
-services, err := di.Get[Service]()
+services, err := di.GetByType[Service]()
```
Get dependencies by type and id:
```go
-service, err := di.GetById[Service]("service id")
+service, err := di.Get[Service]("service id")
+```
+
+Get dependencies by interface:
+```go
+services, err := di.GetByInterface[Worker]() // Worker is interface for many workers
```
### Go doc
@@ -30,8 +35,9 @@ service, err := di.GetById[Service]("service id")
```go
package di // import "go.neonxp.dev/di"
-func Get[T any]() ([]*T, error)
-func GetById[T any](id string) (*T, error)
+func Get[T any](id string) (*T, error)
+func GetByInterface[Interface any]() ([]Interface, error)
+func GetByType[T any]() ([]*T, error)
func Register[T any](id string, constructor func() (*T, error))
```
@@ -53,7 +59,7 @@ di.Register("serviceB", func() (*ServiceB, error) { // <- Register service B, th
})
// Do work ...
-service, err := di.GetById[ServiceB]("serviceB") // <- Get instantinated service B
+service, err := di.Get[ServiceB]("serviceB") // <- Get instantinated service B
if err != nil {
panic(err)
}
diff --git a/di.go b/di.go
index b015577..8e41ee9 100644
--- a/di.go
+++ b/di.go
@@ -15,8 +15,13 @@ func init() {
cache = sync.Map{}
}
+// Register service in di
+func Register[T any](id string, constructor func() (*T, error)) {
+ services.Store(id, constructor)
+}
+
// Get services by type
-func Get[T any]() ([]*T, error) {
+func GetByType[T any]() ([]*T, error) {
var err error
result := []*T{}
services.Range(func(id, constructor any) bool {
@@ -40,13 +45,38 @@ func Get[T any]() ([]*T, error) {
return result, err
}
-// Get service by type and id
-func GetById[T any](id string) (*T, error) {
+// Get services by interface
+func GetByInterface[Interface any]() ([]Interface, error) {
+ var err error
+ result := []Interface{}
+ services.Range(func(id, constructor any) bool {
+ if constructor, ok := constructor.(func() (Interface, error)); ok {
+ if instance, ok := cache.Load(id); ok {
+ if instance, ok := instance.(Interface); ok {
+ result = append(result, instance)
+ }
+ return true
+ }
+ instance, instErr := constructor()
+ if instErr != nil {
+ err = instErr
+ return false
+ }
+ cache.Store(id, instance)
+ result = append(result, instance)
+ }
+ return true
+ })
+ return result, err
+}
+
+// Get service by id and type
+func Get[T any](id string) (*T, error) {
if instance, ok := cache.Load(id); ok {
if instance, ok := instance.(*T); ok {
return instance, nil
}
- return nil, fmt.Errorf("invalid type %t for service %s", instance, id)
+ return nil, fmt.Errorf("invalid type for service %s (%t)", id, instance)
}
if constructor, ok := services.Load(id); ok {
if constructor, ok := constructor.(func() (*T, error)); ok {
@@ -57,11 +87,7 @@ func GetById[T any](id string) (*T, error) {
cache.Store(id, instance)
return instance, nil
}
- return nil, fmt.Errorf("invalid type %t for service %s", constructor, id)
+ return nil, fmt.Errorf("invalid constructor")
}
return nil, fmt.Errorf("unknown service %s", id)
}
-
-func Register[T any](id string, constructor func() (*T, error)) {
- services.Store(id, constructor)
-}
diff --git a/di_test.go b/di_test.go
index cd0eb23..a67a82e 100644
--- a/di_test.go
+++ b/di_test.go
@@ -11,7 +11,7 @@ func ExampleGet() {
return &ServiceA{}, nil
})
di.Register("serviceB", func() (*ServiceB, error) { // <- Register service B, that depends from service A
- serviceA, err := di.Get[ServiceA]() // <- Get dependency from container by type
+ serviceA, err := di.GetByType[ServiceA]() // <- Get dependency from container by type
if err != nil {
return nil, err
}
@@ -22,13 +22,29 @@ func ExampleGet() {
})
// Do work...
- service, err := di.GetById[ServiceB]("serviceB") // <- Get instantinated service B
+ service, err := di.Get[ServiceB]("serviceB") // <- Get instantinated service B
if err != nil {
panic(err)
}
service.DoStuff() // Output: Hello, world!
}
+func ExampleGet_interface() {
+ di.Register("worker1", func() (*Worker1, error) {
+ return &Worker1{}, nil
+ })
+ di.Register("worker2", func() (*Worker2, error) {
+ return &Worker2{}, nil
+ })
+ workers, err := di.GetByInterface[Worker]()
+ if err != nil {
+ panic(err)
+ }
+ for _, w := range workers {
+ w.Do()
+ }
+}
+
type ServiceA struct{}
func (d *ServiceA) DoStuff() {
@@ -42,3 +58,19 @@ type ServiceB struct {
func (d *ServiceB) DoStuff() {
d.ServiceA.DoStuff()
}
+
+type Worker interface {
+ Do()
+}
+
+type Worker1 struct{}
+
+func (w *Worker1) Do() {
+ fmt.Println("Worker 1 says hello")
+}
+
+type Worker2 struct{}
+
+func (w *Worker2) Do() {
+ fmt.Println("Worker 2 says hello")
+}