diff --git a/src/stage1/recurse_fwd.jl b/src/stage1/recurse_fwd.jl index 2c561e73..9cb763f5 100644 --- a/src/stage1/recurse_fwd.jl +++ b/src/stage1/recurse_fwd.jl @@ -24,7 +24,8 @@ function (::∂☆new{1})(B::Type, xs::AbstractTangentBundle{1}...) tangent_nt = NamedTuple{names}(tangent_tup) Tangent{B, typeof(tangent_nt)}(tangent_nt) end - return TaylorBundle{1, B}(the_primal, (the_partial,)) + B2 = typeof(the_primal) # HACK: if the_primal actually has types in it then we want to make sure we get DataType not Type(...) + return TaylorBundle{1, B2}(the_primal, (the_partial,)) end function (::∂☆new{N})(B::Type, xs::AbstractTangentBundle{N}...) where {N} diff --git a/test/forward.jl b/test/forward.jl index 1e4b7142..4f2c6ae6 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -148,6 +148,18 @@ end end +@testset "types in tuples" begin + function foo(a) + tup = (a, 2a, Int) + return tup[2] + end + + let var"'" = Diffractor.PrimeDerivativeFwd + @test foo'(100.0) == 2.0 + end +end + + @testset "taylor_compatible" begin taylor_compatible = Diffractor.taylor_compatible