Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Molecule graph environment #139

Open
wants to merge 127 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 103 commits
Commits
Show all changes
127 commits
Select commit Hold shift + click to select a range
130338d
implemented rotation masks
AlexandraVolokhova Feb 28, 2023
fe9d5ac
add masks to featuraser
AlexandraVolokhova Feb 28, 2023
6c9de56
implemented apply rotations
AlexandraVolokhova Feb 28, 2023
99fb7c7
add simple tests for apply rotations
AlexandraVolokhova Mar 2, 2023
d1713cb
fix bug in torsios, add test with AD
AlexandraVolokhova Mar 2, 2023
217099c
add a comment to ConformerDataset
AlexandraVolokhova Apr 11, 2023
b000332
implemented rotation masks
AlexandraVolokhova Feb 28, 2023
03a445c
add masks to featuraser
AlexandraVolokhova Feb 28, 2023
b9be53c
implemented apply rotations
AlexandraVolokhova Feb 28, 2023
c80b2f8
add simple tests for apply rotations
AlexandraVolokhova Mar 2, 2023
ea8e410
fix bug in torsios, add test with AD
AlexandraVolokhova Mar 2, 2023
7cf2f3e
add a comment to ConformerDataset
AlexandraVolokhova Apr 11, 2023
38afb29
updated setup
michalkoziarski Apr 26, 2023
f0d8225
fixing tests
michalkoziarski Apr 26, 2023
8b7502e
black
michalkoziarski Apr 26, 2023
34341df
Merge remote-tracking branch 'origin/molecule_graph_env' into torch_a…
michalkoziarski Apr 26, 2023
6f7b738
Merge pull request #120 from alexhernandezgarcia/torch_ani_proxy
michalkoziarski Apr 26, 2023
f0d53ab
WiP molecule TorchANI proxy
michalkoziarski May 16, 2023
e87e441
added configs
michalkoziarski May 17, 2023
1d095be
updated proxy defaults
michalkoziarski May 17, 2023
879dd38
overwritten deepcopy
michalkoziarski May 17, 2023
c2897b7
optional batching
michalkoziarski May 18, 2023
e51b8f0
scaled energy
michalkoziarski May 18, 2023
1465827
updated config
michalkoziarski May 18, 2023
4b8816a
fixed docstring
michalkoziarski May 18, 2023
9b4fa15
energy divider as an argument
michalkoziarski May 31, 2023
feee48b
removed unused import
michalkoziarski May 31, 2023
38fc208
Merge pull request #124 from alexhernandezgarcia/torch_ani_proxy
michalkoziarski May 31, 2023
004f905
added aromatic bond type
michalkoziarski Jun 1, 2023
6a908de
XTB proxy
michalkoziarski Jun 5, 2023
032759b
black
michalkoziarski Jun 5, 2023
541e316
XTB using command line interface instead of Python API
michalkoziarski Jun 6, 2023
fe6cecc
cleaning scratch by default
michalkoziarski Jun 8, 2023
eaccf01
tblite
michalkoziarski Jun 9, 2023
91ef7f8
conversion
michalkoziarski Jun 9, 2023
c564ce7
Merge pull request #131 from alexhernandezgarcia/xtb_proxy
AlexandraVolokhova Jun 9, 2023
2baddc8
Merge branch 'batch_class' of github.com:alexhernandezgarcia/gflownet…
AlexandraVolokhova Jun 9, 2023
5e29fbf
rearranged conformer configs and environments
michalkoziarski Jun 15, 2023
3deaeda
added wurlitzer (to supress XTB output)
michalkoziarski Jun 15, 2023
7475f65
conformer environment
michalkoziarski Jun 15, 2023
f15cb4c
deepcopy workaround for XTB proxy
michalkoziarski Jun 15, 2023
caf6537
simplified cpu casting
michalkoziarski Jun 15, 2023
8353087
conformer molecules split into separate files
michalkoziarski Jun 15, 2023
c4b97d8
conformer's state2proxy returning 3D coordinates
michalkoziarski Jun 15, 2023
f6d693b
fixed typo
michalkoziarski Jun 15, 2023
d5839b1
typo
michalkoziarski Jun 15, 2023
20a8eac
atom_positions_dataset removed from attributes
michalkoziarski Jun 15, 2023
9695baf
dataset file containing multiple molecules
michalkoziarski Jun 15, 2023
6223cef
subtracting constant energy term
michalkoziarski Jun 15, 2023
fb404b1
updated conformer conda env
michalkoziarski Jun 16, 2023
15f905f
estimating min value for rejection sampling
michalkoziarski Jun 16, 2023
f6166c9
black
michalkoziarski Jun 16, 2023
1c67145
ray setup for cluster
michalkoziarski Jun 16, 2023
7a6fa8e
updated torchani proxy input format
michalkoziarski Jun 22, 2023
bea25ac
number of jobs in config
michalkoziarski Jun 22, 2023
a1b89dc
policy class moved to a separate file
michalkoziarski Jun 22, 2023
b775dd9
Merge remote-tracking branch 'origin/batch_class' into molecule_graph…
michalkoziarski Jun 22, 2023
db1a9da
disabled ray logging
michalkoziarski Jun 22, 2023
bd820ef
updated conformer env to return numpy arrays in *2proxy methods
michalkoziarski Jun 22, 2023
9603ffe
Merge remote-tracking branch 'origin/batch_class' into molecule_graph…
michalkoziarski Jun 23, 2023
4bfb86b
XTB proxy renamed to TBLite
michalkoziarski Jun 23, 2023
3d3c809
re-added XTB proxy
michalkoziarski Jun 23, 2023
16c83af
constant energy term computation in a proxy instead of environment
michalkoziarski Jun 23, 2023
4de0a0d
removed print
michalkoziarski Jun 23, 2023
75a5152
added constant subtraction in base conformer proxy class
michalkoziarski Jun 23, 2023
0edb340
changed default beta
michalkoziarski Jun 23, 2023
a13e6cf
method passed in xtb config
michalkoziarski Jun 23, 2023
8dee210
added xtb to requirements
michalkoziarski Jun 23, 2023
f8842d0
decreased default beta
michalkoziarski Jun 24, 2023
38e9534
normalizing energies to (0, 1) range
michalkoziarski Jun 24, 2023
58c4164
method dictionary for XTB
michalkoziarski Jun 24, 2023
ae2e70d
joblib instead of ray
michalkoziarski Jun 26, 2023
88ce96c
Merge remote-tracking branch 'origin/main' into molecule_graph_env
michalkoziarski Jun 26, 2023
e5962a5
normalization controlled by an argument
michalkoziarski Jun 27, 2023
9d5c3ab
renamed utils/molecule/xtb to xtb_cli
michalkoziarski Jun 27, 2023
e8596bc
function for finding rotatable bonds
michalkoziarski Jun 27, 2023
bc81236
dynamically computing torsion angles (instead of using dataset)
michalkoziarski Jun 27, 2023
b3f4e64
race condition fix and temporary workaround for white spaces in hydra…
michalkoziarski Jul 26, 2023
583854e
implemented ns
AlexandraVolokhova Aug 2, 2023
6988ea6
fix n_samples for ns
AlexandraVolokhova Aug 2, 2023
90e7155
suppressing ultranest output; better state conversion; more verbose p…
michalkoziarski Aug 8, 2023
38f5d62
more consistent configs
michalkoziarski Aug 8, 2023
a4dd973
commented out print changed to TODO
michalkoziarski Aug 8, 2023
2aba2fa
Merge pull request #171 from alexhernandezgarcia/nested_sampling
michalkoziarski Aug 8, 2023
3b6c11d
Merge branch 'molecule_graph_env' of https://github.com/alexhernandez…
michalkoziarski Aug 8, 2023
afa1488
removed unused import
michalkoziarski Aug 9, 2023
452a45a
outlier removal and clamping of energies
michalkoziarski Aug 9, 2023
49450aa
more robust conformer sampling (higher number of dimensions)
michalkoziarski Aug 9, 2023
6045a20
quantiles used for outlier detection; increased number of samples
michalkoziarski Aug 9, 2023
103eaca
updated comment
michalkoziarski Aug 9, 2023
6757870
scaling to higher dimensionality
michalkoziarski Aug 9, 2023
1b8e144
Merge branch 'rf_gfn' of https://github.com/alexhernandezgarcia/gflow…
michalkoziarski Aug 10, 2023
5269c0d
Merge branch 'rf_gfn' into molecule_graph_env
michalkoziarski Aug 10, 2023
62a2a6d
Merge branch 'rf_gfn-mk' into molecule_graph_env
michalkoziarski Aug 10, 2023
31ea9b0
updated config
michalkoziarski Aug 10, 2023
2e44acc
EGNN-based policy
michalkoziarski Aug 11, 2023
7065121
added missing __init__.py
michalkoziarski Aug 12, 2023
54cbaf7
better device support
michalkoziarski Aug 12, 2023
23bb90f
better device support
michalkoziarski Aug 12, 2023
6448a50
better device support
michalkoziarski Aug 12, 2023
625a74f
Merge branch 'main' into molecule_graph_env
michalkoziarski Sep 7, 2023
f4da186
removed ConformersDataset
michalkoziarski Sep 7, 2023
0897394
remove old_conformer
michalkoziarski Sep 7, 2023
d3bf9f4
moved install script
michalkoziarski Sep 7, 2023
98d154e
Merge branch 'main' into molecule_graph_env
michalkoziarski Sep 7, 2023
699bee7
Merge branch 'main' of https://github.com/alexhernandezgarcia/gflownet
michalkoziarski Sep 14, 2023
e3c6aac
fixed rotatatable bonds
AlexandraVolokhova Sep 18, 2023
3dda1f8
add test
AlexandraVolokhova Sep 18, 2023
e792e49
isort fix
michalkoziarski Sep 18, 2023
34d70a0
formatting changes & typo fixes
michalkoziarski Sep 19, 2023
dc7233e
black
michalkoziarski Sep 19, 2023
e5aa6dd
Merge pull request #213 from alexhernandezgarcia/torsion_angles_detec…
AlexandraVolokhova Sep 19, 2023
4384c4e
fixed function name
michalkoziarski Sep 19, 2023
9d7e02c
Merge branch 'torsion_angles_detection_fix' of https://github.com/ale…
michalkoziarski Sep 19, 2023
d1d0e1b
Merge branch 'torsion_angles_detection_fix_mk' into torsion_angles_de…
michalkoziarski Sep 19, 2023
f4c2f35
Merge github.com:alexhernandezgarcia/gflownet into torsion_angles_det…
AlexandraVolokhova Sep 19, 2023
286c5ec
Merge branch 'main' into molecule_graph_env
michalkoziarski Sep 19, 2023
9032cb0
Merge branch 'torsion_angles_detection_fix' of github.com:alexhernand…
AlexandraVolokhova Sep 19, 2023
582f95a
Merge branch 'molecule_graph_env' into torsion_angles_detection_fix
michalkoziarski Sep 19, 2023
65c8482
Merge branch 'torsion_angles_detection_fix' of github.com:alexhernand…
AlexandraVolokhova Sep 19, 2023
00f614a
add hydrogens fix
AlexandraVolokhova Sep 19, 2023
0ae2a8a
fix ordering bug, add check for hydrogen ta
AlexandraVolokhova Sep 19, 2023
d38aab8
Merge branch 'torsion_angles_detection_fix' of https://github.com/ale…
michalkoziarski Sep 20, 2023
2875505
fix another bug
AlexandraVolokhova Sep 20, 2023
23149eb
Merge branch 'torsion_angles_detection_fix' of https://github.com/ale…
michalkoziarski Sep 20, 2023
9b45f15
black & isort
michalkoziarski Sep 20, 2023
c2fb4e1
Merge pull request #212 from alexhernandezgarcia/torsion_angles_detec…
michalkoziarski Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- base

