dreyyyy commited on
Commit
66a6701
·
verified ·
1 Parent(s): 1312014

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +30 -0
handler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MarianMTModel, MarianTokenizer
2
+ from typing import Any, List, Dict
3
+
4
+ class EndpointHandler:
5
+ def __init__(self, path=""):
6
+ # Load the model and tokenizer
7
+ self.model = MarianMTModel.from_pretrained(path)
8
+ self.tokenizer = MarianTokenizer.from_pretrained(path)
9
+
10
+ def __call__(self, data: Any) -> List[Dict[str, str]]:
11
+ """
12
+ Args:
13
+ data (dict): The request payload with an "inputs" key containing the text to translate.
14
+ Returns:
15
+ List[Dict]: A list containing the translated text.
16
+ """
17
+ # Get the input text from the request
18
+ text = data.get("inputs", "")
19
+
20
+ # Tokenize the input text
21
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True)
22
+
23
+ # Perform the translation
24
+ translated = self.model.generate(**inputs)
25
+
26
+ # Decode the translated text
27
+ translated_text = self.tokenizer.decode(translated[0], skip_special_tokens=True)
28
+
29
+ # Return the translated text as a response
30
+ return [{"translation_text": translated_text}]