diff --git a/README.md b/README.md index 021f160..e7ffb5e 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,9 @@ type TypeMapper interface { // This is really only useful for mapping a value as an interface, as interfaces // cannot at this time be referenced directly without a pointer. MapTo(interface{}, interface{}) TypeMapper + // Maps the outputs types of the function to the handler. + // The handler is run whenever the type is requested. + MapHandler(interface{}) TypeMapper // Provides a possibility to directly insert a mapping based on type and value. // This makes it possible to directly map type arguments not possible to instantiate // with reflect like unidirectional channels. diff --git a/inject.go b/inject.go index 3ff713c..b735c24 100644 --- a/inject.go +++ b/inject.go @@ -43,6 +43,9 @@ type TypeMapper interface { // This is really only useful for mapping a value as an interface, as interfaces // cannot at this time be referenced directly without a pointer. MapTo(interface{}, interface{}) TypeMapper + // Maps the outputs types of the function to the handler. + // The handler is run whenever the type is requested. + MapHandler(interface{}) TypeMapper // Provides a possibility to directly insert a mapping based on type and value. // This makes it possible to directly map type arguments not possible to instantiate // with reflect like unidirectional channels. @@ -148,6 +151,19 @@ func (i *injector) MapTo(val interface{}, ifacePtr interface{}) TypeMapper { return i } +// Maps all of val's output types to the function val, which +// is executed whenever the type is requested. +func (i *injector) MapHandler(val interface{}) TypeMapper { + t := reflect.TypeOf(val) + v := reflect.ValueOf(val) + + for idx := 0; idx < t.NumOut(); idx++ { + i.values[t.Out(idx)] = v + } + + return i +} + // Maps the given reflect.Type to the given reflect.Value and returns // the Typemapper the mapping has been registered in. func (i *injector) Set(typ reflect.Type, val reflect.Value) TypeMapper { @@ -159,17 +175,40 @@ func (i *injector) Get(t reflect.Type) reflect.Value { val := i.values[t] if val.IsValid() { - return val + if val.Kind() != reflect.Func || val.Type() == t { + return val + } + + if results, err := i.Invoke(val.Interface()); err == nil { + for _, r := range results { + if r.Type() == t { + return r + } + } + } } // no concrete types found, try to find implementors // if t is an interface if t.Kind() == reflect.Interface { for k, v := range i.values { - if k.Implements(t) { + if !k.Implements(t) { + continue + } + + if v.Kind() != reflect.Func { val = v break } + + if results, err := i.Invoke(v.Interface()); err == nil { + for _, r := range results { + if r.Type().Implements(t) { + val = r + break + } + } + } } } diff --git a/inject_test.go b/inject_test.go index eb94471..756de72 100644 --- a/inject_test.go +++ b/inject_test.go @@ -157,3 +157,56 @@ func TestInjectImplementors(t *testing.T) { expect(t, injector.Get(inject.InterfaceOf((*fmt.Stringer)(nil))).IsValid(), true) } + +func Test_InjectorMapHandler(t *testing.T) { + injector := inject.New() + + timesRun := 0 // Count number of times the handler was run + + handler := func() (string, int) { + timesRun++ + return "some dependency", 11 + } + + injector.MapHandler(handler) + + expect(t, injector.Get(reflect.TypeOf("string")).IsValid(), true) + expect(t, injector.Get(reflect.TypeOf(11)).IsValid(), true) + expect(t, injector.Get(reflect.TypeOf(handler)).IsValid(), false) // Handler itself should NOT be mapped + expect(t, timesRun, 2) +} + +func Test_InjectorInvokeWithMapHandler(t *testing.T) { + injector := inject.New() + + handler := func() (string, int) { + return "some dependency", 11 + } + + injector.MapHandler(handler) + + var ( + s string + i int + ) + + injector.Invoke(func(j string, k int) { + s = j + i = k + }) + + expect(t, s, "some dependency") + expect(t, i, 11) +} + +func Test_InjectorCanAcceptFuncs(t *testing.T) { + injector := inject.New() + + handler := func() (string, int) { + return "some dependency", 11 + } + + injector.Map(handler) + + expect(t, injector.Get(reflect.TypeOf(handler)).IsValid(), true) +}