Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

An error occurred while running the metagrating case #156

Open
xungu1 opened this issue Dec 27, 2024 · 3 comments
Open

An error occurred while running the metagrating case #156

xungu1 opened this issue Dec 27, 2024 · 3 comments

Comments

@xungu1
Copy link

xungu1 commented Dec 27, 2024

Hello, I made a mistake when I ran the metagrating case of invrs _ gym for the first time. This problem also appeared in other cases. Is there a problem with where I set it ?
Will it be a problem with the python version ? My version of python is 3.8

Here 's my misinformation:
Traceback (most recent call last):
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev\pydevconsole.py", line 364, in runcode
coro = func()
File "", line 1, in
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_bundle\pydev_umd.py", line 198, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_imps_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "E:\Object\simulation_up\compare\gym-main\gym-main\src\demo_1.py", line 26, in
from invrs_gym import challenges
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_bundle\pydev_import_hook.py", line 21, in do_import
module = self.system_import(name, *args, **kwargs)
File "E:\Object\simulation_up\compare\gym-main\gym-main\src\invrs_gym_init
.py", line 9, in
from invrs_gym import challenges as challenges
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_bundle\pydev_import_hook.py", line 21, in do_import
module = self.system_import(name, *args, **kwargs)
File "E:\Object\simulation_up\compare\gym-main\gym-main\src\invrs_gym\challenges_init
.py", line 1, in
from invrs_gym.challenges.bayer.challenge import bayer_sorter as bayer_sorter
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_bundle\pydev_import_hook.py", line 21, in do_import
module = self._system_import(name, *args, **kwargs)
File "E:\Object\simulation_up\compare\gym-main\gym-main\src\invrs_gym\challenges\bayer\challenge.py", line 10, in
from fmmax import basis, fmm # type: ignore[import-untyped]
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_bundle\pydev_import_hook.py", line 21, in do_import
module = self.system_import(name, *args, **kwargs)
File "D:\test_code\virtuallab\test_file\lib\site-packages\fmmax_init
.py", line 5, in
from . import (
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_bundle\pydev_import_hook.py", line 21, in do_import
module = self._system_import(name, *args, **kwargs)
File "D:\test_code\virtuallab\test_file\lib\site-packages\fmmax\fields.py", line 11, in
from fmmax import basis, fft, fmm, scattering, utils
File "D:\BaiduNetdiskDownload\Pycharm Pro 2023\PyCharm 2023.1\plugins\python\helpers\pydev_pydev_bundle\pydev_import_hook.py", line 21, in do_import
module = self._system_import(name, *args, **kwargs)
File "D:\test_code\virtuallab\test_file\lib\site-packages\fmmax\fmm.py", line 47, in
formulation: Formulation | VectorFn,
TypeError: unsupported operand type(s) for |: 'EnumMeta' and '_CallableGenericAlias'

@mfschubert
Copy link
Member

Thanks for reporting this. It seems the Python version specified in the pyproject.toml is incorrect, since I am using some python 3.10 language features. I’ll make an update.

Will it be possible for you to use python 3.10 or newer?

@xungu1
Copy link
Author

xungu1 commented Dec 29, 2024

Thank you for your answer. After replacing the python version with 3.10, the program can run normally. In addition, I have tried the given program for optimizing metagrating. I would like to ask whether it can be used to optimize the diffractive splitter ? If yes, how can I modify the program ?
import numpy as np
from invrs_gym import challenges

challenge = challenges.metagrating()

import jax

params = challenge.component.init(jax.random.PRNGKey(0))

def loss_fn(params):
response, aux = challenge.component.response(params)
loss = challenge.loss(response)
metrics = challenge.metrics(response, params=params, aux=aux)
efficiency = metrics["average_efficiency"]
return loss, (response, efficiency)

import invrs_opt

opt = invrs_opt.density_lbfgsb(beta=4)
state = opt.init(params) # Initialize optimizer state using the initial parameters.

@jax.jit
def step_fn(state):
params = opt.params(state)
(value, (_, efficiency)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
state = opt.update(grad=grad, value=value, params=params, state=state)
return state, (params, efficiency)

Call step_fn repeatedly to optimize, and store the results of each evaluation.

efficiencies = []
for _ in range(65):
state, (params, efficiency) = step_fn(state)
efficiencies.append(efficiency)

import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

ax = plt.subplot(121)
ax.plot(onp.asarray(efficiencies) * 100)
ax.set_xlabel("Step")
ax.set_ylabel("Diffraction efficiency into +1 order (%)")

ax = plt.subplot(122)
im = ax.imshow(1 - params.array, cmap="gray")
im.set_clim([-2, 1])

contours = measure.find_contours(onp.asarray(params.array))
for c in contours:
ax.plot(c[:, 1], c[:, 0], "k", lw=1)

ax.set_xticks([])
ax.set_yticks([])
plt.show()

print(f"Final efficiency: {efficiencies[-1] * 100:.1f}%")

@mfschubert
Copy link
Member

You should be able to design a diffractive supplier by replacing challenges.metagrating() with challenges.diffractive_splitter(). Check the challenges/__init__.py for a consolidated view of all the challenge constructors (https://github.com/invrs-io/gym/blob/main/src/invrs_gym/challenges/__init__.py).

FYI if you surround your code with triple backticks it will format nicely: https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax#quoting-code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants