Skip to content

feat: oscillation scaling #3745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
maxTimeStepCuts="100"
maxSubSteps="1000"
newtonMaxIter="10"
logLevel="1"/>
<LinearSolverParameters directParallel="0"/>
minTimeStepIncreaseInterval="0"
logLevel="1"
oscillationScaling="1"
oscillationFraction="0.1"
oscillationScalingFactor="0.9"
oscillationCheckDepth="2"/>
<LinearSolverParameters solverType="gmres"
preconditionerType="mgr"/>

</CompositionalMultiphaseFVM>
</Solvers>
<Included>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,7 @@ void LinearSolverParametersInput::print()
}
}
TableLayout const tableLayout = TableLayout( GEOS_FMT( "{}: linear solver", getParent().getName() ),
{ TableLayout::Column()
.setName( "Parameter" )
.setValuesAlignment( TableLayout::Alignment::left ),
{ TableLayout::Column().setName( "Parameter" ).setValuesAlignment( TableLayout::Alignment::left ),
"Value" } );
TableTextFormatter const tableFormatter( tableLayout );
GEOS_LOG_RANK_0( tableFormatter.toString( tableData ));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,32 @@ NonlinearSolverParameters::NonlinearSolverParameters( string const & name,
setInputFlag( dataRepository::InputFlags::OPTIONAL ).
setDescription( "Nonlinear acceleration type for sequential solver." );

registerWrapper( viewKeysStruct::oscillationScalingString(), &m_oscillationScaling ).
setApplyDefaultValue( 0 ).
setInputFlag( dataRepository::InputFlags::OPTIONAL ).
setDescription( "Flag to enable oscillation detection and scaling. "
"If set to 1, oscillation detection is enabled and the solution will be scaled if oscillations are detected." );

registerWrapper( viewKeysStruct::oscillationScalingFactorString(), &m_oscillationScalingFactor ).
setApplyDefaultValue( 0.5 ).
setInputFlag( dataRepository::InputFlags::OPTIONAL ).
setDescription( "Scaling factor to dump solution oscillations." );

registerWrapper( viewKeysStruct::oscillationCheckDepthString(), &m_oscillationCheckDepth ).
setApplyDefaultValue( 3 ).
setInputFlag( dataRepository::InputFlags::OPTIONAL ).
setDescription( "Depth in solution history to check for oscillations." );

registerWrapper( viewKeysStruct::oscillationToleranceString(), &m_oscillationTolerance ).
setApplyDefaultValue( 0.01 ).
setInputFlag( dataRepository::InputFlags::OPTIONAL ).
setDescription( "Tolerance for oscillation detection." );

registerWrapper( viewKeysStruct::oscillationFractionString(), &m_oscillationFraction ).
setApplyDefaultValue( 0.05 ).
setInputFlag( dataRepository::InputFlags::OPTIONAL ).
setDescription( "Fraction of dofs oscillating to declare oscillation." );

addLogLevel< logInfo::Convergence >();
addLogLevel< logInfo::NonlinearSolver >();
addLogLevel< logInfo::LineSearch >();
Expand All @@ -195,6 +221,25 @@ void NonlinearSolverParameters::postInputInitialization()
GEOS_ERROR_IF_LE_MSG( m_lineSearchResidualFactor, 0.0,
getWrapperDataContext( viewKeysStruct::lineSearchResidualFactorString() ) << ": should be positive" );

if( m_oscillationScaling > 0 )
{
// check oscillation parameters
GEOS_ERROR_IF_LE_MSG( m_oscillationScalingFactor, 0.0,
getWrapperDataContext( viewKeysStruct::oscillationScalingFactorString() ) << ": should be positive" );
GEOS_ERROR_IF_GT_MSG( m_oscillationScalingFactor, 1.0,
getWrapperDataContext( viewKeysStruct::oscillationScalingFactorString() ) << ": can not be more than 1.0" );
GEOS_ERROR_IF_LT_MSG( m_oscillationCheckDepth, 2,
getWrapperDataContext( viewKeysStruct::oscillationCheckDepthString() ) << ": can not be less than 2" );
GEOS_ERROR_IF_LE_MSG( m_oscillationTolerance, 0.0,
getWrapperDataContext( viewKeysStruct::oscillationToleranceString() ) << ": should be positive" );
GEOS_ERROR_IF_GE_MSG( m_oscillationTolerance, 1.0,
getWrapperDataContext( viewKeysStruct::oscillationToleranceString() ) << ": can not be more than 1.0" );
GEOS_ERROR_IF_LT_MSG( m_oscillationFraction, 0.0,
getWrapperDataContext( viewKeysStruct::oscillationFractionString() ) << ": can not be negative" );
GEOS_ERROR_IF_GT_MSG( m_oscillationFraction, 1.0,
getWrapperDataContext( viewKeysStruct::oscillationFractionString() ) << ": can not be more than 1.0" );
}

if( getLogLevel() > 0 )
{
print();
Expand Down Expand Up @@ -237,8 +282,18 @@ void NonlinearSolverParameters::print() const
tableData.addRow( "Sequential convergence criterion", m_sequentialConvergenceCriterion );
tableData.addRow( "Subcycling", m_subcyclingOption );
}
tableData.addRow( "Oscillation detection and scaling", m_oscillationScaling );
if( m_oscillationScaling > 0 )
{
tableData.addRow( " Scaling factor", m_oscillationScalingFactor );
tableData.addRow( " Check depth", m_oscillationCheckDepth );
tableData.addRow( " Tolerance", m_oscillationTolerance );
tableData.addRow( " Fraction of dofs oscillating", m_oscillationFraction );
}

TableLayout const tableLayout = TableLayout( GEOS_FMT( "{}: nonlinear solver", getParent().getName() ),
{ "Parameter", "Value" } );
{ TableLayout::Column().setName( "Parameter" ).setValuesAlignment( TableLayout::Alignment::left ),
"Value" } );
TableTextFormatter const tableFormatter( tableLayout );
GEOS_LOG_RANK_0( tableFormatter.toString( tableData ));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ class NonlinearSolverParameters : public dataRepository::Group
static constexpr char const * sequentialConvergenceCriterionString() { return "sequentialConvergenceCriterion"; }
static constexpr char const * subcyclingOptionString() { return "subcycling"; }
static constexpr char const * nonlinearAccelerationTypeString() { return "nonlinearAccelerationType"; }

static constexpr char const * oscillationScalingString() { return "oscillationScaling"; }
static constexpr char const * oscillationScalingFactorString() { return "oscillationScalingFactor"; }
static constexpr char const * oscillationCheckDepthString() { return "oscillationCheckDepth"; }
static constexpr char const * oscillationToleranceString() { return "oscillationTolerance"; }
static constexpr char const * oscillationFractionString() { return "oscillationFraction"; }
} viewKeys;

