Skip to content

Commit

Permalink
observe cell traitlet
Browse files Browse the repository at this point in the history
  • Loading branch information
eimrek committed Apr 8, 2024
1 parent 1dc64ca commit 31cc09f
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
14 changes: 14 additions & 0 deletions example/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@
"display(bz)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# update cell:\n",
"bz.cell = [\n",
" [5.0, 0.0, 0.0],\n",
" [0.0, 1.0, 0.0],\n",
" [0.0, 0.0, 1.0],\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
43 changes: 30 additions & 13 deletions src/widget_bzvisualizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pathlib

import anywidget
import traitlets
import traitlets as tl

from . import utils

Expand All @@ -16,17 +16,26 @@ class BZVisualizer(anywidget.AnyWidget):
_esm = pathlib.Path(__file__).parent / "static" / "widget.js"
_css = pathlib.Path(__file__).parent / "static" / "widget.css"

seekpath_data = traitlets.Dict({}).tag(sync=True)
# primary input parameters, not directly used in the JS app
cell = tl.List().tag(sync=True)
rel_coords = tl.List().tag(sync=True)
atom_numbers = tl.List().tag(sync=True)

# parameters passed to the js BZVisualizer
show_axes = traitlets.Bool(True).tag(sync=True)
show_bvectors = traitlets.Bool(True).tag(sync=True)
show_pathpoints = traitlets.Bool(False).tag(sync=True)
disable_interact_overlay = traitlets.Bool(False).tag(sync=True)
# auxiliary traitlet to easily manage the previous ones
system = tl.Dict({}).tag(sync=True)

# Parameters to control the size of the div-container
width = traitlets.Unicode("100%").tag(sync=True)
height = traitlets.Unicode("400px").tag(sync=True)
# Data used in the JS app
seekpath_data = tl.Dict({}).tag(sync=True)

# optional parameters passed to the JS BZVisualizer
show_axes = tl.Bool(True).tag(sync=True)
show_bvectors = tl.Bool(True).tag(sync=True)
show_pathpoints = tl.Bool(False).tag(sync=True)
disable_interact_overlay = tl.Bool(False).tag(sync=True)

# parameters to control the size of the div-container
width = tl.Unicode("100%").tag(sync=True)
height = tl.Unicode("400px").tag(sync=True)

def __init__(
self,
Expand All @@ -40,6 +49,14 @@ def __init__(
The traitlets defined above can be set as a kwargs.
"""
super().__init__(**kwargs)
self.seekpath_data = utils.get_seekpath_data_for_visualizer(
cell, rel_coords, atom_numbers
)
self.system = {
"cell": cell,
"rel_coords": rel_coords,
"atom_numbers": atom_numbers,
}
self.seekpath_data = utils.get_seekpath_data_for_visualizer(self.system)

@tl.observe("cell")
def _cell_changed(self, change):
self.system[change["name"]] = change["new"]
self.seekpath_data = utils.get_seekpath_data_for_visualizer(self.system)
8 changes: 6 additions & 2 deletions src/widget_bzvisualizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
from seekpath.brillouinzone.brillouinzone import get_BZ


def get_seekpath_data_for_visualizer(cell, relcoords, atomic_numbers):
system = (np.array(cell), np.array(relcoords), np.array(atomic_numbers))
def get_seekpath_data_for_visualizer(system):
system = (
np.array(system["cell"]),
np.array(system["rel_coords"]),
np.array(system["atom_numbers"]),
)
res = seekpath.get_explicit_k_path(system, with_time_reversal=False)

b1, b2, b3 = res["reciprocal_primitive_lattice"]
Expand Down

0 comments on commit 31cc09f

Please sign in to comment.