Source code for astronomix._physics_modules._cosmic_rays.cr_fluid_equations
# general
import jax.numpy as jnp
import jax
from functools import partial
# typing
from jaxtyping import Array, Float, jaxtyped
from beartype import beartype as typechecker
from typing import Union
# astronomix classes
from astronomix.variable_registry.registered_variables import RegisteredVariables
# NOTE: currently only supports 1d setups, TODO: generalize
# HERE SET MANUALLY, SHOULD COME FROM
# THE SIMULATION PARAMS
gamma_gas = 5 / 3
gamma_cr = 4 / 3
# @jaxtyped(typechecker=typechecker)
[docs]
@partial(jax.jit, static_argnames=["registered_variables"])
def total_energy_from_primitives_with_crs(
primitive_state: Float[Array, "num_vars num_cells"],
registered_variables: RegisteredVariables,
) -> Float[Array, "num_cells"]:
"""
Calculates the total energy density from primitive variables in a system with cosmic rays.
Args:
primitive_state: Array of primitive variables
registered_variables: Object containing indices for accessing different physical quantities
Returns:
Total energy density array
"""
# get the cosmic ray pressure
cosmic_ray_pressure = (
primitive_state[registered_variables.cosmic_ray_n_index] ** gamma_cr
)
# get the cosmic ray energy (density)
cosmic_ray_energy = cosmic_ray_pressure / (gamma_cr - 1)
# get the gas pressure
gas_pressure = (
primitive_state[registered_variables.pressure_index] - cosmic_ray_pressure
)
# get the gas energy
rho_gas = primitive_state[registered_variables.density_index]
velocity = primitive_state[registered_variables.velocity_index]
gas_energy = gas_pressure / (gamma_gas - 1) + 0.5 * rho_gas * velocity**2
# total energy
E_tot = gas_energy + cosmic_ray_energy
return E_tot
# @jaxtyped(typechecker=typechecker)
[docs]
@partial(jax.jit, static_argnames=["registered_variables"])
def gas_pressure_from_primitives_with_crs(
primitive_state: Float[Array, "num_vars num_cells"],
registered_variables: RegisteredVariables,
) -> Float[Array, "num_cells"]:
"""
Calculates the gas pressure from the primitive state when cosmic rays
are considered in the simulation.
Args:
primitive_state: Array of primitive variables
registered_variables: Object containing indices for accessing different physical quantities
Returns:
gas pressure
"""
# get the cosmic ray pressure
cosmic_ray_pressure = (
primitive_state[registered_variables.cosmic_ray_n_index] ** gamma_cr
)
# return the gas pressure
return primitive_state[registered_variables.pressure_index] - cosmic_ray_pressure
# TODO: make 2D and 3D ready
# @jaxtyped(typechecker=typechecker)
[docs]
@partial(jax.jit, static_argnames=["registered_variables"])
def total_pressure_from_conserved_with_crs(
conserved_state: Float[Array, "num_vars num_cells"],
registered_variables: RegisteredVariables,
) -> Float[Array, "num_cells"]:
"""
Calculates the total pressure from the conserved state when cosmic rays
are considered in the simulation.
Args:
primitive_state: Array of primitive variables
registered_variables: Object containing indices for accessing different physical quantities
Returns:
total pressure
"""
# get the cosmic ray pressure
cosmic_ray_pressure = (
conserved_state[registered_variables.cosmic_ray_n_index] ** gamma_cr
)
# get the cosmic ray energy (density)
cosmic_ray_energy = cosmic_ray_pressure / (gamma_cr - 1)
# get the gas energy
gas_energy = (
conserved_state[registered_variables.pressure_index] - cosmic_ray_energy
)
# get the gas pressure
rho_gas = conserved_state[registered_variables.density_index]
velocity = conserved_state[registered_variables.velocity_index] / rho_gas
gas_pressure = (gas_energy - 0.5 * rho_gas * velocity**2) * (gamma_gas - 1)
# get the total pressure
total_pressure = cosmic_ray_pressure + gas_pressure
return total_pressure
# @jaxtyped(typechecker=typechecker)
[docs]
@partial(jax.jit, static_argnames=["registered_variables"])
def speed_of_sound_crs(
primitive_state: Float[Array, "num_vars num_cells"],
registered_variables: RegisteredVariables,
) -> Float[Array, "num_cells"]:
"""
Calculates the speed of sound from the primitive state
when cosmic rays are considered in the simulation, where
c_s = sqrt((gamma_gas * P_gas + gamma_cr * P_CR) / rho)
Args:
primitive_state: Array of primitive variables
registered_variables: Object containing indices for accessing different physical quantities
Returns:
sound speed
"""
# get the cosmic ray pressure
cosmic_ray_pressure = (
primitive_state[registered_variables.cosmic_ray_n_index] ** gamma_cr
)
# get the gas pressure
gas_pressure = (
primitive_state[registered_variables.pressure_index] - cosmic_ray_pressure
)
return jnp.sqrt(
(gamma_gas * gas_pressure + gamma_cr * cosmic_ray_pressure)
/ primitive_state[registered_variables.density_index]
)