/**
Expand Down Expand Up @@ -352,7 +358,22 @@ class NonlinearSolverParameters : public dataRepository::Group
NonlinearAccelerationType m_nonlinearAccelerationType;

/// Value used to make sure that residual normalizers are not too small when computing residual norm
real64 m_minNormalizer = 1e-12;
real64 m_minNormalizer;

/// Flag to enable oscillation detection and scaling
integer m_oscillationScaling;

/// Oscillation scaling factor
real64 m_oscillationScalingFactor;

/// Oscillation check depth in solution history
integer m_oscillationCheckDepth;

/// Tolerance for oscillation detection
real64 m_oscillationTolerance;

/// Fraction of dofs oscillating to declare oscillation
real64 m_oscillationFraction;
};

ENUM_STRINGS( NonlinearSolverParameters::LineSearchAction,
Expand Down
93 changes: 89 additions & 4 deletions src/coreComponents/physicsSolvers/PhysicsSolverBase.cpp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A unit test or integrated test to exercise the detection logic would be good. Ideally this could be exercised in isolation from a full simulation run, but this isn't always possible..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i enabled it for one example from integrated tests

Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,11 @@ void PhysicsSolverBase::implicitStepSetup( real64 const & GEOS_UNUSED_PARAM( tim
real64 const & GEOS_UNUSED_PARAM( dt ),
DomainPartition & GEOS_UNUSED_PARAM( domain ) )
{
GEOS_THROW( "PhysicsSolverBase::ImplicitStepSetup called!. Should be overridden.", std::runtime_error );
// clean the solution history
while( m_solutionHistory.size() > 0 )
{
m_solutionHistory.eraseArray( 0 );
}
}

void PhysicsSolverBase::setupDofs( DomainPartition const & GEOS_UNUSED_PARAM( domain ),
Expand Down Expand Up @@ -1396,9 +1400,29 @@ bool PhysicsSolverBase::checkSystemSolution( DomainPartition & GEOS_UNUSED_PARAM

real64 PhysicsSolverBase::scalingForSystemSolution( DomainPartition & GEOS_UNUSED_PARAM( domain ),
DofManager const & GEOS_UNUSED_PARAM( dofManager ),
arrayView1d< real64 const > const & GEOS_UNUSED_PARAM( localSolution ) )
arrayView1d< real64 const > const & localSolution )
{
return 1.0;
real64 scalingFactor = 1.0;

// Check for oscillations
if( m_nonlinearSolverParameters.m_oscillationScaling )
{
if( detectOscillations() )
{
scalingFactor *= m_nonlinearSolverParameters.m_oscillationScalingFactor;
GEOS_LOG_LEVEL_RANK_0( logInfo::NonlinearSolver,
GEOS_FMT( " {}: oscillation detected, scaling factor set to {}", getName(), scalingFactor ) );
}

m_solutionHistory.appendArray( localSolution.begin(), localSolution.end());
if( m_solutionHistory.size() > m_nonlinearSolverParameters.m_oscillationCheckDepth )
{
// remove the oldest solution from the history
m_solutionHistory.eraseArray( 0 );
}
}

return scalingFactor;
}

void PhysicsSolverBase::applySystemSolution( DofManager const & GEOS_UNUSED_PARAM( dofManager ),
Expand Down Expand Up @@ -1432,7 +1456,11 @@ void PhysicsSolverBase::resetConfigurationToBeginningOfStep( DomainPartition & G

void PhysicsSolverBase::resetStateToBeginningOfStep( DomainPartition & GEOS_UNUSED_PARAM( domain ) )
{
GEOS_ERROR( "PhysicsSolverBase::ResetStateToBeginningOfStep called!. Should be overridden." );
// clean the solution history
while( m_solutionHistory.size() > 0 )
{
m_solutionHistory.eraseArray( 0 );
}
}

bool PhysicsSolverBase::resetConfigurationToDefault( DomainPartition & GEOS_UNUSED_PARAM( domain ) ) const
Expand Down Expand Up @@ -1512,6 +1540,63 @@ void PhysicsSolverBase::saveSequentialIterationState( DomainPartition & GEOS_UNU
GEOS_ERROR( "Call to PhysicsSolverBase::saveSequentialIterationState. Method should be overloaded by the solver" );
}

// Detect oscillations for all dofs in the solution history
bool PhysicsSolverBase::detectOscillations() const
{
// grab the parameters
integer const oscillationCheckDepth = m_nonlinearSolverParameters.m_oscillationCheckDepth;
real64 const oscillationTolerance = m_nonlinearSolverParameters.m_oscillationTolerance;
real64 const oscillationFraction = m_nonlinearSolverParameters.m_oscillationFraction;

if( m_solutionHistory.size() < oscillationCheckDepth )
return false; // not enough history to check oscillations

RAJA::ReduceSum< parallelDeviceReduce, localIndex > oscillationCount = 0;

auto const solutionHistory = m_solutionHistory.toViewConst();
localIndex const numDofs = m_solutionHistory[0].size();
localIndex const historySize = m_solutionHistory.size();

RAJA::forall< parallelDevicePolicy<> >( RAJA::TypedRangeSegment< localIndex >( 0, numDofs ),
[=] GEOS_HOST_DEVICE ( localIndex const dof )
{
bool oscillationDetected = true;
for( localIndex i = historySize - 1; i > historySize - oscillationCheckDepth; --i )
{
real64 dxCur = solutionHistory[i][dof];
real64 dxPrev = solutionHistory[i-1][dof];

if( LvArray::math::abs( dxCur ) < oscillationTolerance || LvArray::math::abs( dxPrev ) < oscillationTolerance )
{
oscillationDetected = false;
break; // solution changes are too small
}

real64 maxAbs = LvArray::math::max( LvArray::math::abs( dxCur ), LvArray::math::abs( dxPrev ) );
if( LvArray::math::abs( dxCur + dxPrev ) / maxAbs > oscillationTolerance )
{
oscillationDetected = false;
break; // solution changes are not oscillating
}

if( dxCur * dxPrev > 0 )
{
oscillationDetected = false;
break; // sign is not oscillating
}
}

if( oscillationDetected )
{
oscillationCount += 1;
}
} );

real64 const f = static_cast< real64 >( MpiWrapper::sum( oscillationCount.get() ) ) / MpiWrapper::sum( numDofs );

return f > oscillationFraction;
}

#if defined(GEOS_USE_PYGEOSX)
PyTypeObject * PhysicsSolverBase::getPythonType() const
{ return python::getPySolverType(); }
Expand Down
9 changes: 9 additions & 0 deletions src/coreComponents/physicsSolvers/PhysicsSolverBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,12 @@ class PhysicsSolverBase : public ExecutableGroup
return getConstitutiveModel< CONSTITUTIVE_TYPE >( subRegion, getConstitutiveName< CONSTITUTIVE_TYPE >( subRegion ) );
}

/**
* @brief Detect oscillations in the solution
* @return true if oscillations are detected, false otherwise
*/
bool detectOscillations() const;

