diff --git a/web_demo.py b/web_demo.py
index 668dcf4..9c9cb2f 100644
--- a/web_demo.py
+++ b/web_demo.py
@@ -158,6 +158,8 @@ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
res = res.replace('', '')
res = res.replace('', '')
answer = res.replace('', '')
+ if device == "cuda":
+ torch.cuda.empty_cache()
return 0, answer, None, None
except Exception as err:
print(err)
diff --git a/web_demo_2.5.py b/web_demo_2.5.py
index 6f6b81a..9bfa39e 100644
--- a/web_demo_2.5.py
+++ b/web_demo_2.5.py
@@ -150,6 +150,8 @@ def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
res = res.replace('', '')
res = res.replace('', '')
answer = res.replace('', '')
+ if device == "cuda":
+ torch.cuda.empty_cache()
return 0, answer, None, None
except Exception as err:
print(err)