From 0d44cc02aaf4a2a00b5fc10f5b0c23ee432e50a4 Mon Sep 17 00:00:00 2001 From: Zijian Zhang Date: Fri, 24 Jan 2025 11:48:31 +0800 Subject: [PATCH] Fix save_path problem when batch infer --- sample_video.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sample_video.py b/sample_video.py index 2130896..14778d4 100644 --- a/sample_video.py +++ b/sample_video.py @@ -18,7 +18,7 @@ def main(): # Create save folder to save the samples save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' - if not os.path.exists(args.save_path): + if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True) # Load models @@ -50,9 +50,9 @@ def main(): for i, sample in enumerate(samples): sample = samples[i].unsqueeze(0) time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") - save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4" - save_videos_grid(sample, save_path, fps=24) - logger.info(f'Sample save to: {save_path}') + cur_save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4" + save_videos_grid(sample, cur_save_path, fps=24) + logger.info(f'Sample save to: {cur_save_path}') if __name__ == "__main__": main()