diff --git a/Project.toml b/Project.toml index d0aab94f..613811f7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.5" +version = "0.7.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -31,7 +31,7 @@ Adapt = "4.1.1" Aqua = "0.8.9" ArrayLayouts = "1.10.4" BlockArrays = "1.2.0" -DerivableInterfaces = "0.5" +DerivableInterfaces = "0.5.2" DiagonalArrays = "0.3" Dictionaries = "0.4.3" FillArrays = "1.13.0" @@ -44,7 +44,7 @@ SparseArraysBase = "0.5" SplitApplyCombine = "1.2.3" TensorAlgebra = "0.3.2" Test = "1.10" -TypeParameterAccessors = "0.2.0, 0.3" +TypeParameterAccessors = "0.4" julia = "1.10" [extras] diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index 2e916729..35779057 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -29,10 +29,16 @@ axis(a::AbstractVector) = axes(a, 1) function eachblockaxis(a::AbstractVector) return map(axis, blocks(a)) end +function blockaxistype(a::AbstractVector) + return eltype(eachblockaxis(a)) +end # Take a collection of axes and mortar them # into a single blocked axis. function mortar_axis(axs) + return blockrange(axs) +end +function mortar_axis(axs::Vector{<:Base.OneTo{<:Integer}}) return blockedrange(length.(axs)) end diff --git a/src/abstractblocksparsearray/abstractblocksparsearray.jl b/src/abstractblocksparsearray/abstractblocksparsearray.jl index e90e97e0..f82e2ca9 100644 --- a/src/abstractblocksparsearray/abstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/abstractblocksparsearray.jl @@ -19,12 +19,12 @@ end # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N} - return @interface BlockSparseArrayInterface() getindex(a, I...) + return @interface interface(a) getindex(a, I...) end # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.getindex(a::AbstractBlockSparseArray{<:Any,0}) - return @interface BlockSparseArrayInterface() getindex(a) + return @interface interface(a) getindex(a) end ## # Fix ambiguity error with `BlockArrays`. @@ -39,7 +39,7 @@ end ## ## # Fix ambiguity error with `BlockArrays`. ## function Base.getindex(a::AbstractBlockSparseArray, I::Vararg{AbstractVector}) -## ## return @interface BlockSparseArrayInterface() getindex(a, I...) +## ## return @interface interface(a) getindex(a, I...) ## return ArrayLayouts.layout_getindex(a, I...) ## end @@ -47,13 +47,13 @@ end function Base.setindex!( a::AbstractBlockSparseArray{<:Any,N}, value, I::Vararg{Int,N} ) where {N} - @interface BlockSparseArrayInterface() setindex!(a, value, I...) + @interface interface(a) setindex!(a, value, I...) return a end # Fix ambiguity error. function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value) - @interface BlockSparseArrayInterface() setindex!(a, value) + @interface interface(a) setindex!(a, value) return a end diff --git a/src/abstractblocksparsearray/arraylayouts.jl b/src/abstractblocksparsearray/arraylayouts.jl index b875c5d9..684fe646 100644 --- a/src/abstractblocksparsearray/arraylayouts.jl +++ b/src/abstractblocksparsearray/arraylayouts.jl @@ -27,9 +27,8 @@ function Base.similar( elt::Type, axes, ) where {A,B} - # TODO: Check that this equals `similartype(blocktype(B), elt, axes)`, - # or maybe promote them? - output_blocktype = similartype(blocktype(A), elt, axes) + # TODO: Use something like `Base.promote_op(*, A, B)` to determine the output block type. + output_blocktype = similartype(blocktype(A), Type{elt}, Tuple{blockaxistype.(axes)...}) return similar(BlockSparseArray{elt,length(axes),output_blocktype}, axes) end diff --git a/src/abstractblocksparsearray/unblockedsubarray.jl b/src/abstractblocksparsearray/unblockedsubarray.jl index fc80e92f..a10d5419 100644 --- a/src/abstractblocksparsearray/unblockedsubarray.jl +++ b/src/abstractblocksparsearray/unblockedsubarray.jl @@ -30,10 +30,6 @@ function Broadcast.BroadcastStyle(arraytype::Type{<:UnblockedSubArray}) return BroadcastStyle(blocktype(parenttype(arraytype))) end -function TypeParameterAccessors.similartype(arraytype::Type{<:UnblockedSubArray}, elt::Type) - return similartype(blocktype(parenttype(arraytype)), elt) -end - function Base.similar( a::UnblockedSubArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} ) diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index b72f1e41..86425925 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -29,8 +29,8 @@ const AnyAbstractBlockSparseVecOrMat{T,N} = Union{ AnyAbstractBlockSparseVector{T},AnyAbstractBlockSparseMatrix{T} } -function DerivableInterfaces.interface(::Type{<:AnyAbstractBlockSparseArray}) - return BlockSparseArrayInterface() +function DerivableInterfaces.interface(arrayt::Type{<:AnyAbstractBlockSparseArray}) + return BlockSparseArrayInterface(interface(blocktype(arrayt))) end # a[1:2, 1:2] @@ -231,9 +231,9 @@ function Base.similar( end function blocksparse_similar(a, elt::Type, axes::Tuple) - return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}( - undef, axes - ) + ndims = length(axes) + blockt = similartype(blocktype(a), Type{elt}, Tuple{blockaxistype.(axes)...}) + return BlockSparseArray{elt,ndims,blockt}(undef, axes) end @interface ::AbstractBlockSparseArrayInterface function Base.similar( a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}} diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 5c1ed5be..b4ce4038 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -16,7 +16,13 @@ using BlockArrays: blocklength, blocks, findblockindex -using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface, zero! +using DerivableInterfaces: + DerivableInterfaces, + @interface, + AbstractArrayInterface, + DefaultArrayInterface, + interface, + zero! using LinearAlgebra: Adjoint, Transpose using SparseArraysBase: AbstractSparseArrayInterface, @@ -101,18 +107,47 @@ blockstype(a::BlockArray) = blockstype(typeof(a)) blocktype(arraytype::Type{<:BlockArray}) = eltype(blockstype(arraytype)) blocktype(a::BlockArray) = eltype(blocks(a)) -abstract type AbstractBlockSparseArrayInterface{N} <: AbstractSparseArrayInterface{N} end +abstract type AbstractBlockSparseArrayInterface{N,B<:AbstractArrayInterface{N}} <: + AbstractSparseArrayInterface{N} end + +function blockinterface(interface::AbstractBlockSparseArrayInterface{<:Any,B}) where {B} + return B() +end # TODO: Also support specifying the `blocktype` along with the `eltype`. -function Base.similar(::AbstractBlockSparseArrayInterface, T::Type, ax::Tuple) - return similar(BlockSparseArray{T}, ax) +function Base.similar(interface::AbstractBlockSparseArrayInterface, T::Type, ax::Tuple) + # TODO: Generalize by storing the block interface in the block sparse array interface. + N = length(ax) + B = similartype(typeof(blockinterface(interface)), Type{T}, Tuple{blockaxistype.(ax)...}) + return similar(BlockSparseArray{T,N,B}, ax) end -struct BlockSparseArrayInterface{N} <: AbstractBlockSparseArrayInterface{N} end +struct BlockSparseArrayInterface{N,B<:AbstractArrayInterface{N}} <: + AbstractBlockSparseArrayInterface{N,B} + blockinterface::B +end +function BlockSparseArrayInterface{N}(blockinterface::AbstractArrayInterface{N}) where {N} + return BlockSparseArrayInterface{N,typeof(blockinterface)}(blockinterface) +end +function BlockSparseArrayInterface{M,B}(::Val{N}) where {M,B<:AbstractArrayInterface{M},N} + B′ = B(Val(N)) + return BlockSparseArrayInterface(B′) +end +function BlockSparseArrayInterface{N}() where {N} + return BlockSparseArrayInterface{N}(DefaultArrayInterface{N}()) +end BlockSparseArrayInterface(::Val{N}) where {N} = BlockSparseArrayInterface{N}() BlockSparseArrayInterface{M}(::Val{N}) where {M,N} = BlockSparseArrayInterface{N}() BlockSparseArrayInterface() = BlockSparseArrayInterface{Any}() +function DerivableInterfaces.combine_interface_rule( + interface1::AbstractBlockSparseArrayInterface, + interface2::AbstractBlockSparseArrayInterface, +) + B = interface(blockinterface(interface1), blockinterface(interface2)) + return BlockSparseArrayInterface(B) +end + @interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(a::AbstractArray) return error("Not implemented") end diff --git a/test/Project.toml b/test/Project.toml index 9ce1d39a..8c52cd72 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -38,4 +38,4 @@ Suppressor = "0.2" TensorAlgebra = "0.3.2" Test = "1" TestExtras = "0.3" -TypeParameterAccessors = "0.3" +TypeParameterAccessors = "0.4"