_target_: gflownet.envs.alaninedipeptide.AlanineDipeptide
_target_: gflownet.envs.conformers.alaninedipeptide.AlanineDipeptide

path_to_dataset: './data/alanine_dipeptide_conformers_1.npy'
url_to_dataset: 'https://drive.google.com/uc?id=1r1KRGcpBhR3xaS8yt2i64dfMnJGgNj4C'
Expand Down
36 changes: 36 additions & 0 deletions config/env/conformers/conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defaults:
- base

_target_: gflownet.envs.conformers.conformer.Conformer

# smiles: 'CC(C(=O)NC)NC(=O)C' # alanine dipeptide
# smiles: 'CC(C)Cc1ccc(cc1)[C@@H](C)C(=O)O' # ibuprofen
smiles: 'O=C(c1ccc2n1CCC2C(=O)O)c3ccccc3' # ketorolac
# smiles: 'CCCCCC1=CC(=C(C(=C1)O)C2C=C(CCC2C(=C)C)C)O' # cannabidiol
# smiles: 'CN1C2CCC1C(C(C2)OC(=O)C3=CC=CC=C3)C(=O)OC' # cocaine
n_torsion_angles: 2
reward_sampling_method: nested

id: conformer
policy_encoding_dim_per_angle: null
# Fixed length of trajectories
length_traj: 10
vonmises_min_concentration: 1e-3
# Parameters of the fixed policy output distribution
n_comp: 3
fixed_distribution:
vonmises_mean: 0.0
vonmises_concentration: 0.5
# Parameters of the random policy output distribution
random_distribution:
vonmises_mean: 0.0
vonmises_concentration: 0.001
# Buffer
buffer:
data_path: null
train: null
test:
type: grid
n: 1000
output_csv: conformer_test.csv
output_pkl: conformer_test.pkl
52 changes: 52 additions & 0 deletions config/experiments/conformer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# @package _global_

