Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Bug in plotting and computing tilted plane intersections of transformed 0 thickness geometries.
- `Simulation.to_gdspy()` and `Simulation.to_gdstk()` now place polygons in GDS layer `(0, 0)` when no `gds_layer_dtype_map` is provided instead of erroring.
- `task_id` now properly stored in `JaxSimulationData`.

## [2.7.0rc1] - 2024-04-22

Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
14 changes: 11 additions & 3 deletions tidy3d/plugins/adjoint/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def run(
callback_url=callback_url,
verbose=verbose,
)
# TODO: add task_id
return JaxSimulationData.from_sim_data(sim_data, jax_info)


Expand Down Expand Up @@ -151,7 +152,9 @@ def run_fwd(
)

res = RunResidual(fwd_task_id=task_id)
jax_sim_data_orig = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
jax_sim_data_orig = JaxSimulationData.from_sim_data(
sim_data_orig, jax_info_orig, task_id=task_id
)

return jax_sim_data_orig, (res,)

Expand Down Expand Up @@ -410,6 +413,7 @@ def run_async(
task_name = str(_task_name_orig(i))
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
# TODO: add task_id
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)

Expand Down Expand Up @@ -450,8 +454,10 @@ def run_async_fwd(
batch_data_orig = [sim_data for _, sim_data in batch_data_orig.items()]

jax_batch_data_orig = []
for sim_data_orig, jax_info_orig in zip(batch_data_orig, jax_infos_orig):
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
for sim_data_orig, jax_info_orig, task_id in zip(batch_data_orig, jax_infos_orig, fwd_task_ids):
jax_sim_data = JaxSimulationData.from_sim_data(
sim_data_orig, jax_info_orig, task_id=task_id
)
jax_batch_data_orig.append(jax_sim_data)

residual = RunResidualBatch(fwd_task_ids=fwd_task_ids)
Expand Down Expand Up @@ -626,6 +632,7 @@ def run_local(
)

# convert back to jax type and return
# TODO: add task_id
return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)


Expand Down Expand Up @@ -779,6 +786,7 @@ def run_async_local(
task_name = _task_name_orig_local(i, task_name_suffix)
sim_data_tidy3d = batch_data_tidy3d[task_name]
jax_info = jax_infos[str(task_name)]
# TODO: add task_id
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
jax_batch_data.append(jax_sim_data)

Expand Down