Skip to content

[BUG] Fix condition number computation and optional plotting #17

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gempy_engine/API/interp_single/_interp_scalar_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def interpolate_scalar_field(solver_input: SolverInput, options: InterpolationOp

match weights_cached:
case None:
foo = solver_input.weights_x0
weights = _solve_and_store_weights(
solver_input=solver_input,
kernel_options=options.kernel_options,
Expand Down Expand Up @@ -87,7 +88,6 @@ def _solve_interpolation(interp_input: SolverInput, kernel_options: KernelOption
if kernel_options.optimizing_condition_number:
_optimize_nuggets_against_condition_number(A_matrix, interp_input, kernel_options)

# TODO: Smooth should be taken from options
weights = solver_interface.kernel_reduction(
cov=A_matrix,
b=b_vector,
Expand Down
1 change: 1 addition & 0 deletions gempy_engine/API/interp_single/_interp_single_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def input_preprocess(data_shape: TensorsStructure, interpolation_input: Interpol
xyz_to_interpolate=grid_internal,
fault_internal=fault_values
)
solver_input.weights_x0 = interpolation_input.weights

return solver_input

Expand Down
10 changes: 9 additions & 1 deletion gempy_engine/core/backend_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

if is_pytorch_installed:
import torch

PYKEOPS= DEFAULT_PYKEOPS

# * Import a copy of numpy as tfnp
from importlib.util import find_spec, module_from_spec
Expand Down Expand Up @@ -44,7 +46,7 @@ def get_backend_string(cls) -> str:

@classmethod
def change_backend_gempy(cls, engine_backend: AvailableBackends, use_gpu: bool = True, dtype: Optional[str] = None):
cls._change_backend(engine_backend, pykeops_enabled=DEFAULT_PYKEOPS, use_gpu=use_gpu, dtype=dtype)
cls._change_backend(engine_backend, pykeops_enabled=PYKEOPS, use_gpu=use_gpu, dtype=dtype)

@classmethod
def _change_backend(cls, engine_backend: AvailableBackends, pykeops_enabled: bool = False, use_gpu: bool = True, dtype: Optional[str] = None):
Expand Down Expand Up @@ -100,6 +102,12 @@ def _change_backend(cls, engine_backend: AvailableBackends, pykeops_enabled: boo
if (pykeops_enabled):
import pykeops
cls._wrap_pykeops_functions()

if (use_gpu):
cls.use_gpu = True
# cls.tensor_backend_pointer['active_backend'].set_default_device("cuda")
else:
cls.use_gpu = False

case (_):
raise AttributeError(
Expand Down
40 changes: 22 additions & 18 deletions gempy_engine/core/data/interpolation_input.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pprint
from dataclasses import dataclass
from typing import Optional
from dataclasses import dataclass, field
from typing import Optional, Union

import numpy as np

Expand All @@ -21,23 +21,9 @@ class InterpolationInput:
surface_points: SurfacePoints
orientations: Orientations
_original_grid: EngineGrid

@property
def original_grid(self):
return self._original_grid

def set_grid_to_original(self):
self._grid = self._original_grid


_grid: EngineGrid
@property
def grid(self):
return self._grid

def set_temp_grid(self, value):
self._grid = value

weights: Union[list[np.ndarray] | np.ndarray] = field(default_factory=lambda: [])
_unit_values: Optional[np.ndarray] = None
segmentation_function: Optional[callable] = None # * From scalar field to values

Expand All @@ -52,14 +38,18 @@ def set_temp_grid(self, value):

def __init__(self, surface_points: SurfacePoints, orientations: Orientations, grid: EngineGrid,
unit_values: Optional[np.ndarray] = None, segmentation_function: Optional[callable] = None,
stack_relation: StackRelationType = StackRelationType.ERODE):
stack_relation: StackRelationType = StackRelationType.ERODE, weights: list[np.ndarray] = None):
if weights is None:
weights = []

self.surface_points = surface_points
self._original_grid = grid
self._grid = grid
self.orientations = orientations
self.unit_values = unit_values
self.segmentation_function = segmentation_function
self.stack_relation = stack_relation
self.weights = weights

# @ on

Expand Down Expand Up @@ -92,6 +82,7 @@ def from_interpolation_input_subset(cls, all_interpolation_input: "Interpolation
grid=grid,
unit_values=unit_values,
stack_relation=stack_structure.active_masking_descriptor,
weights=(all_interpolation_input.weights[stack_number] if stack_number < len(all_interpolation_input.weights) else None)
)

# ! Setting this on the constructor does not work with data classes.
Expand All @@ -116,6 +107,19 @@ def from_schema(cls, schema: InterpolationInputSchema) -> "InterpolationInput":
grid=grid
)

@property
def original_grid(self):
return self._original_grid

def set_grid_to_original(self):
self._grid = self._original_grid

@property
def grid(self):
return self._grid

def set_temp_grid(self, value):
self._grid = value

@property
def slice_feature(self):
Expand Down
Loading