defaults:
- override /env: conformers/conformer
- override /gflownet: trajectorybalance
- override /policy: conformers/mlp
- override /proxy: conformers/tblite
- override /logger: wandb

# Environment
env:
length_traj: 10
policy_encoding_dim_per_angle: 10
policy_type: mlp
n_comp: 5
vonmises_min_concentration: 4
reward_func: boltzmann
reward_beta: 32
reward_sampling_method: nested

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
optimizer:
batch_size:
forward: 100
backward_dataset: 0
backward_replay: 0
lr: 0.00001
z_dim: 16
lr_z_mult: 1000
n_train_steps: 40000
lr_decay_period: 1000000

# WandB
logger:
lightweight: True
project_name: "gflownet"
tags:
- gflownet
- continuous
- molecule
test:
period: 5000
n: 10000
checkpoints:
period: 5000

# Hydra
hydra:
run:
dir: ${user.logdir.root}/molecule/${now:%Y-%m-%d_%H-%M-%S}
4 changes: 2 additions & 2 deletions config/experiments/icml23/alaninedipeptide.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# @package _global_

defaults:
- override /env: alaninedipeptide
- override /env: conformers/alaninedipeptide
- override /gflownet: trajectorybalance
- override /proxy: molecule
- override /proxy: conformers/rf
- override /logger: wandb
- override /user: sasha

