Spaces:
Paused
Paused
Commit
·
50e2fd2
1
Parent(s):
14d86a4
Changed Generate stream to async
Browse files- main/api.py +7 -4
main/api.py
CHANGED
@@ -2,7 +2,8 @@ import os
|
|
2 |
from pathlib import Path
|
3 |
from threading import Thread
|
4 |
import torch
|
5 |
-
from typing import Optional,
|
|
|
6 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
7 |
from .utils.logging import setup_logger
|
8 |
|
@@ -248,12 +249,12 @@ class LLMApi:
|
|
248 |
self.logger.error(f"Error generating response: {str(e)}")
|
249 |
raise
|
250 |
|
251 |
-
def generate_stream(
|
252 |
self,
|
253 |
prompt: str,
|
254 |
system_message: Optional[str] = None,
|
255 |
max_new_tokens: Optional[int] = None
|
256 |
-
) ->
|
257 |
"""
|
258 |
Generate a streaming response for the given prompt.
|
259 |
"""
|
@@ -287,10 +288,12 @@ class LLMApi:
|
|
287 |
thread = Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
|
288 |
thread.start()
|
289 |
|
290 |
-
#
|
291 |
for new_text in streamer:
|
292 |
self.logger.debug(f"Generated chunk: {new_text[:50]}...")
|
293 |
yield new_text
|
|
|
|
|
294 |
|
295 |
except Exception as e:
|
296 |
self.logger.error(f"Error in streaming generation: {str(e)}")
|
|
|
2 |
from pathlib import Path
|
3 |
from threading import Thread
|
4 |
import torch
|
5 |
+
from typing import Optional, List, AsyncIterator
|
6 |
+
import asyncio
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
|
8 |
from .utils.logging import setup_logger
|
9 |
|
|
|
249 |
self.logger.error(f"Error generating response: {str(e)}")
|
250 |
raise
|
251 |
|
252 |
+
async def generate_stream(
|
253 |
self,
|
254 |
prompt: str,
|
255 |
system_message: Optional[str] = None,
|
256 |
max_new_tokens: Optional[int] = None
|
257 |
+
) -> AsyncIterator[str]:
|
258 |
"""
|
259 |
Generate a streaming response for the given prompt.
|
260 |
"""
|
|
|
288 |
thread = Thread(target=self.generation_model.generate, kwargs=generation_kwargs)
|
289 |
thread.start()
|
290 |
|
291 |
+
# Use async generator to yield chunks
|
292 |
for new_text in streamer:
|
293 |
self.logger.debug(f"Generated chunk: {new_text[:50]}...")
|
294 |
yield new_text
|
295 |
+
# Add a small delay to allow other tasks to run
|
296 |
+
await asyncio.sleep(0)
|
297 |
|
298 |
except Exception as e:
|
299 |
self.logger.error(f"Error in streaming generation: {str(e)}")
|