Skip to content

Commit

Permalink
Raise an exception when the specified module does not support evaluat…
Browse files Browse the repository at this point in the history
…ion.
  • Loading branch information
haoyuying authored Apr 13, 2021
1 parent 30aace4 commit af6dd63
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
7 changes: 6 additions & 1 deletion paddlehub/finetune/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def __init__(self,

if not isinstance(self.model, paddle.nn.Layer):
raise TypeError('The model {} is not a `paddle.nn.Layer` object.'.format(self.model.__name__))


if self.local_rank == 0 and not os.path.exists(self.checkpoint_dir):
os.makedirs(self.checkpoint_dir)
Expand Down Expand Up @@ -178,6 +179,9 @@ def train(self,
collate_fn(callable): function to generate mini-batch data by merging the sample list.
None for only stack each fields of sample in axis 0(same as :attr::`np.stack(..., axis=0)`). Default None
'''
if eval_dataset is not None and not hasattr(self.model, 'validation_step'):
raise NotImplementedError('The specified finetuning model does not support evaluation.')

batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = paddle.io.DataLoader(
Expand Down Expand Up @@ -298,6 +302,7 @@ def evaluate(self,
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)

loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
Expand Down Expand Up @@ -363,7 +368,7 @@ def validation_step(self, batch: Any, batch_idx: int):
batch_idx(int) : The index of batch.
'''
if self.nranks > 1:
result = self.model._layers.validation_step(batch, batch_idx)
result = self.model._layers.validation_step(batch, batch_idx)
else:
result = self.model.validation_step(batch, batch_idx)
return result
Expand Down
19 changes: 4 additions & 15 deletions paddlehub/module/cv_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,21 +643,8 @@ def training_step(self, batch: List[paddle.Tensor], batch_idx: int) -> dict:
Returns:
results(dict): The model outputs, such as loss.
'''

return self.validation_step(batch, batch_idx)

def validation_step(self, batch: List[paddle.Tensor], batch_idx: int) -> dict:
"""
One step for validation, which should be called as forward computation.
Args:
batch(list[paddle.Tensor]): The one batch data, which contains images and labels.
batch_idx(int): The index of batch.
Returns:
results(dict) : The model outputs, such as metrics.
"""

label = batch[1].astype('int64')
criterionCE = nn.loss.CrossEntropyLoss()
Expand All @@ -666,10 +653,12 @@ def validation_step(self, batch: List[paddle.Tensor], batch_idx: int) -> dict:
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
logit = F.interpolate(logit, label.shape[-2:], mode='bilinear')

logit = logit.transpose([0,2,3,1])
loss_ce = criterionCE(logit, label)
loss += loss_ce / len(logits)

return {"loss": loss}

def predict(self, images: Union[str, np.ndarray], batch_size: int = 1, visualization: bool = True, save_path: str = 'seg_result') -> List[np.ndarray]:
Expand Down

0 comments on commit af6dd63

Please sign in to comment.