diff --git a/web_demo.py b/web_demo.py index 668dcf4..9c9cb2f 100644 --- a/web_demo.py +++ b/web_demo.py @@ -158,6 +158,8 @@ def chat(img, msgs, ctx, params=None, vision_hidden_states=None): res = res.replace('', '') res = res.replace('', '') answer = res.replace('', '') + if device == "cuda": + torch.cuda.empty_cache() return 0, answer, None, None except Exception as err: print(err) diff --git a/web_demo_2.5.py b/web_demo_2.5.py index 6f6b81a..9bfa39e 100644 --- a/web_demo_2.5.py +++ b/web_demo_2.5.py @@ -150,6 +150,8 @@ def chat(img, msgs, ctx, params=None, vision_hidden_states=None): res = res.replace('', '') res = res.replace('', '') answer = res.replace('', '') + if device == "cuda": + torch.cuda.empty_cache() return 0, answer, None, None except Exception as err: print(err)