tnk2908 commited on
Commit
c1b5167
·
1 Parent(s): cc8b2eb

Check cuda availability before loading models

Browse files
Files changed (2) hide show
  1. config.ini +1 -1
  2. model_factory.py +6 -0
config.ini CHANGED
@@ -22,7 +22,7 @@ opt_13b = str:facebook/opt-13b
22
  [models.params]
23
  dtype = str:bfloat16
24
  load_device = str:cpu
25
- run_device = str:cpu
26
 
27
  [encrypt.default]
28
  gen_model = str:gpt2
 
22
  [models.params]
23
  dtype = str:bfloat16
24
  load_device = str:cpu
25
+ run_device = str:cuda
26
 
27
  [encrypt.default]
28
  gen_model = str:gpt2
model_factory.py CHANGED
@@ -32,6 +32,12 @@ class ModelFactory:
32
 
33
  load_device = GlobalConfig.get("models.params", "load_device")
34
  run_device = GlobalConfig.get("models.params", "run_device")
 
 
 
 
 
 
35
  if load_device is not None:
36
  cls.load_device = torch.device(str(load_device))
37
  if run_device is not None:
 
32
 
33
  load_device = GlobalConfig.get("models.params", "load_device")
34
  run_device = GlobalConfig.get("models.params", "run_device")
35
+ if not torch.cuda.is_available():
36
+ if load_device == "cuda" or run_device == "cuda":
37
+ print("cuda is not available, use cpu instead")
38
+ load_device = "cpu"
39
+ run_device = "cpu"
40
+
41
  if load_device is not None:
42
  cls.load_device = torch.device(str(load_device))
43
  if run_device is not None: