Skip to content

Commit

Permalink
update data dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 committed Jan 3, 2025
1 parent 9b9fb5e commit 165e5ab
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions deepxde/data/mf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,24 @@ def train_next_batch(self, batch_size=None):
self.geom.uniform_points(self.num_lo, True),
self.geom.uniform_points(self.num_hi, True),
)
)
).astype(config.real(np))
else:
self.X_train = np.vstack(
(
self.geom.random_points(self.num_lo, random=self.dist_train),
self.geom.random_points(self.num_hi, random=self.dist_train),
)
)
).astype(config.real(np))
y_lo_train = self.func_lo(self.X_train)
y_hi_train = self.func_hi(self.X_train)
self.y_train = [y_lo_train, y_hi_train]
return self.X_train, self.y_train

@run_if_any_none("X_test", "y_test")
def test(self):
self.X_test = self.geom.uniform_points(self.num_test, True)
self.X_test = self.geom.uniform_points(self.num_test, True).astype(
config.real(np)
)
y_lo_test = self.func_lo(self.X_test)
y_hi_test = self.func_hi(self.X_test)
self.y_test = [y_lo_test, y_hi_test]
Expand Down Expand Up @@ -84,20 +86,20 @@ def __init__(
standardize=False,
):
if X_lo_train is not None:
self.X_lo_train = X_lo_train
self.X_hi_train = X_hi_train
self.y_lo_train = y_lo_train
self.y_hi_train = y_hi_train
self.X_hi_test = X_hi_test
self.y_hi_test = y_hi_test
self.X_lo_train = X_lo_train.astype(config.real(np))
self.X_hi_train = X_hi_train.astype(config.real(np))
self.y_lo_train = y_lo_train.astype(config.real(np))
self.y_hi_train = y_hi_train.astype(config.real(np))
self.X_hi_test = X_hi_test.astype(config.real(np))
self.y_hi_test = y_hi_test.astype(config.real(np))
elif fname_lo_train is not None:
data = np.loadtxt(fname_lo_train)
data = np.loadtxt(fname_lo_train).astype(config.real(np))
self.X_lo_train = data[:, col_x]
self.y_lo_train = data[:, col_y]
data = np.loadtxt(fname_hi_train)
data = np.loadtxt(fname_hi_train).astype(config.real(np))
self.X_hi_train = data[:, col_x]
self.y_hi_train = data[:, col_y]
data = np.loadtxt(fname_hi_test)
data = np.loadtxt(fname_hi_test).astype(config.real(np))
self.X_hi_test = data[:, col_x]
self.y_hi_test = data[:, col_y]
else:
Expand All @@ -117,7 +119,10 @@ def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
return [loss_lo, loss_hi]

def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None):
return [bkd.as_tensor(0, dtype=config.real(bkd.lib)), loss_fn(targets[1], outputs[1])]
return [
bkd.as_tensor(0, dtype=config.real(bkd.lib)),
loss_fn(targets[1], outputs[1]),
]

@run_if_any_none("X_train", "y_train")
def train_next_batch(self, batch_size=None):
Expand Down

0 comments on commit 165e5ab

Please sign in to comment.