{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Visualization for Radial 1D Stellar Wind Simulation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# ==== GPU selection ====\n", "from autocvd import autocvd\n", "autocvd(num_gpus = 1)\n", "# =======================\n", "\n", "# numerics\n", "import jax\n", "import jax.numpy as jnp\n", "# # for now using CPU as of outdated NVIDIA Driver\n", "# jax.config.update('jax_platform_name', 'cpu')\n", "# # jax.config.update('jax_disable_jit', True)\n", "# # 64-bit precision\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "# debug nans\n", "# jax.config.update(\"jax_debug_nans\", True)\n", "\n", "# timing\n", "from timeit import default_timer as timer\n", "\n", "# plotting\n", "import matplotlib.pyplot as plt\n", "from matplotlib.gridspec import GridSpec\n", "\n", "# fluids\n", "from astronomix import WindParams\n", "from astronomix import SimulationConfig\n", "from astronomix import get_helper_data\n", "from astronomix import SimulationParams\n", "from astronomix import time_integration\n", "from astronomix import construct_primitive_state\n", "\n", "\n", "from astronomix import get_registered_variables\n", "from astronomix.option_classes import WindConfig\n", "\n", "\n", "# units\n", "from astronomix import CodeUnits\n", "from astropy import units as u\n", "import astropy.constants as c\n", "from astropy.constants import m_p\n", "\n", "# wind-specific\n", "from astronomix._physics_modules._stellar_wind.weaver import Weaver" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initiating the stellar wind simulation" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "👷 Setting up simulation...\n" ] } ], "source": [ "from astronomix.option_classes.simulation_config import OPEN_BOUNDARY, REFLECTIVE_BOUNDARY, SPHERICAL\n", "\n", "\n", "print(\"👷 Setting up simulation...\")\n", "\n", "# simulation settings\n", "gamma = 5/3\n", "\n", "# spatial domain\n", "geometry = SPHERICAL\n", "box_size = 1.0\n", "num_cells = 401\n", "\n", "left_boundary = REFLECTIVE_BOUNDARY\n", "right_boundary = OPEN_BOUNDARY\n", "\n", "# activate stellar wind\n", "stellar_wind = True\n", "\n", "fixed_timestep = True\n", "num_timesteps = 10000\n", "\n", "# setup simulation config\n", "config = SimulationConfig(\n", " runtime_debugging = True,\n", " geometry = geometry,\n", " box_size = box_size, \n", " num_cells = num_cells,\n", " wind_config = WindConfig(\n", " stellar_wind = stellar_wind,\n", " num_injection_cells = 10,\n", " trace_wind_density = False,\n", " ),\n", " # fixed_timestep = fixed_timestep,\n", " # num_timesteps = num_timesteps,\n", " # first_order_fallback = True,\n", ")\n", "\n", "helper_data = get_helper_data(config)\n", "\n", "registered_variables = get_registered_variables(config)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from astronomix.option_classes.simulation_config import HLL\n", "\n", "\n", "config_high_res = SimulationConfig(\n", " riemann_solver = HLL,\n", " geometry = geometry,\n", " box_size = box_size, \n", " num_cells = 2001,\n", " wind_config = WindConfig(\n", " stellar_wind = stellar_wind,\n", " num_injection_cells = 10,\n", " ),\n", " # fixed_timestep = fixed_timestep,\n", " # num_timesteps = num_timesteps,\n", " # first_order_fallback = True,\n", ")\n", "\n", "helper_data_high_res = get_helper_data(config_high_res)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting the simulation parameters and initial state" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "For spherical geometry, only HLL is currently supported. Also, only the unsplit mode has been tested.\n", "Setting unsplit mode for spherical geometry\n", "Setting MUSCL time integrator for spherical geometry\n", "Automatically setting reflective left and open right boundary for spherical geometry.\n", "For stellar wind simulations, we need source term aware timesteps, turning on.\n", "For spherical geometry, only HLL is currently supported. Also, only the unsplit mode has been tested.\n", "Setting unsplit mode for spherical geometry\n", "Setting MUSCL time integrator for spherical geometry\n", "Automatically setting reflective left and open right boundary for spherical geometry.\n", "For stellar wind simulations, we need source term aware timesteps, turning on.\n" ] } ], "source": [ "# code units\n", "from astronomix.option_classes.simulation_config import finalize_config\n", "\n", "\n", "code_length = 3 * u.parsec\n", "code_mass = 1e-3 * u.M_sun\n", "code_velocity = 1 * u.km / u.s\n", "code_units = CodeUnits(code_length, code_mass, code_velocity)\n", "\n", "# time domain\n", "C_CFL = 0.8\n", "t_final = 2.5 * 1e4 * u.yr\n", "t_end = t_final.to(code_units.code_time).value\n", "dt_max = 0.1 * t_end\n", "\n", "# wind parameters\n", "M_star = 40 * u.M_sun\n", "wind_final_velocity = 2000 * u.km / u.s\n", "wind_mass_loss_rate = 2.965e-3 / (1e6 * u.yr) * M_star\n", "\n", "wind_params = WindParams(\n", " wind_mass_loss_rate = wind_mass_loss_rate.to(code_units.code_mass / code_units.code_time).value,\n", " wind_final_velocity = wind_final_velocity.to(code_units.code_velocity).value\n", ")\n", "\n", "params = SimulationParams(\n", " C_cfl = C_CFL,\n", " dt_max = dt_max,\n", " gamma = gamma,\n", " t_end = t_end,\n", " wind_params=wind_params\n", ")\n", "\n", "params_high_res = SimulationParams(\n", " C_cfl = C_CFL,\n", " dt_max = dt_max,\n", " gamma = gamma,\n", " t_end = t_end,\n", " wind_params=wind_params\n", ")\n", "\n", "# homogeneous initial state\n", "rho_0 = 2 * c.m_p / u.cm**3\n", "p_0 = 3e4 * u.K / u.cm**3 * c.k_B\n", "\n", "rho_init = jnp.ones(num_cells) * rho_0.to(code_units.code_density).value\n", "u_init = jnp.zeros(num_cells)\n", "p_init = jnp.ones(num_cells) * p_0.to(code_units.code_pressure).value\n", "\n", "# get initial state\n", "initial_state = construct_primitive_state(\n", " config = config,\n", " registered_variables = registered_variables,\n", " density = rho_init,\n", " velocity_x = u_init,\n", " gas_pressure = p_init\n", ")\n", "\n", "config = finalize_config(config, initial_state.shape)\n", "\n", "# initial state high res\n", "rho_init_high_res = jnp.ones(config_high_res.num_cells) * rho_0.to(code_units.code_density).value\n", "u_init_high_res = jnp.zeros(config_high_res.num_cells)\n", "p_init_high_res = jnp.ones(config_high_res.num_cells) * p_0.to(code_units.code_pressure).value\n", "\n", "initial_state_high_res = construct_primitive_state(\n", " config = config_high_res,\n", " registered_variables = registered_variables,\n", " density = rho_init_high_res,\n", " velocity_x = u_init_high_res,\n", " gas_pressure = p_init_high_res\n", ")\n", "\n", "config_high_res = finalize_config(config_high_res, initial_state_high_res.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Simulation and Gradient" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dv = 0.1 km / s\n" ] } ], "source": [ "final_state = time_integration(initial_state, config, params, registered_variables)\n", "\n", "# high res final state\n", "final_state_high_res = time_integration(initial_state_high_res, config_high_res, params_high_res, registered_variables)\n", "\n", "def integrator(velocity):\n", " return time_integration(initial_state, config, SimulationParams(C_cfl=params.C_cfl, dt_max=params.dt_max, gamma=params.gamma, t_end=params.t_end, wind_params=WindParams(wind_mass_loss_rate=params.wind_params.wind_mass_loss_rate, wind_final_velocity=velocity)), registered_variables)\n", "\n", "vel_sens = jax.jacfwd(integrator)(params.wind_params.wind_final_velocity)\n", "\n", "# calculate the finite difference derivative\n", "dv = 0.1\n", "# print dv in km/s\n", "print(f\"dv = {(dv * code_units.code_velocity).to(u.km/u.s)}\")\n", "vel_sens_fd = (integrator(params.wind_params.wind_final_velocity + dv) - integrator(params.wind_params.wind_final_velocity - dv)) / (2 * dv)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "👷 generating plots\n", "0.00852260137538079 code_length / code_velocity\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_2542770/3031789449.py:149: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "def plot_weaver_comparison(axs, final_state, params, helper_data, code_units, rho_0, p_0):\n", " print(\"👷 generating plots\")\n", "\n", " rho = final_state[registered_variables.density_index]\n", " vel = final_state[registered_variables.velocity_index]\n", " p = final_state[registered_variables.pressure_index]\n", "\n", " rho = rho * code_units.code_density\n", " vel = vel * code_units.code_velocity\n", " p = p * code_units.code_pressure\n", "\n", " r_high_res = helper_data_high_res.geometric_centers * code_units.code_length\n", "\n", " rho_high_res = final_state_high_res[registered_variables.density_index]\n", " vel_high_res = final_state_high_res[registered_variables.velocity_index]\n", " p_high_res = final_state_high_res[registered_variables.pressure_index]\n", "\n", " rho_high_res = rho_high_res * code_units.code_density\n", " vel_high_res = vel_high_res * code_units.code_velocity\n", " p_high_res = p_high_res * code_units.code_pressure\n", "\n", " r = helper_data.geometric_centers * code_units.code_length\n", "\n", " # get weaver solution\n", " weaver = Weaver(\n", " params.wind_params.wind_final_velocity * code_units.code_velocity,\n", " params.wind_params.wind_mass_loss_rate * code_units.code_mass / code_units.code_time,\n", " rho_0,\n", " p_0\n", " )\n", " current_time = params.t_end * code_units.code_time# + 12e-4 * code_units.code_time\n", " print(current_time)\n", " \n", " # density\n", " r_density_weaver, density_weaver = weaver.get_density_profile(0.01 * u.parsec, 3.5 * u.parsec, current_time)\n", " r_density_weaver = r_density_weaver.to(u.parsec)\n", " density_weaver = (density_weaver / m_p).to(u.cm**-3)\n", "\n", " # velocity\n", " r_velocity_weaver, velocity_weaver = weaver.get_velocity_profile(0.01 * u.parsec, 3.5 * u.parsec, current_time)\n", " r_velocity_weaver = r_velocity_weaver.to(u.parsec)\n", " velocity_weaver = velocity_weaver.to(u.km / u.s)\n", "\n", " # pressure\n", " r_pressure_weaver, pressure_weaver = weaver.get_pressure_profile(0.01 * u.parsec, 3.5 * u.parsec, current_time)\n", " r_pressure_weaver = r_pressure_weaver.to(u.parsec)\n", " pressure_weaver = (pressure_weaver / c.k_B).to(u.cm**-3 * u.K)\n", "\n", " axs[0].set_yscale(\"log\")\n", " axs[0].plot(r.to(u.parsec), (rho / m_p).to(u.cm**-3), label=\"astronomix\")\n", "\n", " axs[0].plot(r_density_weaver, density_weaver, \"--\", label=\"Weaver solution\")\n", "\n", " axs[0].plot(r_high_res.to(u.parsec), (rho_high_res / m_p).to(u.cm**-3), \"-.\", label=\"astronomix, N = {}\".format(config_high_res.num_cells))\n", "\n", " axs[0].set_title(\"density\")\n", " axs[0].set_ylabel(r\"$\\rho$ in m$_p$ cm$^{-3}$\")\n", " axs[0].set_xlim(0, 3)\n", "\n", " axs[0].legend(loc=\"upper left\")\n", "\n", " # turn off x ticks\n", " axs[0].set_xticks([])\n", " axs[1].set_xticks([])\n", " axs[2].set_xticks([])\n", "\n", " axs[1].set_yscale(\"log\")\n", " axs[1].plot(r.to(u.parsec), (p / c.k_B).to(u.K / u.cm**3), label=\"astronomix\")\n", " axs[1].plot(r_pressure_weaver, pressure_weaver, \"--\", label=\"Weaver solution\")\n", " axs[1].plot(r_high_res.to(u.parsec), (p_high_res / c.k_B).to(u.K / u.cm**3), \"-.\", label=\"astronomix, N = {}\".format(config_high_res.num_cells))\n", "\n", " axs[1].set_title(\"pressure\")\n", " axs[1].set_ylabel(r\"$p$/k$_b$ in K cm$^{-3}$\")\n", " axs[1].set_xlim(0, 3)\n", "\n", " axs[1].legend(loc=\"upper left\")\n", "\n", "\n", " axs[2].set_yscale(\"log\")\n", " axs[2].plot(r.to(u.parsec), vel.to(u.km / u.s), label=\"astronomix\")\n", " axs[2].plot(r_velocity_weaver, velocity_weaver, \"--\", label=\"Weaver solution\")\n", " axs[2].plot(r_high_res.to(u.parsec), vel_high_res.to(u.km / u.s), \"-.\", label=\"astronomix, N = {}\".format(config_high_res.num_cells))\n", " axs[2].set_title(\"velocity\")\n", " # ylim 1 to 1e4 km/s\n", " axs[2].set_ylim(1, 1e4)\n", " axs[2].set_xlim(0, 3)\n", " axs[2].set_ylabel(\"v in km/s\")\n", " # xlabel\n", " # show legend upper left\n", " axs[2].legend(loc=\"upper right\")\n", "\n", "def sensitivity_plot(axs, vel_sens, vel_sens_fd):\n", "\n", " rho_sens_vel = vel_sens[registered_variables.density_index]\n", " vel_sens_vel = vel_sens[registered_variables.velocity_index]\n", " p_sens_vel = vel_sens[registered_variables.pressure_index]\n", "\n", " rho_sens_vel_fd = vel_sens_fd[registered_variables.density_index]\n", " vel_sens_vel_fd = vel_sens_fd[registered_variables.velocity_index]\n", " p_sens_vel_fd = vel_sens_fd[registered_variables.pressure_index]\n", "\n", " r = helper_data.geometric_centers * code_units.code_length\n", "\n", " axs[0].plot(r.to(u.parsec), rho_sens_vel, label=r\"d$\\rho$/dv$_\\infty$ autodiff\")\n", " axs[0].plot(r.to(u.parsec), rho_sens_vel_fd, \"--\", label=r\"d$\\rho$/dv$_\\infty$ finite diff.\")\n", " axs[0].set_ylabel(r\"d$\\rho$/dv$_\\infty$\")\n", " axs[0].legend(loc = \"upper left\")\n", " axs[0].tick_params(axis='y')\n", " axs[0].set_yscale('symlog')\n", " axs[0].set_xlim(0, 3)\n", " axs[0].set_xlabel(\"r in pc\")\n", " axs[0].yaxis.set_label_coords(-0.15, 0.5)\n", "\n", " axs[1].plot(r.to(u.parsec), p_sens_vel, label=r\"dp/dv$_\\infty$ autodiff\")\n", " axs[1].plot(r.to(u.parsec), p_sens_vel_fd, \"--\", label=r\"dp/dv$_\\infty$ finite diff.\")\n", " axs[1].set_ylabel(r\"dp/dv$_\\infty$\")\n", " axs[1].legend(loc = \"lower right\")\n", " axs[1].tick_params(axis='y')\n", " axs[1].set_yscale('symlog')\n", " axs[1].set_xlim(0, 3)\n", " axs[1].set_xlabel(\"r in pc\")\n", " axs[1].yaxis.set_label_coords(-0.15, 0.5)\n", "\n", " axs[2].plot(r.to(u.parsec), vel_sens_vel, label=r\"dv/dv$_\\infty$ autodiff\")\n", " axs[2].plot(r.to(u.parsec), vel_sens_vel_fd, \"--\", label=r\"dv/dv$_\\infty$ finite diff.\")\n", " axs[2].set_ylabel(r\"dv/dv$_\\infty$\")\n", " axs[2].legend(loc = \"upper right\")\n", " axs[2].tick_params(axis='y')\n", " axs[2].set_yscale('symlog')\n", " axs[2].set_xlim(0, 3)\n", " axs[2].set_xlabel(\"r in pc\")\n", " axs[2].yaxis.set_label_coords(-0.15, 0.5)\n", "\n", " axs[0].yaxis.set_major_locator(plt.MaxNLocator(3))\n", " axs[1].yaxis.set_major_locator(plt.MaxNLocator(6))\n", " axs[2].yaxis.set_major_locator(plt.MaxNLocator(3))\n", "\n", "\n", "fig = plt.figure(figsize=(14, 4.5))\n", "\n", "gs = GridSpec(2, 3, height_ratios=[3, 2], figure=fig, hspace=0.1, wspace=0.3)\n", "\n", "axs_upper = [fig.add_subplot(gs[0, i]) for i in range(3)]\n", "axs_lower = [fig.add_subplot(gs[1, i]) for i in range(3)]\n", "\n", "plot_weaver_comparison(axs_upper, final_state, params, helper_data, code_units, rho_0, p_0)\n", "sensitivity_plot(axs_lower, vel_sens, vel_sens_fd)\n", "\n", "plt.tight_layout()\n", "\n", "# TODO: add finite difference here\n", "\n", "plt.savefig(\"../figures/gradients_through_stellar_wind.pdf\", bbox_inches=\"tight\")" ] } ], "metadata": { "kernelspec": { "display_name": "f1uids", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }