Skip to content

Commit

Permalink
fix: changed field initialization to default_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianZach committed Sep 19, 2024
1 parent e9f4240 commit 2728569
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions ssax/objectives/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ class Beale(ObjectiveFn):

dim: int = 2
optimal_value: float = 0.0
optimizers: jax.Array = jnp.array([(3.0, 0.5)])
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([(3.0, 0.5)]))


@classmethod
def create(cls,
Expand All @@ -85,7 +86,7 @@ class Branin(ObjectiveFn):

dim: int = 2
optimal_value: float = 0.397887
optimizers: jax.Array = jnp.array([(-jnp.pi, 12.275), (jnp.pi, 2.275), (9.42478, 2.475)])
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([(-jnp.pi, 12.275), (jnp.pi, 2.275), (9.42478, 2.475)]))

@classmethod
def create(cls,
Expand Down Expand Up @@ -113,7 +114,7 @@ class Bukin(ObjectiveFn):

dim: int = 2
optimal_value: float = 0.0
optimizers: jnp.array = jnp.array([(-10.0, 1.0)])
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([(-10.0, 1.0)]))

@classmethod
def create(cls,
Expand All @@ -136,7 +137,7 @@ class Cosine8(ObjectiveFn):

dim: int = 8
optimal_value: float = 0.8
optimizers: jnp.array = jnp.zeros(8)
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.zeros(8))

@classmethod
def create(cls,
Expand All @@ -157,7 +158,7 @@ class DropWave(ObjectiveFn):

dim: int = 2
optimal_value: float = -1.0
optimizers: jnp.array = jnp.zeros(2)
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.zeros(2))

@classmethod
def create(cls,
Expand All @@ -181,7 +182,7 @@ class EggHolder(ObjectiveFn):

dim: int = 2
optimal_value: float = -959.6407
optimizers: jnp.array = jnp.array([(512.0, 404.2319)])
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([(512.0, 404.2319)]))

@classmethod
def create(cls,
Expand All @@ -205,12 +206,12 @@ class HolderTable(ObjectiveFn):

dim: int = 2
optimal_value: float = -19.2085
optimizers: jnp.array = jnp.array([
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([
(8.05502, 9.66459),
(-8.05502, -9.66459),
(-8.05502, 9.66459),
(8.05502, -9.66459),
])
]))

@classmethod
def create(cls,
Expand All @@ -234,7 +235,7 @@ class SixHumpCamel(ObjectiveFn):

dim: int = 2
optimal_value: float = -1.0316
optimizers: jnp.array = jnp.array([(0.0898, -0.7126), (-0.0898, 0.7126)])
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([(0.0898, -0.7126), (-0.0898, 0.7126)]))

@classmethod
def create(cls,
Expand All @@ -260,7 +261,7 @@ class ThreeHumpCamel(ObjectiveFn):

dim: int = 2
optimal_value: float = 0.0
optimizers: jnp.array = jnp.array([(0.0, 0.0)])
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([(0.0, 0.0)]))

@classmethod
def create(cls,
Expand Down Expand Up @@ -369,7 +370,7 @@ class Michalewicz(ObjectiveFn):

dim: int = 2 # NOTE: hard fixed dim = 2 for now
optimal_value: float = -1.8013
optimizers: jnp.array = jnp.array([(2.20290552, 1.57079633)])
optimizers: jax.Array = struct.field(default_factory=lambda: jnp.array([(2.20290552, 1.57079633)]))

@classmethod
def create(cls,
Expand Down

0 comments on commit 2728569

Please sign in to comment.