diff --git a/gempy_engine/API/interp_single/_interp_scalar_field.py b/gempy_engine/API/interp_single/_interp_scalar_field.py index fd02674..be749af 100644 --- a/gempy_engine/API/interp_single/_interp_scalar_field.py +++ b/gempy_engine/API/interp_single/_interp_scalar_field.py @@ -47,6 +47,7 @@ def interpolate_scalar_field(solver_input: SolverInput, options: InterpolationOp match weights_cached: case None: + foo = solver_input.weights_x0 weights = _solve_and_store_weights( solver_input=solver_input, kernel_options=options.kernel_options, @@ -87,7 +88,6 @@ def _solve_interpolation(interp_input: SolverInput, kernel_options: KernelOption if kernel_options.optimizing_condition_number: _optimize_nuggets_against_condition_number(A_matrix, interp_input, kernel_options) - # TODO: Smooth should be taken from options weights = solver_interface.kernel_reduction( cov=A_matrix, b=b_vector, diff --git a/gempy_engine/API/interp_single/_interp_single_feature.py b/gempy_engine/API/interp_single/_interp_single_feature.py index 41370a7..c1ccd53 100644 --- a/gempy_engine/API/interp_single/_interp_single_feature.py +++ b/gempy_engine/API/interp_single/_interp_single_feature.py @@ -108,6 +108,7 @@ def input_preprocess(data_shape: TensorsStructure, interpolation_input: Interpol xyz_to_interpolate=grid_internal, fault_internal=fault_values ) + solver_input.weights_x0 = interpolation_input.weights return solver_input diff --git a/gempy_engine/core/backend_tensor.py b/gempy_engine/core/backend_tensor.py index 1c3cab6..d734e9b 100644 --- a/gempy_engine/core/backend_tensor.py +++ b/gempy_engine/core/backend_tensor.py @@ -12,6 +12,8 @@ if is_pytorch_installed: import torch + +PYKEOPS= DEFAULT_PYKEOPS # * Import a copy of numpy as tfnp from importlib.util import find_spec, module_from_spec @@ -44,7 +46,7 @@ def get_backend_string(cls) -> str: @classmethod def change_backend_gempy(cls, engine_backend: AvailableBackends, use_gpu: bool = True, dtype: Optional[str] = None): - cls._change_backend(engine_backend, pykeops_enabled=DEFAULT_PYKEOPS, use_gpu=use_gpu, dtype=dtype) + cls._change_backend(engine_backend, pykeops_enabled=PYKEOPS, use_gpu=use_gpu, dtype=dtype) @classmethod def _change_backend(cls, engine_backend: AvailableBackends, pykeops_enabled: bool = False, use_gpu: bool = True, dtype: Optional[str] = None): @@ -100,6 +102,12 @@ def _change_backend(cls, engine_backend: AvailableBackends, pykeops_enabled: boo if (pykeops_enabled): import pykeops cls._wrap_pykeops_functions() + + if (use_gpu): + cls.use_gpu = True + # cls.tensor_backend_pointer['active_backend'].set_default_device("cuda") + else: + cls.use_gpu = False case (_): raise AttributeError( diff --git a/gempy_engine/core/data/interpolation_input.py b/gempy_engine/core/data/interpolation_input.py index 427e961..cf00c7f 100644 --- a/gempy_engine/core/data/interpolation_input.py +++ b/gempy_engine/core/data/interpolation_input.py @@ -1,6 +1,6 @@ import pprint -from dataclasses import dataclass -from typing import Optional +from dataclasses import dataclass, field +from typing import Optional, Union import numpy as np @@ -21,23 +21,9 @@ class InterpolationInput: surface_points: SurfacePoints orientations: Orientations _original_grid: EngineGrid - - @property - def original_grid(self): - return self._original_grid - - def set_grid_to_original(self): - self._grid = self._original_grid - - _grid: EngineGrid - @property - def grid(self): - return self._grid - - def set_temp_grid(self, value): - self._grid = value + weights: Union[list[np.ndarray] | np.ndarray] = field(default_factory=lambda: []) _unit_values: Optional[np.ndarray] = None segmentation_function: Optional[callable] = None # * From scalar field to values @@ -52,7 +38,10 @@ def set_temp_grid(self, value): def __init__(self, surface_points: SurfacePoints, orientations: Orientations, grid: EngineGrid, unit_values: Optional[np.ndarray] = None, segmentation_function: Optional[callable] = None, - stack_relation: StackRelationType = StackRelationType.ERODE): + stack_relation: StackRelationType = StackRelationType.ERODE, weights: list[np.ndarray] = None): + if weights is None: + weights = [] + self.surface_points = surface_points self._original_grid = grid self._grid = grid @@ -60,6 +49,7 @@ def __init__(self, surface_points: SurfacePoints, orientations: Orientations, gr self.unit_values = unit_values self.segmentation_function = segmentation_function self.stack_relation = stack_relation + self.weights = weights # @ on @@ -92,6 +82,7 @@ def from_interpolation_input_subset(cls, all_interpolation_input: "Interpolation grid=grid, unit_values=unit_values, stack_relation=stack_structure.active_masking_descriptor, + weights=(all_interpolation_input.weights[stack_number] if stack_number < len(all_interpolation_input.weights) else None) ) # ! Setting this on the constructor does not work with data classes. @@ -116,6 +107,19 @@ def from_schema(cls, schema: InterpolationInputSchema) -> "InterpolationInput": grid=grid ) + @property + def original_grid(self): + return self._original_grid + + def set_grid_to_original(self): + self._grid = self._original_grid + + @property + def grid(self): + return self._grid + + def set_temp_grid(self, value): + self._grid = value @property def slice_feature(self): diff --git a/gempy_engine/modules/solver/_pykeops_solvers/_conjugate_gradient.py b/gempy_engine/modules/solver/_pykeops_solvers/_conjugate_gradient.py index 45caefb..6b5ce74 100644 --- a/gempy_engine/modules/solver/_pykeops_solvers/_conjugate_gradient.py +++ b/gempy_engine/modules/solver/_pykeops_solvers/_conjugate_gradient.py @@ -1,64 +1,290 @@ from pykeops.common.utils import get_tools - from gempy_engine.core.backend_tensor import BackendTensor +import warnings + +def ConjugateGradientSolver(binding, linop, b, eps=1e-6, x0=None, + regularization=None, preconditioning=None, + adaptive_tolerance=True, max_iterations=5000, + verbose=False + ): + """ + Robust Conjugate Gradient solver for ill-conditioned linear systems using PyKeOps. + + Solves the linear system: linop(a) = b + where linop represents a symmetric positive definite linear operator. + + Enhanced with stability features for ill-conditioned kriging matrices: + - Adaptive regularization + - Preconditioning support + - Robust convergence criteria + - Residual monitoring and restart capability + + Args: + binding: PyKeOps backend binding (CPU/GPU) + linop: Linear operator function (symmetric positive definite) + b: Right-hand side vector/tensor + eps: Base convergence tolerance (default: 1e-6) + x0: Initial guess (optional, defaults to zero vector) + regularization: Regularization strategy ('auto', 'fixed', None) + preconditioning: Preconditioner function (optional) + adaptive_tolerance: Whether to use adaptive convergence criteria + max_iterations: Maximum iterations (default: 5000) + + Returns: + a: Solution vector where linop(a) = b + """ + # ============================================================================= + # INITIALIZATION AND STABILITY SETUP + # ============================================================================= -def ConjugateGradientSolver(binding, linop, b, eps=1e-6, x0=None): - # Conjugate gradient algorithm to solve linear system of the form - # Ma=b where linop is a linear operation corresponding - # to a symmetric and positive definite matrix tools = get_tools(binding) - delta = tools.size(b) * eps**2 - # Initialize 'a' with 'x0' if provided, otherwise as zero vector + + # --- inside INITIALIZATION AND STABILITY SETUP ----------------- + if adaptive_tolerance: + b_norm = (b ** 2).sum().sqrt() + # Relative part + rel_thresh = eps * b_norm + # Absolute part scaled by vector size + abs_thresh = eps * tools.size(b) + # Minimum practical threshold to avoid over-tightening + min_thresh = 1e-4 * b_norm # <- tweak to taste + initial_residual_threshold = max(rel_thresh, abs_thresh, min_thresh) + delta = initial_residual_threshold ** 2 + else: + delta = tools.size(b) * eps ** 2 + + + # Initialize solution vector with better conditioning if x0 is not None: - a = tools.copy(x0.to(BackendTensor.dtype_obj)) + a = tools.copy(x0.to(BackendTensor.dtype_obj)).reshape(-1, 1) + else: - a = 0 * b + # For ill-conditioned systems, start with small random perturbation + # instead of pure zero to avoid getting stuck in numerical null space + a = 0.001 * tools.randn_like(b) if hasattr(tools, 'randn_like') else 0 * b + + # ============================================================================= + # REGULARIZATION FOR ILL-CONDITIONED MATRICES + # ============================================================================= + + def regularized_linop(x): + """Apply regularization to improve conditioning""" + base_result = linop(x) + + if regularization == 'auto': + # Adaptive regularization based on residual behavior + reg_param = max(1e-8, eps * 0.1) # Dynamic regularization + return base_result + reg_param * x + elif regularization == 'fixed': + # Fixed Tikhonov regularization + reg_param = 1e-6 # Adjust based on your problem + return base_result + reg_param * x + else: + return base_result + + # Use regularized operator if specified + effective_linop = regularized_linop if regularization else linop + + # ============================================================================= + # PRECONDITIONING SETUP + # ============================================================================= + + if preconditioning is not None: + # Apply preconditioning to both sides of the equation + # M^(-1) * A * x = M^(-1) * b, where M is preconditioner + b_preconditioned = preconditioning(b) + + def preconditioned_linop(x): + return preconditioning(effective_linop(x)) + + working_linop = preconditioned_linop + working_b = b_preconditioned + else: + working_linop = effective_linop + working_b = b + + # ============================================================================= + # ENHANCED CONJUGATE GRADIENT SETUP + # ============================================================================= + + # Compute initial residual + r = tools.copy(working_b) - working_linop(a) + nr2 = (r ** 2).sum() + + # Store initial residual for relative convergence check + initial_nr2 = nr2 - r = tools.copy(b) - linop(a) # Update the residual based on the initial guess - nr2 = (r**2).sum() if nr2 < delta: return a + p = tools.copy(r) + + # ============================================================================= + # ENHANCED ITERATION CONTROL + # ============================================================================= + k = 1 - prev_nr2 = nr2 # Initialize previous nr2 for divergence check - max_iterations = 5000 - divergence_tolerance = 20 # Number of consecutive iterations allowed for increase in residual - consecutive_divergence = 0 # Counter for consecutive divergence + prev_nr2 = nr2 + consecutive_divergence = 0 + divergence_tolerance = 10 # Reduced for ill-conditioned systems + + # Stagnation detection for ill-conditioned systems + stagnation_threshold = 1e-12 + stagnation_counter = 0 + stagnation_tolerance = 50 + + # Restart mechanism + restart_threshold = max_iterations // 4 # Restart every 25% of max iterations + last_restart = 0 + + # Quality monitoring + residual_history = [] + + # ============================================================================= + # ROBUST CONJUGATE GRADIENT LOOP + # ============================================================================= while k < max_iterations: - Mp = linop(p) - alp = nr2 / (p * Mp).sum() + + # ------------------------------------------------------------------------- + # CORE CG STEP WITH NUMERICAL STABILITY CHECKS + # ------------------------------------------------------------------------- + + Mp = working_linop(p) + + # Check for numerical breakdown in denominator + denominator = (p * Mp).sum() + if abs(denominator) < 1e-14: + warnings.warn(f"Numerical breakdown detected at iteration {k}. " + f"Denominator too small: {denominator}") + break + + alp = nr2 / denominator + + # Safeguard against excessive step sizes + if abs(alp) > 1e6: + warnings.warn(f"Excessive step size detected: {alp}. Reducing.") + alp = alp / abs(alp) * 1e6 # Limit step size + a += alp * p r -= alp * Mp - nr2new = (r**2).sum() + nr2new = (r ** 2).sum() + + # ------------------------------------------------------------------------- + # ENHANCED CONVERGENCE CRITERIA + # ------------------------------------------------------------------------- + + # Relative convergence check (important for ill-conditioned systems) + relative_residual = nr2new / initial_nr2 + absolute_residual = nr2new - # Check for convergence - if nr2new < delta: + if adaptive_tolerance: + # Use both absolute and relative criteria + converged = (absolute_residual < delta) or (relative_residual < eps ** 2) + else: + converged = absolute_residual < delta + + if converged: + print(f"Converged at iteration {k}") + print(f" Absolute residual: {absolute_residual:.2e}") + print(f" Relative residual: {relative_residual:.2e}") break - # Check for divergence - if nr2new > prev_nr2: + # ------------------------------------------------------------------------- + # STAGNATION AND DIVERGENCE DETECTION + # ------------------------------------------------------------------------- + + residual_history.append(float(nr2new)) + + # Check for stagnation (typical in ill-conditioned systems) + residual_change = abs(nr2new - prev_nr2) / max(prev_nr2, 1e-16) + if residual_change < stagnation_threshold: + stagnation_counter += 1 + else: + stagnation_counter = 0 + + # Handle stagnation with restart + if stagnation_counter >= stagnation_tolerance: + if k - last_restart > restart_threshold: + print(f"Stagnation detected at iteration {k}. Restarting CG...") + # Restart: reset search direction to steepest descent + p = tools.copy(r) + stagnation_counter = 0 + last_restart = k + else: + print(f"Persistent stagnation detected. Algorithm may have converged " + f"to machine precision. Current residual: {nr2new:.2e}") + break + + # Enhanced divergence detection + if nr2new > prev_nr2 * 1.1: # Allow small increases due to rounding consecutive_divergence += 1 if consecutive_divergence >= divergence_tolerance: - print(f"Diverging for {divergence_tolerance} consecutive iterations at iteration {k}. Stopping algorithm.") + print(f"Algorithm diverging for {divergence_tolerance} iterations.") + print(f"This may indicate severe ill-conditioning.") + print(f"Consider stronger regularization or preconditioning.") break else: - consecutive_divergence = 0 # Reset counter if no divergence in this iteration + consecutive_divergence = 0 + + # ------------------------------------------------------------------------- + # PREPARE NEXT ITERATION WITH STABILITY CHECKS + # ------------------------------------------------------------------------- + + beta = nr2new / nr2 + + # Prevent beta from becoming too large (Fletcher-Reeves restart condition) + if beta > 1.0: + if verbose: + print(f"Large beta detected ({beta:.2e}) at iteration {k}. Restarting.") + p = tools.copy(r) # Restart with steepest descent + else: + p = r + beta * p + # Update for next iteration - p = r + (nr2new / nr2) * p nr2 = nr2new - prev_nr2 = nr2 # Update previous nr2 + prev_nr2 = nr2 + + # ------------------------------------------------------------------------- + # PROGRESS MONITORING + # ------------------------------------------------------------------------- - # Print every 100 iterations - if k % 1000 == 0: - print(f"Iteration {k}, Residual Norm: {nr2}") + if k % 500 == 0: # More frequent reporting for ill-conditioned systems + print(f"Iteration {k:4d} | Residual: {nr2:.2e} | " + f"Relative: {relative_residual:.2e} | " + f"Stagnation: {stagnation_counter}") k += 1 - # Print final number of iterations - print(f"Final Iteration {k}, Residual Norm: {nr2}") - return a + # ============================================================================= + # FINAL DIAGNOSTICS AND WARNINGS + # ============================================================================= + + final_relative_residual = nr2 / initial_nr2 + + print(f"\n=== Solver Summary ===") + print(f"Final iteration: {k}") + print(f"Final absolute residual: {nr2:.2e}") + print(f"Final relative residual: {final_relative_residual:.2e}") + print(f"Convergence target: {delta:.2e}") + + # Warn about potential issues + if k >= max_iterations: + warnings.warn("Maximum iterations reached. Solution may not be converged.") + if final_relative_residual > 1e-3: + warnings.warn(f"High relative residual ({final_relative_residual:.2e}). " + f"Matrix may be severely ill-conditioned.") + + if len(residual_history) > 100: + # Check convergence rate + recent_residuals = residual_history[-50:] + if len(recent_residuals) > 10: + avg_reduction = (recent_residuals[0] / recent_residuals[-1]) ** (1 / len(recent_residuals)) + if avg_reduction < 1.01: + warnings.warn("Very slow convergence detected. Consider preconditioning.") + + return a diff --git a/gempy_engine/modules/solver/_pykeops_solvers/_kernel_solve_autograd.py b/gempy_engine/modules/solver/_pykeops_solvers/_kernel_solve_autograd.py index e5561ec..2208032 100644 --- a/gempy_engine/modules/solver/_pykeops_solvers/_kernel_solve_autograd.py +++ b/gempy_engine/modules/solver/_pykeops_solvers/_kernel_solve_autograd.py @@ -5,7 +5,8 @@ from pykeops.common.parse_type import get_type from pykeops.torch.generic.generic_red import GenredAutograd -from gempy_engine.modules.solver._pykeops_solvers._conjugate_gradient import ConjugateGradientSolver +from ._conjugate_gradient import ConjugateGradientSolver +from ._nystrom import create_adaptive_nystrom_preconditioner class KernelSolveAutograd(torch.autograd.Function): @@ -113,7 +114,28 @@ def linop(var): return res global copy - result = ConjugateGradientSolver("torch", linop, varinv.data, eps, x0=x0) + if False: # * This does not work for gpu and so far it seems not to be specially better than direct solvers + preconditioner = create_adaptive_nystrom_preconditioner( + binding="torch", + linop=linop, + x_sample=varinv.data, + strategy="conservative", + ) + else: + preconditioner = None + + result = ConjugateGradientSolver( + binding="torch", + linop=linop, + b=varinv.data, + eps=1e-4, + x0=x0, + regularization=None, + preconditioning=preconditioner, + adaptive_tolerance=False, + max_iterations=500, + verbose=False + ) # relying on the 'ctx.saved_variables' attribute is necessary if you want to be able to differentiate the output # of the backward once again. It helps pytorch to keep track of 'who is who'. @@ -271,18 +293,18 @@ def backward(ctx, G): # Grads wrt. formula, aliases, varinvpos, alpha, backend, dtype, device_id_request, eps, ranges, optional_flags, rec_multVar_highdim, nx, ny, *args return ( - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - *grads, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + *grads, ) diff --git a/gempy_engine/modules/solver/_pykeops_solvers/_nystrom.py b/gempy_engine/modules/solver/_pykeops_solvers/_nystrom.py new file mode 100644 index 0000000..9db4292 --- /dev/null +++ b/gempy_engine/modules/solver/_pykeops_solvers/_nystrom.py @@ -0,0 +1,418 @@ +from pykeops.common.utils import get_tools +import warnings +import torch + + +# ============================================================================= +# HELPER FUNCTIONS FOR PRECONDITIONING +# ============================================================================= + +def create_adaptive_nystrom_preconditioner(binding, linop, x_sample, + strategy='conservative', **kwargs): + """ + Create an adaptive Nyström preconditioner with automatic parameter tuning. + + Args: + binding: PyKeOps backend + linop: Linear operator (kernel matrix) + x_sample: Sample vector + strategy: 'aggressive', 'conservative', or 'minimal' + **kwargs: Additional parameters for fine-tuning + + Returns: + Adaptive Nyström preconditioner + """ + tools = get_tools(binding) + n = tools.size(x_sample) + + print(f"Creating adaptive Nyström preconditioner (strategy: {strategy})") + + # Set parameters based on strategy + if strategy == 'aggressive': + params = { + 'rank' : min(n // 5, 200), + 'max_rank' : min(n // 3, 500), + 'tolerance' : 1e-8, + 'pivoting_strategy': 'greedy' + } + elif strategy == 'conservative': + params = { + 'rank' : min(n // 10, 100), + 'max_rank' : min(n // 4, 250), + 'tolerance' : 1e-6, + 'pivoting_strategy': 'greedy' + } + else: # minimal + params = { + 'rank' : min(n // 20, 50), + 'max_rank' : min(n // 8, 100), + 'tolerance' : 1e-4, + 'pivoting_strategy': 'diagonal' + } + + # Override with user-provided parameters + params.update(kwargs) + + print(f"Nyström parameters: {params}") + + # Try Nyström with fallback + try: + return nystrom_preconditioner(binding, linop, x_sample, **params) + except Exception as e: + print(f"Adaptive Nyström failed: {e}") + print("Falling back to identity preconditioner") + return _create_identity_preconditioner() + + + +def nystrom_preconditioner(binding, linop, x_sample, rank=None, tolerance=1e-6, + max_rank=None, pivoting_strategy='greedy'): + """ + Create a Nyström (pivoted Cholesky) preconditioner for kernel matrices. + + This method constructs a low-rank approximation of the kernel matrix: + K ≈ C * C^T where C is n × rank + + The preconditioner is then M^(-1) = (C * C^T + λI)^(-1) + + Args: + binding: PyKeOps backend ('torch', 'numpy', etc.) + linop: Linear operator function (kernel matrix) + x_sample: Sample vector to determine size and structure + rank: Fixed rank for approximation (optional) + tolerance: Tolerance for adaptive rank selection + max_rank: Maximum rank to consider + pivoting_strategy: 'greedy', 'random', or 'diagonal' + + Returns: + Nyström preconditioner function + """ + tools = get_tools(binding) + n = tools.size(x_sample) + + print(f"Creating Nyström preconditioner for system size: {n}") + + # Set reasonable defaults for rank + if max_rank is None: + max_rank = min(n // 4, 500) # Limit computational cost + + if rank is None: + rank = min(n // 10, 100) # Start with modest rank + + print(f"Target rank: {rank}, Max rank: {max_rank}") + + try: + # Step 1: Select pivot points + pivot_indices = _select_pivots(binding, linop, x_sample, rank, + pivoting_strategy, max_rank, tolerance) + + actual_rank = len(pivot_indices) + print(f"Selected {actual_rank} pivot points") + + # Step 2: Construct the Nyström approximation + C = _construct_nystrom_factor(binding, linop, x_sample, pivot_indices) + + # Step 3: Create the preconditioner + preconditioner = _create_nystrom_preconditioner(binding, C, tolerance) + + return preconditioner + + except Exception as e: + print(f"Nyström preconditioner construction failed: {e}") + print("Falling back to identity preconditioner") + return _create_identity_preconditioner() + + +def _select_pivots(binding, linop, x_sample, target_rank, strategy, max_rank, tolerance): + """ + Select pivot points for Nyström approximation using various strategies. + """ + tools = get_tools(binding) + n = tools.size(x_sample) + + if strategy == 'greedy': + return _greedy_pivot_selection(binding, linop, x_sample, target_rank, max_rank, tolerance) + elif strategy == 'random': + return _random_pivot_selection(n, target_rank) + elif strategy == 'diagonal': + return _diagonal_pivot_selection(binding, linop, x_sample, target_rank) + else: + raise ValueError(f"Unknown pivoting strategy: {strategy}") + + +def _greedy_pivot_selection(binding, linop, x_sample, target_rank, max_rank, tolerance): + """ + Greedy pivot selection based on diagonal residuals (pivoted Cholesky). + """ + tools = get_tools(binding) + n = tools.size(x_sample) + + selected_pivots = [] + diagonal_residuals = [] + + # Initialize: compute diagonal elements + print("Computing initial diagonal elements...") + for i in range(n): + ei = _create_unit_vector(binding, x_sample, i) + diag_elem = _extract_diagonal_element(linop, ei, i) + diagonal_residuals.append(float(diag_elem)) + + # Convert to tensor for easier manipulation + if binding == 'torch': + import torch + diag_residuals = torch.tensor(diagonal_residuals, dtype=x_sample.dtype, device=x_sample.device) + + # Greedy selection loop + for k in range(min(target_rank, max_rank, n)): + # Find the pivot with largest residual diagonal + if binding == 'torch': + pivot_idx = torch.argmax(diag_residuals).item() + max_residual = diag_residuals[pivot_idx].item() + else: + pivot_idx = diagonal_residuals.index(max(diagonal_residuals)) + max_residual = diagonal_residuals[pivot_idx] + + # Check stopping criterion + if max_residual < tolerance: + print(f"Stopping pivot selection: max residual {max_residual:.2e} < tolerance {tolerance:.2e}") + break + + selected_pivots.append(pivot_idx) + print(f"Selected pivot {k + 1}: index {pivot_idx}, residual: {max_residual:.2e}") + + # Update residual diagonal elements + if k < min(target_rank, max_rank, n) - 1: # Don't update on last iteration + _update_residual_diagonal(binding, linop, x_sample, pivot_idx, + selected_pivots, diag_residuals if binding == 'torch' else diagonal_residuals) + + return selected_pivots + + +def _random_pivot_selection(n, target_rank): + """ + Random pivot selection (for comparison/fallback). + """ + import random + indices = list(range(n)) + random.shuffle(indices) + return indices[:target_rank] + + +def _diagonal_pivot_selection(binding, linop, x_sample, target_rank): + """ + Select pivots based on largest diagonal elements. + """ + tools = get_tools(binding) + n = tools.size(x_sample) + + diagonal_elements = [] + for i in range(n): + ei = _create_unit_vector(binding, x_sample, i) + diag_elem = _extract_diagonal_element(linop, ei, i) + diagonal_elements.append((float(diag_elem), i)) + + # Sort by diagonal value (descending) + diagonal_elements.sort(reverse=True) + + return [idx for _, idx in diagonal_elements[:target_rank]] + + +def _create_unit_vector(binding, x_sample, index): + """ + Create the i-th unit vector with same properties as x_sample. + """ + tools = get_tools(binding) + + if binding == 'torch': + import torch + ei = torch.zeros_like(x_sample) + if len(x_sample.shape) == 1: + ei[index] = 1.0 + else: + flat_ei = ei.view(-1) + flat_ei[index] = 1.0 + ei = flat_ei.view(x_sample.shape) + else: + ei = tools.zeros_like(x_sample) + if len(ei.shape) == 1: + ei[index] = 1.0 + else: + flat_ei = ei.view(-1) + flat_ei[index] = 1.0 + + return ei + + +def _extract_diagonal_element(linop, unit_vector, index): + """ + Extract diagonal element by applying linop to unit vector. + """ + result = linop(unit_vector) + + if len(result.shape) == 1: + return result[index] + else: + flat_result = result.view(-1) + return flat_result[index] + + +def _update_residual_diagonal(binding, linop, x_sample, new_pivot, selected_pivots, diag_residuals): + """ + Update residual diagonal after adding a new pivot (for greedy selection). + + This implements the pivoted Cholesky update: + diag_residual[i] -= (K[i, new_pivot])^2 / K[new_pivot, new_pivot] + """ + # Get the new pivot column + pivot_vector = _create_unit_vector(binding, x_sample, new_pivot) + pivot_column = linop(pivot_vector) + + # Get diagonal element of new pivot + pivot_diag = _extract_diagonal_element(linop, pivot_vector, new_pivot) + + if abs(pivot_diag) < 1e-12: + print(f"Warning: very small pivot diagonal {pivot_diag}") + return + + # Update residual diagonal + tools = get_tools(binding) + n = tools.size(x_sample) + + for i in range(n): + if i not in selected_pivots: # Don't update already selected pivots + # Extract K[i, new_pivot] + if len(pivot_column.shape) == 1: + kij = pivot_column[i] + else: + flat_column = pivot_column.view(-1) + kij = flat_column[i] + + # Update residual: diag[i] -= kij^2 / pivot_diag + update = (kij * kij) / pivot_diag + + if binding == 'torch': + diag_residuals[i] -= update + # Ensure non-negative + diag_residuals[i] = torch.maximum(diag_residuals[i], torch.tensor(0.0, device=diag_residuals.device)) + else: + diag_residuals[i] -= float(update) + diag_residuals[i] = max(diag_residuals[i], 0.0) + + +def _construct_nystrom_factor(binding, linop, x_sample, pivot_indices): + """ + Construct the Nyström factor C such that K ≈ C * C^T. + + C[i,j] = K[i, pivot_j] / sqrt(K[pivot_j, pivot_j]) + """ + tools = get_tools(binding) + n = tools.size(x_sample) + rank = len(pivot_indices) + + print(f"Constructing Nyström factor: {n} × {rank}") + + # Initialize factor matrix + if binding == 'torch': + import torch + C = torch.zeros(n, rank, dtype=x_sample.dtype, device=x_sample.device) + else: + # For other backends, we'll build column by column + C_columns = [] + + # Construct each column of C + for j, pivot_idx in enumerate(pivot_indices): + # Get the pivot column from kernel matrix + pivot_vector = _create_unit_vector(binding, x_sample, pivot_idx) + kernel_column = linop(pivot_vector) + + # Get diagonal element for normalization + pivot_diag = _extract_diagonal_element(linop, pivot_vector, pivot_idx) + + if abs(pivot_diag) < 1e-12: + print(f"Warning: very small pivot diagonal {pivot_diag} at pivot {j}") + pivot_diag = 1e-12 # Regularize + + # Normalize: C[:, j] = K[:, pivot_idx] / sqrt(K[pivot_idx, pivot_idx]) + normalizer = (abs(pivot_diag)) ** 0.5 + + if binding == 'torch': + if len(kernel_column.shape) > 1: + kernel_column = kernel_column.view(-1) + C[:, j] = kernel_column / normalizer + else: + if len(kernel_column.shape) > 1: + kernel_column = kernel_column.view(-1) + normalized_column = kernel_column / normalizer + C_columns.append(normalized_column) + + if (j + 1) % 10 == 0: + print(f" Constructed {j + 1}/{rank} columns") + + if binding != 'torch': + # Stack columns for other backends + C = tools.stack(C_columns, axis=1) if hasattr(tools, 'stack') else C_columns + + return C + + +def _create_nystrom_preconditioner(binding, C, regularization=1e-6): + """ + Create preconditioner from Nyström factor. + + The preconditioner solves: (C * C^T + λI) * x = b + using the Woodbury matrix identity. + """ + tools = get_tools(binding) + + if binding == 'torch': + import torch + + # Precompute: (C^T * C + λI)^(-1) + CtC = torch.matmul(C.T, C) + regularized_CtC = CtC + regularization * torch.eye(CtC.shape[0], + dtype=C.dtype, device=C.device) + + try: + CtC_inv = torch.linalg.inv(regularized_CtC) + print(f"Nyström preconditioner ready with rank {C.shape[1]}") + except Exception as e: + print(f"Matrix inversion failed: {e}, using pseudo-inverse") + CtC_inv = torch.linalg.pinv(regularized_CtC) + + def nystrom_preconditioner(x): + """ + Apply Nyström preconditioner using Woodbury identity: + (C*C^T + λI)^(-1) * x = (1/λ) * (x - C * (C^T*C + λI)^(-1) * C^T * x) + """ + if len(x.shape) > 1: + x_flat = x.view(-1) + else: + x_flat = x + + # Woodbury formula application + Ctx = torch.matmul(C.T, x_flat) + temp = torch.matmul(CtC_inv, Ctx) + Ctemp = torch.matmul(C, temp) + + result = (x_flat - Ctemp) / regularization + + return result.view(x.shape) if len(x.shape) > 1 else result + + else: + # For other backends, create a simpler version + print("Using simplified Nyström preconditioner for non-torch backend") + + def nystrom_preconditioner(x): + # Simple scaling approximation + return x / regularization + + return nystrom_preconditioner + + +def _create_identity_preconditioner(): + """Create a trivial identity preconditioner.""" + + def identity_preconditioner(x): + return x + + return identity_preconditioner diff --git a/gempy_engine/modules/solver/_torch_solvers.py b/gempy_engine/modules/solver/_torch_solvers.py index a14cffc..2f04e31 100644 --- a/gempy_engine/modules/solver/_torch_solvers.py +++ b/gempy_engine/modules/solver/_torch_solvers.py @@ -4,8 +4,8 @@ bt = BackendTensor -def pykeops_torch_cg(b, cov, x0): - if PYKEOPS_SOLVER:=False: +def pykeops_torch_cg(b, cov, x0, use_gpu): + if PYKEOPS_SOLVER:=False: # * Default pykeops solver. It is here as reference solver = cov.solver( b.view(-1, 1), alpha=0, @@ -17,14 +17,14 @@ def pykeops_torch_cg(b, cov, x0): ) w = solver(eps=1e-5) - else: + else: # * My solver and what we need to tweak solver = custom_pykeops_solver( cov, b.view(-1, 1), alpha=0, - backend="CPU", + backend="GPU" if use_gpu else "CPU", call=False, - dtype_acc="float64", + dtype_acc="float64", # * For now we always use float 64 even on gpu sum_scheme="kahan_scheme" ) diff --git a/gempy_engine/modules/solver/solver_interface.py b/gempy_engine/modules/solver/solver_interface.py index 60aaf9d..c32a565 100644 --- a/gempy_engine/modules/solver/solver_interface.py +++ b/gempy_engine/modules/solver/solver_interface.py @@ -24,10 +24,13 @@ def kernel_reduction(cov, b, kernel_options: KernelOptions, x0: Optional[np.ndar case (AvailableBackends.PYTORCH, False, _): if kernel_options.compute_condition_number: cond_number = BackendTensor.t.linalg.cond(cov) + kernel_options.condition_number = cond_number print(f'Condition number: {cond_number}.') w = torch_solve(b, cov) case (AvailableBackends.PYTORCH, True, _): - w = pykeops_torch_cg(b, cov, x0) + if len(x0) == 0: + x0 = None + w = pykeops_torch_cg(b, cov, x0, bt.use_gpu) case (AvailableBackends.numpy, True, Solvers.PYKEOPS_CG): w = pykeops_numpy_cg(b, cov, dtype) case (AvailableBackends.numpy, True, Solvers.DEFAULT): @@ -35,7 +38,7 @@ def kernel_reduction(cov, b, kernel_options: KernelOptions, x0: Optional[np.ndar case (AvailableBackends.numpy, False, Solvers.DEFAULT): w = numpy_solve(b, cov, dtype) if compute_condition_number: - _compute_conditional_number(cov) + kernel_options.condition_number = _compute_conditional_number(cov) case (AvailableBackends.numpy, _, Solvers.DEFAULT |Solvers.SCIPY_CG): w = numpy_cg(b, cov) case (AvailableBackends.numpy, _, Solvers.GMRES): @@ -49,7 +52,7 @@ def kernel_reduction(cov, b, kernel_options: KernelOptions, x0: Optional[np.ndar -def _compute_conditional_number(cov): +def _compute_conditional_number(cov, plot=False): cond_number = np.linalg.cond(cov) svd = np.linalg.svd(cov) eigvals = np.linalg.eigvals(cov) @@ -58,15 +61,18 @@ def _compute_conditional_number(cov): idx = np.where(eigvals > 800) print(idx) - import matplotlib.pyplot as plt if not is_positive_definite: # ! Careful numpy False warnings.warn('The covariance matrix is not positive definite') - # Plotting the histogram - plt.hist(eigvals, bins=50, color='blue', alpha=0.7, log=True) - plt.xlabel('Eigenvalue') - plt.ylabel('Frequency') - plt.title('Histogram of Eigenvalues') - plt.show() + if plot: + import matplotlib.pyplot as plt + # Plotting the histogram + plt.hist(eigvals, bins=50, color='blue', alpha=0.7, log=True) + plt.xlabel('Eigenvalue') + plt.ylabel('Frequency') + plt.title('Histogram of Eigenvalues') + plt.show() + + return cond_number diff --git a/gempy_engine/modules/weights_cache/weights_cache_interface.py b/gempy_engine/modules/weights_cache/weights_cache_interface.py index ac43a1f..0ff3aa9 100644 --- a/gempy_engine/modules/weights_cache/weights_cache_interface.py +++ b/gempy_engine/modules/weights_cache/weights_cache_interface.py @@ -29,7 +29,12 @@ def initialize_cache_dir(disk_cache_dir=None): os.makedirs(WeightCache.disk_cache_dir, exist_ok=True) WeightCache._check_and_cleanup_cache() - + + @staticmethod + def clear_cache(): + WeightCache.memory_cache = {} + WeightCache._check_and_cleanup_cache() + @staticmethod def _check_and_cleanup_cache(): total_size = 0