Source code for astronomix._fluid_equations.total_quantities

# general
import jax
import jax.numpy as jnp
from functools import partial

# type checking
from beartype import beartype as typechecker
from jaxtyping import Array, Float, jaxtyped
from typing import Union

# astronomix containers
from astronomix._physics_modules._cosmic_rays.cr_fluid_equations import (
    total_energy_from_primitives_with_crs,
)
from astronomix.variable_registry.registered_variables import RegisteredVariables
from astronomix.option_classes.simulation_config import STATE_TYPE, SimulationConfig
from astronomix.data_classes.simulation_helper_data import HelperData

# astronomix functions
from astronomix._physics_modules._self_gravity._poisson_solver import (
    _compute_gravitational_potential,
)
from astronomix._fluid_equations._equations import (
    get_absolute_velocity,
    total_energy_from_primitives,
)


# @jaxtyped(typechecker=typechecker)
[docs] @partial(jax.jit, static_argnames=["config", "registered_variables"]) def calculate_internal_energy(state, helper_data, gamma, config, registered_variables): p = state[registered_variables.pressure_index] if config.cosmic_ray_config.cosmic_rays: gamma_cr = 4 / 3 p = p - state[registered_variables.cosmic_ray_n_index] ** gamma_cr internal_energy = p / (gamma - 1) if config.dimensionality == 1: return jnp.sum(internal_energy * helper_data.cell_volumes) else: return jnp.sum(internal_energy * config.grid_spacing**config.dimensionality)
# @jaxtyped(typechecker=typechecker)
[docs] @partial(jax.jit, static_argnames=["config", "registered_variables"]) def calculate_radial_momentum(state, helper_data, config, registered_variables): rho = state[registered_variables.density_index] box_center = jnp.zeros(config.dimensionality) + config.box_size / 2 geometric_centers = helper_data.geometric_centers r_hat = (geometric_centers - box_center) / jnp.linalg.norm( geometric_centers - box_center, axis=-1, keepdims=True ) if config.dimensionality == 1: u = state[registered_variables.velocity_index] else: u = state[ registered_variables.velocity_index.x : registered_variables.velocity_index.x + config.dimensionality ] u_radial = jnp.sum(jnp.moveaxis(u, 0, -1) * r_hat, axis=-1) radial_momentum = rho * u_radial if config.dimensionality == 1: return jnp.sum(radial_momentum * helper_data.cell_volumes) else: return jnp.sum(radial_momentum * config.grid_spacing**config.dimensionality)
# @jaxtyped(typechecker=typechecker)
[docs] @partial(jax.jit, static_argnames=["config", "registered_variables"]) def calculate_kinetic_energy(state, helper_data, config, registered_variables): rho = state[registered_variables.density_index] u = get_absolute_velocity(state, config, registered_variables) kinetic_energy = 0.5 * rho * u**2 if config.dimensionality == 1: return jnp.sum(kinetic_energy * helper_data.cell_volumes) else: return jnp.sum(kinetic_energy * config.grid_spacing**config.dimensionality)
# @jaxtyped(typechecker=typechecker)
[docs] @partial(jax.jit, static_argnames=["config", "registered_variables"]) def calculate_gravitational_energy( state, helper_data, gravitational_constant, config, registered_variables ): rho = state[registered_variables.density_index] potential = _compute_gravitational_potential( rho, config.grid_spacing, config, gravitational_constant ) gravitational_energy = 0.5 * rho * potential if config.dimensionality == 1: return jnp.sum(gravitational_energy * helper_data.cell_volumes) else: return jnp.sum( gravitational_energy * config.grid_spacing**config.dimensionality )
# @jaxtyped(typechecker=typechecker)
[docs] @partial(jax.jit, static_argnames=["config", "registered_variables"]) def calculate_total_energy( primitive_state: STATE_TYPE, helper_data: HelperData, gamma: Union[float, Float[Array, ""]], gravitational_constant: Union[float, Float[Array, ""]], config: SimulationConfig, registered_variables: RegisteredVariables, ) -> Float[Array, ""]: """ Calculate the total energy in the domain. Args: primitive_state: The primitive state array. helper_data: The helper data. gamma: The adiabatic index. num_ghost_cells: The number of ghost cells. Returns: The total energy. """ rho = primitive_state[registered_variables.density_index] u = get_absolute_velocity(primitive_state, config, registered_variables) p = primitive_state[registered_variables.pressure_index] if config.cosmic_ray_config.cosmic_rays: energy = total_energy_from_primitives_with_crs( primitive_state, registered_variables ) else: energy = total_energy_from_primitives(rho, u, p, gamma) if config.self_gravity: potential = _compute_gravitational_potential( rho, config.grid_spacing, config, gravitational_constant ) energy += 0.5 * rho * potential if config.dimensionality == 1: return jnp.sum(energy * helper_data.cell_volumes) else: return jnp.sum(energy * config.grid_spacing**config.dimensionality)
# @jaxtyped(typechecker=typechecker)
[docs] @partial(jax.jit, static_argnames=["config"]) def calculate_total_mass( primitive_state: STATE_TYPE, helper_data: HelperData, config: SimulationConfig, ) -> Float[Array, ""]: """ Calculate the total mass in the domain. Args: primitive_state: The primitive state array. helper_data: The helper data. config: The simulation configuration. Returns: The total mass. """ num_ghost_cells = config.num_ghost_cells if config.dimensionality == 1: return jnp.sum(primitive_state[0] * helper_data.cell_volumes) else: return jnp.sum(primitive_state[0]) * config.box_size**config.dimensionality