Commit 2f2d356e authored by AUTOMATIC's avatar AUTOMATIC

call torch_gc before/after each gpu gradio operation

parent c1c27dad
import os import os
import threading import threading
from modules import devices
from modules.paths import script_path from modules.paths import script_path
import signal import signal
...@@ -47,6 +48,8 @@ def wrap_queued_call(func): ...@@ -47,6 +48,8 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc()
shared.state.sampling_step = 0 shared.state.sampling_step = 0
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.job_no = 0 shared.state.job_no = 0
...@@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func): ...@@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func):
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
devices.torch_gc()
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment