Spaces:
Running
Running
Check cuda availability before loading models
Browse files- config.ini +1 -1
- model_factory.py +6 -0
config.ini
CHANGED
@@ -22,7 +22,7 @@ opt_13b = str:facebook/opt-13b
|
|
22 |
[models.params]
|
23 |
dtype = str:bfloat16
|
24 |
load_device = str:cpu
|
25 |
-
run_device = str:
|
26 |
|
27 |
[encrypt.default]
|
28 |
gen_model = str:gpt2
|
|
|
22 |
[models.params]
|
23 |
dtype = str:bfloat16
|
24 |
load_device = str:cpu
|
25 |
+
run_device = str:cuda
|
26 |
|
27 |
[encrypt.default]
|
28 |
gen_model = str:gpt2
|
model_factory.py
CHANGED
@@ -32,6 +32,12 @@ class ModelFactory:
|
|
32 |
|
33 |
load_device = GlobalConfig.get("models.params", "load_device")
|
34 |
run_device = GlobalConfig.get("models.params", "run_device")
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
if load_device is not None:
|
36 |
cls.load_device = torch.device(str(load_device))
|
37 |
if run_device is not None:
|
|
|
32 |
|
33 |
load_device = GlobalConfig.get("models.params", "load_device")
|
34 |
run_device = GlobalConfig.get("models.params", "run_device")
|
35 |
+
if not torch.cuda.is_available():
|
36 |
+
if load_device == "cuda" or run_device == "cuda":
|
37 |
+
print("cuda is not available, use cpu instead")
|
38 |
+
load_device = "cpu"
|
39 |
+
run_device = "cpu"
|
40 |
+
|
41 |
if load_device is not None:
|
42 |
cls.load_device = torch.device(str(load_device))
|
43 |
if run_device is not None:
|