Expand Down
16 changes: 16 additions & 0 deletions config/policy/conformers/egnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_target_: gflownet.policy.conformers.egnn.EGNNPolicy

forward:
n_gnn_layers: 7
n_node_mlp_layers: 2
n_pool_mlp_layers: 2
egnn_hidden_dim: 128
node_mlp_hidden_dim: 128
pool_mlp_hidden_dim: 128
backward:
n_gnn_layers: 7
n_node_mlp_layers: 2
n_pool_mlp_layers: 2
egnn_hidden_dim: 128
node_mlp_hidden_dim: 128
pool_mlp_hidden_dim: 128
13 changes: 13 additions & 0 deletions config/policy/conformers/mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_target_: gflownet.policy.base.Policy

forward:
type: mlp
n_hid: 512
n_layers: 5
checkpoint: forward
backward:
type: mlp
n_hid: 512
n_layers: 5
shared_weights: False
checkpoint: backward
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: gflownet.proxy.molecule.RFMoleculeEnergy
_target_: gflownet.proxy.conformers.rf.RFMoleculeEnergy

path_to_model: './data/random_forest_reward_100.pkl'
url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF'
url_to_model: 'https://drive.google.com/uc?id=1OpQNC8WWIsMh8K4olfSaQRFlj3emYThF'
1 change: 1 addition & 0 deletions config/proxy/conformers/tblite.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: gflownet.proxy.conformers.tblite.TBLiteMoleculeEnergy
4 changes: 4 additions & 0 deletions config/proxy/conformers/torchani.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_target_: gflownet.proxy.conformers.torchani.TorchANIMoleculeEnergy

