Skip to content

Add a default rule for custom blocks #3570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions pyomo/core/base/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations
import copy
import functools
import logging
import sys
import weakref
Expand Down Expand Up @@ -2120,7 +2121,6 @@ def __init__(self, *args, **kwargs):
# initializer
self._dense = kwargs.pop('dense', True)
kwargs.setdefault('ctype', Block)
ActiveIndexedComponent.__init__(self, *args, **kwargs)
if _options is not None:
deprecation_warning(
"The Block 'options=' keyword is deprecated. "
Expand All @@ -2129,19 +2129,10 @@ def __init__(self, *args, **kwargs):
"the function arguments",
version='5.7.2',
)
if self.is_indexed():

def rule_wrapper(model, *_idx):
return _rule(model, *_idx, **_options)

else:

def rule_wrapper(model):
return _rule(model, **_options)

self._rule = Initializer(rule_wrapper)
self._rule = Initializer(functools.partial(_rule, **_options))
else:
self._rule = Initializer(_rule)
ActiveIndexedComponent.__init__(self, *args, **kwargs)
if _concrete:
# Call self.construct() as opposed to just setting the _constructed
# flag so that the base class construction procedure fires (this
Expand Down Expand Up @@ -2426,6 +2417,7 @@ class CustomBlock(Block):
def __init__(self, *args, **kwargs):
if self._default_ctype is not None:
kwargs.setdefault('ctype', self._default_ctype)
kwargs.setdefault("rule", getattr(self, '_default_rule', None))
Block.__init__(self, *args, **kwargs)

def __new__(cls, *args, **kwargs):
Expand All @@ -2446,13 +2438,56 @@ def __new__(cls, *args, **kwargs):
return super().__new__(cls._indexed_custom_block, *args, **kwargs)


def declare_custom_block(name, new_ctype=None):
class _custom_block_rule_redirect(object):
"""Functor to redirect the default rule to a BlockData method"""

def __init__(self, cls, name):
self.cls = cls
self.name = name

def __call__(self, block, *args, **kwargs):
return getattr(self.cls, self.name)(block, *args, **kwargs)


def declare_custom_block(name, new_ctype=None, rule=None):
"""Decorator to declare components for a custom block data class

This decorator simplifies the definition of custom derived Block
classes. With this decorator, developers must only implement the
derived "Data" class. The decorator automatically creates the
derived containers using the provided name, and adds them to the
current module:

>>> @declare_custom_block(name="FooBlock")
... class FooBlockData(BlockData):
... # custom block data class
... pass

>>> s = FooBlock()
>>> type(s)
<class 'ScalarFooBlock'>

>>> s = FooBlock([1,2])
>>> type(s)
<class 'IndexedFooBlock'>

It is frequently desirable for the custom class to have a default
``rule`` for constructing and populating new instances. The default
rule can be provided either as an explicit function or a string. If
a string, the rule is obtained by attribute lookup on the derived
Data class:

>>> @declare_custom_block(name="BarBlock", rule="build")
... class BarBlockData(BlockData):
... def build(self, *args):
... self.x = Var(initialize=5)

>>> m = pyo.ConcreteModel()
>>> m.b = BarBlock([1,2])
>>> print(m.b[1].x.value)
5
>>> print(m.b[2].x.value)
5

"""

def block_data_decorator(block_data):
Expand All @@ -2476,9 +2511,16 @@ def block_data_decorator(block_data):
"_ComponentDataClass": block_data,
# By default this new block does not declare a new ctype
"_default_ctype": None,
# Define the default rule (may be None)
"_default_rule": rule,
},
)

# If the default rule is a string, then replace it with a
# function that will look up the attribute on the data class.
if type(rule) is str:
comp._default_rule = _custom_block_rule_redirect(block_data, rule)

if new_ctype is not None:
if new_ctype is True:
comp._default_ctype = comp
Expand Down
14 changes: 10 additions & 4 deletions pyomo/core/base/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from pyomo.core.pyomoobject import PyomoObject
from pyomo.core.base.component_namer import name_repr, index_repr
from pyomo.core.base.global_set import UnindexedComponent_index
from pyomo.core.base.initializer import PartialInitializer

logger = logging.getLogger('pyomo.core')

