From 7372c2bfb439f226c4d8c8eb0364500d1e19bfb6 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 18:05:13 -0400 Subject: [PATCH 1/8] More general block types in broadcast style --- Project.toml | 2 +- src/abstractblocksparsearray/broadcast.jl | 4 ++-- src/blocksparsearrayinterface/broadcast.jl | 20 ++++++++++++++++---- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index f76da0df..d0a2ff15 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.10" +version = "0.7.11" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/abstractblocksparsearray/broadcast.jl b/src/abstractblocksparsearray/broadcast.jl index 3527831c..ecef502c 100644 --- a/src/abstractblocksparsearray/broadcast.jl +++ b/src/abstractblocksparsearray/broadcast.jl @@ -1,8 +1,8 @@ using BlockArrays: AbstractBlockedUnitRange, BlockSlice -using Base.Broadcast: Broadcast +using Base.Broadcast: Broadcast, BroadcastStyle function Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray}) - return BlockSparseArrayStyle{ndims(arraytype)}() + return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype))) end # Fix ambiguity error with `BlockArrays`. diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index d8ab5ec8..56d73ad2 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -3,21 +3,33 @@ using GPUArraysCore: @allowscalar using MapBroadcast: Mapped using DerivableInterfaces: DerivableInterfaces, @interface -abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end +abstract type AbstractBlockSparseArrayStyle{N,B} <: AbstractArrayStyle{N} end -function DerivableInterfaces.interface(::Type{<:AbstractBlockSparseArrayStyle}) - return BlockSparseArrayInterface() +function DerivableInterfaces.interface( + ::Type{<:AbstractBlockSparseArrayStyle{N,B}} +) where {N,B} + return BlockSparseArrayInterface(interface(B)) end -struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end +struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: + AbstractBlockSparseArrayStyle{N,B} + blockstyle::B +end +function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N} + return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle) +end # Define for new sparse array types. # function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray}) # return BlockSparseArrayStyle{ndims(arraytype)}() # end +BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}()) BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}() BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}() +function BlockSparseArrayStyle{M,B}(::Val{N}) where {M,B<:AbstractArrayStyle{M},N} + return BlockSparseArrayStyle{N}(B(Val(N))) +end Broadcast.BroadcastStyle(a::BlockSparseArrayStyle, ::DefaultArrayStyle{0}) = a function Broadcast.BroadcastStyle( From 7ecf6ebbf5f81be5e9d26360435094c5e0ed6a93 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 19:06:08 -0400 Subject: [PATCH 2/8] Define mixing block sparse array styles --- src/blocksparsearrayinterface/broadcast.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index 56d73ad2..3975cb28 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -1,13 +1,22 @@ -using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted +using Base.Broadcast: + Broadcast, BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted using GPUArraysCore: @allowscalar using MapBroadcast: Mapped using DerivableInterfaces: DerivableInterfaces, @interface -abstract type AbstractBlockSparseArrayStyle{N,B} <: AbstractArrayStyle{N} end +abstract type AbstractBlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: + AbstractArrayStyle{N} end +blockstyle(::AbstractBlockSparseArrayStyle{<:Any,B}) where {<:Any,B} = B() + +function Broadcast.BroadcastStyle( + style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle +) + return BlockSparseArrayStyle(BroadcastStyle(blockstyle(style1), blockstyle(style2))) +end function DerivableInterfaces.interface( ::Type{<:AbstractBlockSparseArrayStyle{N,B}} -) where {N,B} +) where {N,B<:AbstractArrayStyle{N}} return BlockSparseArrayInterface(interface(B)) end From 9c238194a3e879cbfac3d45f55fd8a6615310e3d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 19:07:33 -0400 Subject: [PATCH 3/8] Define mixing block sparse array styles --- src/blocksparsearrayinterface/broadcast.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index 3975cb28..1d00a463 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -6,7 +6,8 @@ using DerivableInterfaces: DerivableInterfaces, @interface abstract type AbstractBlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: AbstractArrayStyle{N} end -blockstyle(::AbstractBlockSparseArrayStyle{<:Any,B}) where {<:Any,B} = B() + +blockstyle(::AbstractBlockSparseArrayStyle{N,B}) where {N,B<:AbstractArrayStyle{N}} = B() function Broadcast.BroadcastStyle( style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle From 5c8b83cfc5420d1c6d9c7b7cfd65cc685052d528 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 19:16:21 -0400 Subject: [PATCH 4/8] Fix combining block broadcast styles --- src/blocksparsearrayinterface/broadcast.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index 1d00a463..03be3b3a 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -12,7 +12,8 @@ blockstyle(::AbstractBlockSparseArrayStyle{N,B}) where {N,B<:AbstractArrayStyle{ function Broadcast.BroadcastStyle( style1::AbstractBlockSparseArrayStyle, style2::AbstractBlockSparseArrayStyle ) - return BlockSparseArrayStyle(BroadcastStyle(blockstyle(style1), blockstyle(style2))) + style = Broadcast.result_style(blockstyle(style1), blockstyle(style2)) + return BlockSparseArrayStyle(style) end function DerivableInterfaces.interface( From f3e9290b51120f0ecb8296ba95dcd83af0065b17 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 19:22:29 -0400 Subject: [PATCH 5/8] Fix tests --- src/blocksparsearrayinterface/broadcast.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index 03be3b3a..78620869 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -30,11 +30,9 @@ function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N} return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle) end -# Define for new sparse array types. -# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray}) -# return BlockSparseArrayStyle{ndims(arraytype)}() -# end - +function BlockSparseArrayStyle{N,B}() where {N,B<:AbstractArrayStyle{N}} + return BlockSparseArrayStyle{N,B}(B()) +end BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}()) BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}() BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}() From 3bd326b27b32f7337e4a7ef2e65cff227c0865aa Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 21:03:49 -0400 Subject: [PATCH 6/8] Catch cases when block type can't be determined from similartype --- src/blocksparsearray/blocksparsearray.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/blocksparsearray/blocksparsearray.jl b/src/blocksparsearray/blocksparsearray.jl index abc3a992..16b37f1c 100644 --- a/src/blocksparsearray/blocksparsearray.jl +++ b/src/blocksparsearray/blocksparsearray.jl @@ -171,11 +171,22 @@ function BlockSparseArray{T,N,A}( return BlockSparseArray{T,N,A}(undef, (dim1, dim_rest...)) end +function unchecked_similartype(a, args...) + A = Base.promote_op(similar, a, args...) + return !isconcretetype(A) ? Array{T,N} : A +end + function BlockSparseArray{T,N}( ::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}} ) where {T,N} axt = Tuple{blockaxistype.(axes)...} - A = similartype(Array{T}, axt) + # Ideally we would use: + # ```julia + # A = similartype(Array{T}, axt) + # ``` + # but that doesn't work when `similar` isn't defined or + # isn't type stable. + A = unchecked_similartype(Array{T}, axt) return BlockSparseArray{T,N,A}(undef, axes) end From 0c8dc98465ecd15336c9d57fa1688dfdf27a3930 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 22:32:10 -0400 Subject: [PATCH 7/8] Fix constructor --- src/blocksparsearray/blocksparsearray.jl | 10 ++++++---- test/test_basics.jl | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/blocksparsearray/blocksparsearray.jl b/src/blocksparsearray/blocksparsearray.jl index 16b37f1c..7d782c8c 100644 --- a/src/blocksparsearray/blocksparsearray.jl +++ b/src/blocksparsearray/blocksparsearray.jl @@ -171,9 +171,11 @@ function BlockSparseArray{T,N,A}( return BlockSparseArray{T,N,A}(undef, (dim1, dim_rest...)) end -function unchecked_similartype(a, args...) - A = Base.promote_op(similar, a, args...) - return !isconcretetype(A) ? Array{T,N} : A +function similartype_unchecked( + A::Type{<:AbstractArray{T}}, axt::Type{<:Tuple{Vararg{Any,N}}} +) where {T,N} + A′ = Base.promote_op(similar, A, axt) + return !isconcretetype(A′) ? Array{T,N} : A′ end function BlockSparseArray{T,N}( @@ -186,7 +188,7 @@ function BlockSparseArray{T,N}( # ``` # but that doesn't work when `similar` isn't defined or # isn't type stable. - A = unchecked_similartype(Array{T}, axt) + A = similartype_unchecked(Array{T}, axt) return BlockSparseArray{T,N,A}(undef, axes) end diff --git a/test/test_basics.jl b/test/test_basics.jl index aad02b1c..83bc4d1d 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -33,6 +33,7 @@ using BlockSparseArrays: eachblockstoredindex, eachstoredblock, eachstoredblockdiagindex, + similartype_unchecked, sparsemortar, view! using GPUArraysCore: @allowscalar @@ -44,6 +45,20 @@ using TestExtras: @constinferred using TypeParameterAccessors: TypeParameterAccessors, Position include("TestBlockSparseArraysUtils.jl") +@testset "similartype_unchecked" begin + @test @constinferred(similartype_unchecked(Array{Float32}, NTuple{2,Int})) === + Matrix{Float32} + @test @constinferred(similartype_unchecked(Array{Float32}, NTuple{2,Base.OneTo{Int}})) === + Matrix{Float32} + @test @constinferred(similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int})) === + Matrix{Float32} + @test @constinferred(similartype_unchecked(JLArray{Float32}, NTuple{2,Int})) === + JLMatrix{Float32} + @test @constinferred( + similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}}) + ) === JLMatrix{Float32} +end + arrayts = (Array, JLArray) @testset "BlockSparseArrays (arraytype=$arrayt, eltype=$elt)" for arrayt in arrayts, elt in (Float32, Float64, Complex{Float32}, Complex{Float64}) From a238b61216a67806b25d8bbe5ae9faed0823899d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 12 Jun 2025 22:52:29 -0400 Subject: [PATCH 8/8] Fix tests in Julia 1.10 --- test/test_basics.jl | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index 83bc4d1d..8da0276d 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -50,13 +50,21 @@ include("TestBlockSparseArraysUtils.jl") Matrix{Float32} @test @constinferred(similartype_unchecked(Array{Float32}, NTuple{2,Base.OneTo{Int}})) === Matrix{Float32} - @test @constinferred(similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int})) === - Matrix{Float32} - @test @constinferred(similartype_unchecked(JLArray{Float32}, NTuple{2,Int})) === - JLMatrix{Float32} - @test @constinferred( - similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}}) - ) === JLMatrix{Float32} + if VERSION < v"1.11-" + # Not type stable in Julia 1.10. + @test similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int}) === Matrix{Float32} + @test similartype_unchecked(JLArray{Float32}, NTuple{2,Int}) === JLMatrix{Float32} + @test similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}}) === + JLMatrix{Float32} + else + @test @constinferred(similartype_unchecked(AbstractArray{Float32}, NTuple{2,Int})) === + Matrix{Float32} + @test @constinferred(similartype_unchecked(JLArray{Float32}, NTuple{2,Int})) === + JLMatrix{Float32} + @test @constinferred( + similartype_unchecked(JLArray{Float32}, NTuple{2,Base.OneTo{Int}}) + ) === JLMatrix{Float32} + end end arrayts = (Array, JLArray)