model: ANI2x
use_ensemble: True
3 changes: 3 additions & 0 deletions config/proxy/conformers/xtb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_target_: gflownet.proxy.conformers.xtb.XTBMoleculeEnergy

method: gfnff
2 changes: 1 addition & 1 deletion gflownet/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__all__ = ["base", "grid", "aptamers"]
__all__ = ["base", "grid", "aptamers", "alaninedipeptide"]
20 changes: 20 additions & 0 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,26 @@ def statebatch2policy(
tfloat(states, float_type=self.float, device=self.device)
)

def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]:
"""
Prepares a batch of states in "GFlowNet format" for the proxy. Typically,
this will be the same as the statebatch2proxy, but in cases in which proxy
input is already processed (e.g., conformers, with list of torsion angles
converted to 3D positions of atoms), this can be overwritten to preserve KDE.
"""
return self.statebatch2proxy(states)

def statetorch2kde(
self, states: TensorType["batch_size", "state_dim"]
) -> TensorType["batch_size", "state_proxy_dim"]:
"""
Prepares a batch of states in torch "GFlowNet format" for the KDE. Typically,
this will be the same as the statetorch2proxy, but in cases in which proxy
input is already processed (e.g., conformers, with list of torsion angles
converted to 3D positions of atoms), this can be overwritten to preserve KDE.
"""
return self.statetorch2proxy(states)