Expand Down Expand Up @@ -451,10 +452,15 @@ def __init__(self, **kwds):
self.doc = kwds.pop('doc', None)
self._name = kwds.pop('name', None)
if kwds:
raise ValueError(
"Unexpected keyword options found while constructing '%s':\n\t%s"
% (type(self).__name__, ','.join(sorted(kwds.keys())))
)
# If there are leftover keywords, and the component has a
# rule, pass those keywords on to the rule
if getattr(self, '_rule', None) is not None:
self._rule = PartialInitializer(self._rule, **kwds)
else:
raise ValueError(
"Unexpected keyword options found while constructing '%s':\n\t%s"
% (type(self).__name__, ','.join(sorted(kwds.keys())))
)
#
# Verify that ctype has been specified.
#
Expand Down
38 changes: 26 additions & 12 deletions pyomo/core/base/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pyomo.common.pyomo_typing import overload
from typing import Union, Type

from pyomo.common.deprecation import RenamedClass
from pyomo.common.deprecation import RenamedClass, deprecated
from pyomo.common.errors import DeveloperError, TemplateExpressionError
from pyomo.common.formatting import tabular_writer
from pyomo.common.log import is_debug_set
Expand Down Expand Up @@ -545,7 +545,7 @@ def __init__(self, template_info, component, index):
def expr(self):
# Note that it is faster to just generate the expression from
# scratch than it is to clone it and replace the IndexTemplate objects
self.set_value(self.parent_component().rule(self.parent_block(), self.index()))
self.set_value(self.parent_component()._rule(self.parent_block(), self.index()))
return self.expr

def template_expr(self):
Expand Down Expand Up @@ -640,9 +640,9 @@ def __init__(self, *args, **kwargs):
_init = self._pop_from_kwargs('Constraint', kwargs, ('rule', 'expr'), None)
# Special case: we accept 2- and 3-tuples as constraints
if type(_init) is tuple:
self.rule = Initializer(_init, treat_sequences_as_mappings=False)
self._rule = Initializer(_init, treat_sequences_as_mappings=False)
else:
self.rule = Initializer(_init)
self._rule = Initializer(_init)

kwargs.setdefault('ctype', Constraint)
ActiveIndexedComponent.__init__(self, *args, **kwargs)
Expand All @@ -663,7 +663,7 @@ def construct(self, data=None):
for _set in self._anonymous_sets:
_set.construct()

rule = self.rule
rule = self._rule
try:
# We do not (currently) accept data for constructing Constraints
index = None
Expand Down Expand Up @@ -719,9 +719,9 @@ def construct(self, data=None):
timer.report()

def _getitem_when_not_present(self, idx):
if self.rule is None:
if self._rule is None:
raise KeyError(idx)
con = self._setitem_when_not_present(idx, self.rule(self.parent_block(), idx))
con = self._setitem_when_not_present(idx, self._rule(self.parent_block(), idx))
if con is None:
raise KeyError(idx)
return con
Expand All @@ -746,6 +746,20 @@ def _pprint(self):
],
)

@property
def rule(self):
return self._rule

@rule.setter
@deprecated(
f"The 'Constraint.rule' attribute will be made "
"read-only in a future Pyomo release.",
version='6.9.3.dev0',
remove_in='6.11',
)
def rule(self, rule):
self._rule = rule

def display(self, prefix="", ostream=None):
"""
Print component state information
Expand Down Expand Up @@ -971,14 +985,14 @@ def __init__(self, **kwargs):

super().__init__(Set(dimen=1), **kwargs)

self.rule = Initializer(
self._rule = Initializer(
_rule, treat_sequences_as_mappings=False, allow_generators=True
)
# HACK to make the "counted call" syntax work. We wait until
# after the base class is set up so that is_indexed() is
# reliable.
if self.rule is not None and type(self.rule) is IndexedCallInitializer:
self.rule = CountedCallInitializer(self, self.rule, self._starting_index)
if self._rule is not None and type(self._rule) is IndexedCallInitializer:
self._rule = CountedCallInitializer(self, self._rule, self._starting_index)

def construct(self, data=None):
"""
Expand All @@ -995,8 +1009,8 @@ def construct(self, data=None):
for _set in self._anonymous_sets:
_set.construct()

if self.rule is not None:
_rule = self.rule(self.parent_block(), ())
if self._rule is not None:
_rule = self._rule(self.parent_block(), ())
for cc in iter(_rule):
if cc is ConstraintList.End:
break
Expand Down
59 changes: 42 additions & 17 deletions pyomo/core/base/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,27 +340,27 @@ class IndexedCallInitializer(InitializerBase):
def __init__(self, _fcn):
self._fcn = _fcn

def __call__(self, parent, idx):
def __call__(self, parent, idx, **kwargs):
# Note: this is called by a component using data from a Set (so
# any tuple-like type should have already been checked and
# converted to a tuple; or flattening is turned off and it is
# the user's responsibility to sort things out.
if idx.__class__ is tuple:
return self._fcn(parent, *idx)
return self._fcn(parent, *idx, **kwargs)
else:
return self._fcn(parent, idx)
return self._fcn(parent, idx, **kwargs)


class ParameterizedIndexedCallInitializer(IndexedCallInitializer):
"""IndexedCallInitializer that accepts additional arguments"""

__slots__ = ()

def __call__(self, parent, idx, *args):
def __call__(self, parent, idx, *args, **kwargs):
if idx.__class__ is tuple:
return self._fcn(parent, *args, *idx)
return self._fcn(parent, *args, *idx, **kwargs)
else:
return self._fcn(parent, *args, idx)
return self._fcn(parent, *args, idx, **kwargs)


class CountedCallGenerator(object):
Expand Down Expand Up @@ -481,8 +481,8 @@ def __init__(self, _fcn, constant=True):
self._fcn = _fcn
self._constant = constant

def __call__(self, parent, idx):
return self._fcn(parent)
def __call__(self, parent, idx, **kwargs):
return self._fcn(parent, **kwargs)

def constant(self):
"""Return True if this initializer is constant across all indices"""
Expand All @@ -494,8 +494,8 @@ class ParameterizedScalarCallInitializer(ScalarCallInitializer):

__slots__ = ()

def __call__(self, parent, idx, *args):
return self._fcn(parent, *args)
def __call__(self, parent, idx, *args, **kwargs):
return self._fcn(parent, *args, **kwargs)


class DefaultInitializer(InitializerBase):
Expand Down Expand Up @@ -523,9 +523,9 @@ def __init__(self, initializer, default, exceptions):
self._default = default
self._exceptions = exceptions

def __call__(self, parent, index):
def __call__(self, parent, index, **kwargs):
try:
return self._initializer(parent, index)
return self._initializer(parent, index, **kwargs)
except self._exceptions:
return self._default

Expand All @@ -542,7 +542,7 @@ def indices(self):


class ParameterizedInitializer(InitializerBase):
"""Base class for all Initializer objects"""
"""Wrapper to provide additional positional arguments to Initializer objects"""

__slots__ = ('_base_initializer',)

Expand All @@ -565,8 +565,33 @@ def indices(self):
"""
return self._base_initializer.indices()

def __call__(self, parent, idx, *args):
return self._base_initializer(parent, idx)(parent, *args)
def __call__(self, parent, idx, *args, **kwargs):
return self._base_initializer(parent, idx)(parent, *args, **kwargs)


class PartialInitializer(InitializerBase):
"""Partial wrapper of an InitializerBase that supplies additional arguments"""

__slots__ = ('_fcn',)

def __init__(self, _fcn, *args, **kwargs):
self._fcn = functools.partial(_fcn, *args, **kwargs)

def constant(self):
return self._fcn.func.constant()

def contains_indices(self):
return self._fcn.func.contains_indices()

def indices(self):
return self._fcn.func.indices()

def __call__(self, parent, idx, *args, **kwargs):
# Note that the Initializer.__call__ API is different from the
# rule API. As a result, we cannot just inherit from
# IndexedCallInitializer and must instead implement our own
# __call__ here.
return self._fcn(parent, idx, *args, **kwargs)


_bound_sequence_types = collections.defaultdict(None.__class__)
Expand Down Expand Up @@ -618,8 +643,8 @@ def __init__(self, arg, obj=NOTSET):
arg, treat_sequences_as_mappings=treat_sequences_as_mappings
)

def __call__(self, parent, index):
val = self._initializer(parent, index)
def __call__(self, parent, index, **kwargs):
val = self._initializer(parent, index, **kwargs)
if _bound_sequence_types[val.__class__]:
return val
if _bound_sequence_types[val.__class__] is None:
Expand Down
Loading
Loading