diff --git a/CHANGELOG.md b/CHANGELOG.md index 2fce45d752..58ca03e7ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Bug in plotting and computing tilted plane intersections of transformed 0 thickness geometries. - `Simulation.to_gdspy()` and `Simulation.to_gdstk()` now place polygons in GDS layer `(0, 0)` when no `gds_layer_dtype_map` is provided instead of erroring. +- `task_id` now properly stored in `JaxSimulationData`. ## [2.7.0rc1] - 2024-04-22 diff --git a/docs/faq b/docs/faq index 79c529e8dd..d2671008d9 160000 --- a/docs/faq +++ b/docs/faq @@ -1 +1 @@ -Subproject commit 79c529e8dd46205074f22cd1924ad8b7d51c6d67 +Subproject commit d2671008d950672548ed029e27c1417f6b8b2b8c diff --git a/docs/notebooks b/docs/notebooks index 8ea4d93998..e6f5a19c84 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 8ea4d93998f5db72a9437c0d50bb6b5e175353d6 +Subproject commit e6f5a19c8494a8d4f7106864a8fcb85b25e1d040 diff --git a/tidy3d/plugins/adjoint/web.py b/tidy3d/plugins/adjoint/web.py index 38a6d9b1da..b99696ffd2 100644 --- a/tidy3d/plugins/adjoint/web.py +++ b/tidy3d/plugins/adjoint/web.py @@ -123,6 +123,7 @@ def run( callback_url=callback_url, verbose=verbose, ) + # TODO: add task_id return JaxSimulationData.from_sim_data(sim_data, jax_info) @@ -151,7 +152,9 @@ def run_fwd( ) res = RunResidual(fwd_task_id=task_id) - jax_sim_data_orig = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig) + jax_sim_data_orig = JaxSimulationData.from_sim_data( + sim_data_orig, jax_info_orig, task_id=task_id + ) return jax_sim_data_orig, (res,) @@ -410,6 +413,7 @@ def run_async( task_name = str(_task_name_orig(i)) sim_data_tidy3d = batch_data_tidy3d[task_name] jax_info = jax_infos[str(task_name)] + # TODO: add task_id jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info) jax_batch_data.append(jax_sim_data) @@ -450,8 +454,10 @@ def run_async_fwd( batch_data_orig = [sim_data for _, sim_data in batch_data_orig.items()] jax_batch_data_orig = [] - for sim_data_orig, jax_info_orig in zip(batch_data_orig, jax_infos_orig): - jax_sim_data = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig) + for sim_data_orig, jax_info_orig, task_id in zip(batch_data_orig, jax_infos_orig, fwd_task_ids): + jax_sim_data = JaxSimulationData.from_sim_data( + sim_data_orig, jax_info_orig, task_id=task_id + ) jax_batch_data_orig.append(jax_sim_data) residual = RunResidualBatch(fwd_task_ids=fwd_task_ids) @@ -626,6 +632,7 @@ def run_local( ) # convert back to jax type and return + # TODO: add task_id return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info) @@ -779,6 +786,7 @@ def run_async_local( task_name = _task_name_orig_local(i, task_name_suffix) sim_data_tidy3d = batch_data_tidy3d[task_name] jax_info = jax_infos[str(task_name)] + # TODO: add task_id jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info) jax_batch_data.append(jax_sim_data)