def policy2state(self, state_policy: List) -> List:
"""
Converts the model (e.g. one-hot encoding) version of a state given as
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from gflownet.envs.ctorus import ContinuousTorus
from gflownet.utils.molecule import constants
from gflownet.utils.molecule.atom_positions_dataset import AtomPositionsDataset
from gflownet.utils.molecule.conformer_base import ConformerBase
from gflownet.utils.molecule.datasets import AtomPositionsDataset
from gflownet.utils.molecule.rdkit_conformer import RDKitConformer


class AlanineDipeptide(ContinuousTorus):
Expand All @@ -26,7 +26,7 @@ def __init__(
path_to_dataset, url_to_dataset
)
atom_positions = self.atom_positions_dataset.sample()
self.conformer = ConformerBase(
self.conformer = RDKitConformer(
atom_positions, constants.ad_smiles, constants.ad_free_tas
)
n_dim = len(self.conformer.freely_rotatable_tas)
Expand Down
138 changes: 138 additions & 0 deletions gflownet/envs/conformers/conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import copy
from typing import List, Optional, Tuple

import numpy as np
import numpy.typing as npt
from rdkit import Chem
from rdkit.Chem import AllChem
from torchtyping import TensorType

from gflownet.envs.ctorus import ContinuousTorus
from gflownet.utils.molecule.constants import ad_atom_types
from gflownet.utils.molecule.featurizer import MolDGLFeaturizer
from gflownet.utils.molecule.rdkit_conformer import RDKitConformer
from gflownet.utils.molecule.rotatable_bonds import find_rotor_from_smile


class Conformer(ContinuousTorus):
"""
Extension of continuous torus to conformer generation. Based on AlanineDipeptide,
but accepts any molecule (defined by SMILES and freely rotatable torsion angles).
"""

def __init__(
self,
smiles: str,
n_torsion_angles: Optional[int] = 2,
torsion_indices: Optional[List[int]] = None,
policy_type: str = "mlp",
**kwargs,
):
if torsion_indices is None:
# We hard code default torsion indices for Alanine Dipeptide to preserve
# backward compatibility.
if smiles == "CC(C(=O)NC)NC(=O)C" and n_torsion_angles == 2:
torsion_indices = [2, 1]
else:
torsion_indices = list(range(n_torsion_angles))

atom_positions = Conformer._get_positions(smiles)
torsion_angles = Conformer._get_torsion_angles(smiles, torsion_indices)
self.conformer = RDKitConformer(atom_positions, smiles, torsion_angles)

# Conversions
self.statebatch2oracle = self.statebatch2proxy
self.statetorch2oracle = self.statetorch2proxy
if policy_type == "gnn":
self.statebatch2policy = self.statebatch2policy_gnn
elif policy_type != "mlp":
raise ValueError(
f"Unrecognized policy_type = {policy_type}, expected either 'mlp' or 'gnn'."
)

super().__init__(n_dim=len(self.conformer.freely_rotatable_tas), **kwargs)

self.sync_conformer_with_state()

self.graph = MolDGLFeaturizer(ad_atom_types).mol2dgl(self.conformer.rdk_mol)

@staticmethod
def _get_positions(smiles: str) -> npt.NDArray:
mol = Chem.MolFromSmiles(smiles)
mol = Chem.AddHs(mol)
AllChem.EmbedMolecule(mol, randomSeed=0)
return mol.GetConformer().GetPositions()

@staticmethod
def _get_torsion_angles(smiles: str, indices: List[int]) -> List[Tuple[int]]:
torsion_angles = find_rotor_from_smile(smiles)
torsion_angles = [torsion_angles[i] for i in indices]
return torsion_angles

def sync_conformer_with_state(self, state: List = None):
if state is None:
state = self.state
for idx, ta in enumerate(self.conformer.freely_rotatable_tas):
self.conformer.set_torsion_angle(ta, state[idx])
return self.conformer

def statebatch2proxy(self, states: List[List]) -> npt.NDArray:
"""
Returns a list of proxy states, each being a numpy array with dimensionality
(n_atoms, 4), in which the first column encodes atomic number, and the last
three columns encode atom positions.
"""
states_proxy = []
for st in states:
conf = self.sync_conformer_with_state(st)
states_proxy.append(
np.concatenate(
[
conf.get_atomic_numbers()[..., np.newaxis],
conf.get_atom_positions(),
],
axis=1,
)
)
return np.array(states_proxy)

def statetorch2proxy(self, states: TensorType["batch", "state_dim"]) -> npt.NDArray:
return self.statebatch2proxy(states.cpu().numpy())

def statebatch2policy_gnn(self, states: List[List]) -> npt.NDArray[np.float32]:
"""
Returns an array of GNN-format policy inputs with dimensionality
(n_states, n_atoms, 4), in which the first three columns encode atom positions,
and the last column encodes current timestep.
"""
policy_input = []
for state in states:
conformer = self.sync_conformer_with_state(state)
positions = conformer.get_atom_positions()
policy_input.append(
np.concatenate(
[positions, np.full((positions.shape[0], 1), state[-1])],
axis=1,
)
)
return np.array(policy_input)

def statebatch2kde(self, states: List[List]) -> npt.NDArray[np.float32]:
return np.array(states)[:, :-1]

def statetorch2kde(
self, states: TensorType["batch_size", "state_dim"]
) -> TensorType["batch_size", "state_proxy_dim"]:
return states.cpu().numpy()[:, :-1]

def __deepcopy__(self, memo):
cls = self.__class__
new_instance = cls.__new__(cls)

for attr_name, attr_value in self.__dict__.items():
if attr_name != "conformer":
setattr(new_instance, attr_name, copy.copy(attr_value))

new_instance.conformer = self.conformer

return new_instance
Loading