AurelioAguirre commited on
Commit
50e2fd2
·
1 Parent(s): 14d86a4

Changed Generate stream to async

Browse files
Files changed (1) hide show
  1. main/api.py +7 -4
main/api.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  from pathlib import Path
3
  from threading import Thread
4
  import torch
5
- from typing import Optional, Iterator, List
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
  from .utils.logging import setup_logger
8
 
@@ -248,12 +249,12 @@ class LLMApi:
248
  self.logger.error(f"Error generating response: {str(e)}")
249
  raise
250
 
251
- def generate_stream(
252
  self,
253
  prompt: str,
254
  system_message: Optional[str] = None,
255
  max_new_tokens: Optional[int] = None
256
- ) -> Iterator[str]:
257
  """
258
  Generate a streaming response for the given prompt.
259
  """
@@ -287,10 +288,12 @@ class LLMApi:
287
  thread = Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
288
  thread.start()
289
 
290
- # Yield the generated text in chunks
291
  for new_text in streamer:
292
  self.logger.debug(f"Generated chunk: {new_text[:50]}...")
293
  yield new_text
 
 
294
 
295
  except Exception as e:
296
  self.logger.error(f"Error in streaming generation: {str(e)}")
 
2
  from pathlib import Path
3
  from threading import Thread
4
  import torch
5
+ from typing import Optional, List, AsyncIterator
6
+ import asyncio
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
8
  from .utils.logging import setup_logger
9
 
 
249
  self.logger.error(f"Error generating response: {str(e)}")
250
  raise
251
 
252
+ async def generate_stream(
253
  self,
254
  prompt: str,
255
  system_message: Optional[str] = None,
256
  max_new_tokens: Optional[int] = None
257
+ ) -> AsyncIterator[str]:
258
  """
259
  Generate a streaming response for the given prompt.
260
  """
 
288
  thread = Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
289
  thread.start()
290
 
291
+ # Use async generator to yield chunks
292
  for new_text in streamer:
293
  self.logger.debug(f"Generated chunk: {new_text[:50]}...")
294
  yield new_text
295
+ # Add a small delay to allow other tasks to run
296
+ await asyncio.sleep(0)
297
 
298
  except Exception as e:
299
  self.logger.error(f"Error in streaming generation: {str(e)}")