Source code for pdevisualizer.solver

"""
Unified API for PDE solving and visualization.

This module provides a clean, professional interface for solving different types
of PDEs with flexible boundary conditions and initial conditions.
"""

import numpy as np
from typing import Optional, Union, Dict, Any, Tuple, List, Sequence
from enum import Enum
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Import our existing solvers
from .heat2d import solve_heat, animate_heat, step_heat
from .wave2d import solve_wave, animate_wave, step_wave, step_wave_first
from .wave2d import create_gaussian_pulse, create_circular_wave
from .boundary_conditions import (
    BoundarySpec,
    solve_heat_with_boundaries,
    solve_wave_with_boundaries,
    BoundaryType,  # Use the enum from boundary_conditions
)


[docs] class EquationType(Enum): """Supported PDE types.""" HEAT = "heat" WAVE = "wave"
# BoundaryType is imported from boundary_conditions module (for wave equation)
[docs] class BoundaryCondition: """Boundary condition specification - Enhanced version.""" def __init__(self, boundary_type: BoundaryType, value: float = 0.0): self.type = boundary_type self.value = value
[docs] @classmethod def dirichlet(cls, value: float = 0.0): """Fixed value boundary condition.""" return cls(BoundaryType.DIRICHLET, value)
[docs] @classmethod def neumann(cls, flux: float = 0.0): """Fixed flux boundary condition (insulated if flux=0).""" return cls(BoundaryType.NEUMANN, flux)
[docs] @classmethod def periodic(cls): """Periodic boundary condition.""" return cls(BoundaryType.PERIODIC)
[docs] @classmethod def absorbing(cls): """Absorbing boundary condition (for waves).""" return cls(BoundaryType.ABSORBING)
[docs] def to_boundary_spec(self): """Convert to BoundarySpec for the boundary_conditions module.""" return BoundarySpec.uniform(self.type, self.value)
[docs] class InitialConditions: """Helper class for creating common initial conditions."""
[docs] @staticmethod def zeros(shape: Tuple[int, int]) -> np.ndarray: """Zero initial condition.""" return np.zeros(shape)
[docs] @staticmethod def constant(shape: Tuple[int, int], value: float) -> np.ndarray: """Constant field initial condition.""" return np.full(shape, value)
[docs] @staticmethod def gaussian_pulse( shape: Tuple[int, int], center: Tuple[float, float], sigma: float, amplitude: float = 1.0 ) -> np.ndarray: """Gaussian pulse initial condition.""" if isinstance(shape, int): nx = ny = shape else: nx, ny = shape return create_gaussian_pulse((nx, ny), center, sigma, amplitude)
[docs] @staticmethod def circular_wave( shape: Tuple[int, int], center: Tuple[float, float], radius: float, amplitude: float = 1.0 ) -> np.ndarray: """Circular wave initial condition.""" if isinstance(shape, int): nx = ny = shape else: nx, ny = shape return create_circular_wave((nx, ny), center, radius, amplitude)
[docs] @staticmethod def multiple_sources(shape: Tuple[int, int], sources) -> np.ndarray: """Multiple point sources initial condition. Parameters: ----------- shape : tuple Grid shape (nx, ny) sources : list List of (x, y, amplitude) tuples for source locations """ u0 = np.zeros(shape) for x, y, amplitude in sources: if 0 <= x < shape[0] and 0 <= y < shape[1]: u0[int(x), int(y)] = amplitude return u0
[docs] @staticmethod def sine_wave( shape: Tuple[int, int], wavelength: float, amplitude: float = 1.0, direction: str = "x" ) -> np.ndarray: """Sinusoidal wave initial condition. Parameters: ----------- shape : tuple Grid shape (nx, ny) wavelength : float Wavelength in grid units amplitude : float Wave amplitude direction : str Wave direction ('x', 'y', or 'diagonal') """ nx, ny = shape x = np.linspace(0, nx - 1, nx) y = np.linspace(0, ny - 1, ny) X, Y = np.meshgrid(x, y, indexing="ij") if direction == "x": return amplitude * np.sin(2 * np.pi * X / wavelength) elif direction == "y": return amplitude * np.sin(2 * np.pi * Y / wavelength) elif direction == "diagonal": return amplitude * np.sin(2 * np.pi * (X + Y) / (wavelength * np.sqrt(2))) else: raise ValueError("direction must be 'x', 'y', or 'diagonal'")
[docs] class PDESolver: """ Unified interface for solving 2D partial differential equations. This class provides a clean, professional API for solving heat and wave equations with flexible boundary conditions and initial conditions. """
[docs] def __init__( self, equation: Union[str, EquationType], grid_shape: Tuple[int, int] = (100, 100), spacing: Tuple[float, float] = (1.0, 1.0), boundary: Optional[BoundaryCondition] = None, ): """ Initialize the PDE solver. Parameters: ----------- equation : str or EquationType Type of PDE to solve ('heat' or 'wave') grid_shape : tuple Grid dimensions (nx, ny) spacing : tuple Grid spacing (dx, dy) boundary : BoundaryCondition, optional Boundary condition (default: Dirichlet with value=0) """ # Handle string input if isinstance(equation, str): equation = EquationType(equation.lower()) self.equation = equation self.grid_shape = grid_shape self.dx, self.dy = spacing # Default boundary condition if boundary is None: boundary = BoundaryCondition.dirichlet(0.0) self.boundary = boundary # PDE-specific parameters self.parameters = {} self._initial_conditions = None self._initial_velocity = None # Set default parameters if self.equation == EquationType.HEAT: self.parameters = {"alpha": 1.0, "dt": 0.1} elif self.equation == EquationType.WAVE: self.parameters = {"c": 1.0, "dt": 0.05}
[docs] def set_initial_conditions(self, u0: np.ndarray, v0: Optional[np.ndarray] = None): """ Set initial conditions for the PDE. Parameters: ----------- u0 : numpy.ndarray Initial field (temperature for heat, amplitude for wave) v0 : numpy.ndarray, optional Initial velocity (for wave equation only) """ if u0.shape != self.grid_shape: raise ValueError( f"Initial condition shape {u0.shape} doesn't match " f"grid shape {self.grid_shape}" ) self._initial_conditions = u0.copy() if self.equation == EquationType.WAVE and v0 is not None: if v0.shape != self.grid_shape: raise ValueError( f"Initial velocity shape {v0.shape} doesn't match " f"grid shape {self.grid_shape}" ) self._initial_velocity = v0.copy() else: self._initial_velocity = None
[docs] def set_parameters(self, **params): """ Set equation-specific parameters. For heat equation: alpha (thermal diffusivity), dt (time step) For wave equation: c (wave speed), dt (time step) """ for key, value in params.items(): if key in ["alpha", "c", "dt", "dx", "dy"]: self.parameters[key] = value else: raise ValueError(f"Unknown parameter: {key}") # Update spacing if provided if "dx" in params: self.dx = params["dx"] if "dy" in params: self.dy = params["dy"]
[docs] def get_stability_info(self) -> Dict[str, Any]: """ Get stability condition information for current parameters. Returns: -------- dict Information about stability conditions and current factor """ dt = self.parameters.get("dt", 0.1) if self.equation == EquationType.HEAT: alpha = self.parameters.get("alpha", 1.0) factor = alpha * dt * (1 / self.dx**2 + 1 / self.dy**2) limit = 0.5 condition = f"α * dt * (1/dx² + 1/dy²) ≤ {limit}" elif self.equation == EquationType.WAVE: c = self.parameters.get("c", 1.0) factor = c * dt * np.sqrt(1 / self.dx**2 + 1 / self.dy**2) limit = 1.0 condition = f"c * dt * √(1/dx² + 1/dy²) ≤ {limit}" else: # This should never happen, but we need to handle it for type checking factor = 0.0 limit = 1.0 condition = "Unknown equation type" return { "condition": condition, "current_factor": factor, "limit": limit, "is_stable": factor <= limit, "safety_margin": limit - factor, }
[docs] def validate_stability(self): """Validate that current parameters satisfy stability conditions.""" info = self.get_stability_info() if not info["is_stable"]: raise ValueError( f"Stability condition violated for {self.equation.value} equation!\n" f"Condition: {info['condition']}\n" f"Current factor: {info['current_factor']:.4f} > {info['limit']}\n" f"Reduce dt or increase dx/dy." )
[docs] def solve(self, steps: int = 100) -> np.ndarray: """ Solve the PDE for the specified number of time steps. Parameters: ----------- steps : int Number of time steps to solve Returns: -------- numpy.ndarray Final solution field """ if self._initial_conditions is None: raise ValueError("Initial conditions not set. Use set_initial_conditions().") # Validate stability self.validate_stability() # Get parameters dt = self.parameters.get("dt", 0.1) # Convert boundary condition to BoundarySpec boundary_spec = self.boundary.to_boundary_spec() if self.equation == EquationType.HEAT: alpha = self.parameters.get("alpha", 1.0) # Always use flexible boundary conditions for non-default boundaries # Only use original solver for default Dirichlet(0.0) if self.boundary.type == BoundaryType.DIRICHLET and self.boundary.value == 0.0: # Use original solver for default case (better performance) return solve_heat( self._initial_conditions, α=alpha, dt=dt, dx=self.dx, dy=self.dy, steps=steps ) else: # Use flexible boundary conditions for all other cases return solve_heat_with_boundaries( self._initial_conditions, boundary_spec, α=alpha, dt=dt, dx=self.dx, dy=self.dy, steps=steps, ) elif self.equation == EquationType.WAVE: c = self.parameters.get("c", 1.0) # Always use flexible boundary conditions for non-default boundaries # Only use original solver for default Dirichlet(0.0) if self.boundary.type == BoundaryType.DIRICHLET and self.boundary.value == 0.0: # Use original solver for default case (better performance) return solve_wave( self._initial_conditions, v0=self._initial_velocity, c=c, dt=dt, dx=self.dx, dy=self.dy, steps=steps, ) else: # Use flexible boundary conditions for all other cases return solve_wave_with_boundaries( self._initial_conditions, boundary_spec, v0=self._initial_velocity, c=c, dt=dt, dx=self.dx, dy=self.dy, steps=steps, ) else: raise ValueError(f"Unknown equation type: {self.equation}")
[docs] def animate( self, frames: int = 100, interval: int = 50, save_path: Optional[str] = None ) -> FuncAnimation: """ Create an animation of the PDE solution. Parameters: ----------- frames : int Number of animation frames interval : int Time between frames in milliseconds save_path : str, optional Path to save animation (e.g., 'animation.gif') Returns: -------- matplotlib.animation.FuncAnimation Animation object """ if self._initial_conditions is None: raise ValueError("Initial conditions not set. Use set_initial_conditions().") # Validate stability self.validate_stability() # Get parameters dt = self.parameters.get("dt", 0.1) if self.equation == EquationType.HEAT: alpha = self.parameters.get("alpha", 1.0) anim = animate_heat( self._initial_conditions, α=alpha, dt=dt, dx=self.dx, dy=self.dy, frames=frames, interval=interval, ) elif self.equation == EquationType.WAVE: c = self.parameters.get("c", 1.0) anim = animate_wave( self._initial_conditions, v0=self._initial_velocity, c=c, dt=dt, dx=self.dx, dy=self.dy, frames=frames, interval=interval, ) else: raise ValueError(f"Unknown equation type: {self.equation}") # Save if path provided if save_path: anim.save(save_path, writer="pillow") print(f"Animation saved to {save_path}") return anim
[docs] def info(self) -> str: """ Get a summary of the current solver configuration. Returns: -------- str Formatted summary of solver state """ stability = self.get_stability_info() info_str = f""" PDE Solver Configuration: ======================== Equation Type: {self.equation.value.title()} Grid Shape: {self.grid_shape} Grid Spacing: dx={self.dx}, dy={self.dy} Boundary Condition: {self.boundary.type.value} (value={self.boundary.value}) Parameters: {chr(10).join(f" {k}: {v}" for k, v in self.parameters.items())} Stability: Condition: {stability['condition']} Current Factor: {stability['current_factor']:.4f} Limit: {stability['limit']} Status: {'✅ STABLE' if stability['is_stable'] else '❌ UNSTABLE'} Safety Margin: {stability['safety_margin']:.4f} Initial Conditions: {'✅ Set' if self._initial_conditions is not None else '❌ Not Set'} """ if self.equation == EquationType.WAVE: info_str += f"Initial Velocity: {'✅ Set' if self._initial_velocity is not None else '❌ Not Set'}\n" return info_str