/// Courant–Friedrichs–Lewy factor for the timestep
real64 m_cflFactor;

Expand Down Expand Up @@ -1090,6 +1096,9 @@ class PhysicsSolverBase : public ExecutableGroup
/// Timers for the aggregate profiling of the solver
std::map< std::string, std::chrono::system_clock::duration > m_timers;

/// History of the solution vector, used for oscillation detection
ArrayOfArrays< real64 > m_solutionHistory;

private:
/// List of names of regions the solver will be applied to
string_array m_targetRegionNames;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1377,10 +1377,12 @@ void CompositionalMultiphaseBase::initializePostInitialConditionsPreSubGroups()
}

void
CompositionalMultiphaseBase::implicitStepSetup( real64 const & GEOS_UNUSED_PARAM( time_n ),
real64 const & GEOS_UNUSED_PARAM( dt ),
CompositionalMultiphaseBase::implicitStepSetup( real64 const & time_n,
real64 const & dt,
DomainPartition & domain )
{
PhysicsSolverBase::implicitStepSetup( time_n, dt, domain );

forDiscretizationOnMeshTargets( domain.getMeshBodies(), [&]( string const &,
MeshLevel & mesh,
string_array const & regionNames )
Expand Down Expand Up @@ -2473,6 +2475,8 @@ void CompositionalMultiphaseBase::resetStateToBeginningOfStep( DomainPartition &
{
GEOS_MARK_FUNCTION;

PhysicsSolverBase::resetStateToBeginningOfStep( domain );

forDiscretizationOnMeshTargets( domain.getMeshBodies(), [&]( string const &,
MeshLevel & mesh,
string_array const & regionNames )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,7 @@ real64 CompositionalMultiphaseFVM::scalingForSystemSolution( DomainPartition & d
else
{
string const dofKey = dofManager.getKey( viewKeyStruct::elemDofFieldString() );
real64 scalingFactor = 1.0;
real64 scalingFactor = CompositionalMultiphaseBase::scalingForSystemSolution( domain, dofManager, localSolution );
real64 minPresScalingFactor = 1.0, minCompDensScalingFactor = 1.0, minTempScalingFactor = 1.0;

stdVector< MpiWrapper::PairType< real64, globalIndex > > regionDeltaPresMaxLoc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class CoupledSolver : public PhysicsSolverBase
DofManager const & dofManager,
arrayView1d< real64 const > const & localSolution ) override
{
real64 scalingFactor = 1e9;
real64 scalingFactor = PhysicsSolverBase::scalingForSystemSolution( domain, dofManager, localSolution );
forEachArgInTuple( m_solvers, [&]( auto & solver, auto )
{
real64 const singlePhysicsScalingFactor = solver->scalingForSystemSolution( domain, dofManager, localSolution );
Expand Down
Loading