diff --git a/Project.toml b/Project.toml index f76da0d..d0a2ff1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.10" +version = "0.7.11" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractblocksparsearray/broadcast.jl b/src/abstractblocksparsearray/broadcast.jl index 3527831..ecef502 100644 --- a/src/abstractblocksparsearray/broadcast.jl +++ b/src/abstractblocksparsearray/broadcast.jl @@ -1,8 +1,8 @@ using BlockArrays: AbstractBlockedUnitRange, BlockSlice -using Base.Broadcast: Broadcast +using Base.Broadcast: Broadcast, BroadcastStyle function Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray}) - return BlockSparseArrayStyle{ndims(arraytype)}() + return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype))) end # Fix ambiguity error with `BlockArrays`. diff --git a/src/blocksparsearray/blocksparsearray.jl b/src/blocksparsearray/blocksparsearray.jl index abc3a99..7d782c8 100644 --- a/src/blocksparsearray/blocksparsearray.jl +++ b/src/blocksparsearray/blocksparsearray.jl @@ -171,11 +171,24 @@ function BlockSparseArray{T,N,A}( return BlockSparseArray{T,N,A}(undef, (dim1, dim_rest...)) end +function similartype_unchecked( + A::Type{<:AbstractArray{T}}, axt::Type{<:Tuple{Vararg{Any,N}}} +) where {T,N} + A′ = Base.promote_op(similar, A, axt) + return !isconcretetype(A′) ? Array{T,N} : A′ +end + function BlockSparseArray{T,N}( ::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}} ) where {T,N} axt = Tuple{blockaxistype.(axes)...} - A = similartype(Array{T}, axt) + # Ideally we would use: + # ```julia + # A = similartype(Array{T}, axt) + # ``` + # but that doesn't work when `similar` isn't defined or + # isn't type stable. + A = similartype_unchecked(Array{T}, axt) return BlockSparseArray{T,N,A}(undef, axes) end diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index d8ab5ec..7862086 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -1,23 +1,44 @@ -using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted +using Base.Broadcast: + Broadcast, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted using GPUArraysCore: @allowscalar using MapBroadcast: Mapped using DerivableInterfaces: DerivableInterfaces, @interface -abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end +abstract type AbstractBlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: + AbstractArrayStyle{N} end -function DerivableInterfaces.interface(::Type{<:AbstractBlockSparseArrayStyle}) - return BlockSparseArrayInterface() +blockstyle(::AbstractBlockSparseArrayStyle{N,B}) where {N,B<:AbstractArrayStyle{N}} = B() + +function Broadcast.BroadcastStyle( + style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle +) + style = Broadcast.result_style(blockstyle(style1), blockstyle(style2)) + return BlockSparseArrayStyle(style) end -struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end +function DerivableInterfaces.interface( + ::Type{<:AbstractBlockSparseArrayStyle{N,B}} +) where {N,B<:AbstractArrayStyle{N}} + return BlockSparseArrayInterface(interface(B)) +end -# Define for new sparse array types. -# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray}) -# return BlockSparseArrayStyle{ndims(arraytype)}() -# end +struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: + AbstractBlockSparseArrayStyle{N,B} + blockstyle::B +end +function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N} + return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle) +end +function BlockSparseArrayStyle{N,B}() where {N,B<:AbstractArrayStyle{N}} + return BlockSparseArrayStyle{N,B}(B()) +end +BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}()) BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}() BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}() +function BlockSparseArrayStyle{M,B}(::Val{N}) where {M,B<:AbstractArrayStyle{M},N} + return BlockSparseArrayStyle{N}(B(Val(N))) +end Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::DefaultArrayStyle{0}) = a function Broadcast.BroadcastStyle( diff --git a/test/test_basics.jl b/test/test_basics.jl index aad02b1..8da0276 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -33,6 +33,7 @@ using BlockSparseArrays: eachblockstoredindex, eachstoredblock, eachstoredblockdiagindex, + similartype_unchecked, sparsemortar, view! using GPUArraysCore: @allowscalar @@ -44,6 +45,28 @@ using TestExtras: @constinferred using TypeParameterAccessors: TypeParameterAccessors, Position include("TestBlockSparseArraysUtils.jl") +@testset "similartype_unchecked" begin + @test @constinferred(similartype_unchecked(Array{Float32}, NTuple{2,Int})) === + Matrix{Float32} + @test @constinferred(similartype_unchecked(Array{Float32}, NTuple{2,Base.OneTo{Int}})) === + Matrix{Float32} + if VERSION < v"1.11-" + # Not type stable in Julia 1.10. + @test similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int}) === Matrix{Float32} + @test similartype_unchecked(JLArray{Float32}, NTuple{2,Int}) === JLMatrix{Float32} + @test similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}}) === + JLMatrix{Float32} + else + @test @constinferred(similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int})) === + Matrix{Float32} + @test @constinferred(similartype_unchecked(JLArray{Float32}, NTuple{2,Int})) === + JLMatrix{Float32} + @test @constinferred( + similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}}) + ) === JLMatrix{Float32} + end +end + arrayts = (Array, JLArray) @testset "BlockSparseArrays (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts, elt in (Float32, Float64, Complex{Float32}, Complex{Float64})