Shekswess commited on
Commit
2996fd9
·
1 Parent(s): 405ff53
Files changed (2) hide show
  1. app.py +253 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The UI file for the SynthGenAI package."""
2
+
3
+ import os
4
+ import asyncio
5
+
6
+ import gradio as gr
7
+
8
+ from .data_model import DatasetConfig, DatasetGeneratorConfig, LLMConfig
9
+ from .dataset_generator import (
10
+ InstructionDatasetGenerator,
11
+ PreferenceDatasetGenerator,
12
+ RawDatasetGenerator,
13
+ SentimentAnalysisDatasetGenerator,
14
+ SummarizationDatasetGenerator,
15
+ TextClassificationDatasetGenerator,
16
+ )
17
+
18
+
19
+ def validate_inputs(*args):
20
+ """
21
+ Validate that all required inputs are filled.
22
+
23
+ Args:
24
+ *args: The input values to validate.
25
+
26
+ Returns:
27
+ bool: True if all required inputs are filled, False otherwise.
28
+ """
29
+ for arg in args:
30
+ if not arg:
31
+ return False
32
+ return True
33
+
34
+
35
+ def generate_synthetic_dataset(
36
+ llm_model,
37
+ temperature,
38
+ top_p,
39
+ max_tokens,
40
+ api_base,
41
+ api_key,
42
+ dataset_type,
43
+ topic,
44
+ domains,
45
+ language,
46
+ additional_description,
47
+ num_entries,
48
+ hf_token,
49
+ hf_repo_name,
50
+ llm_env_vars,
51
+ ):
52
+ """
53
+ Generate a dataset based on the provided parameters.
54
+
55
+ Args:
56
+ llm_model (str): The LLM model to use.
57
+ temperature (float): The temperature for the LLM.
58
+ top_p (float): The top_p value for the LLM.
59
+ max_tokens (int): The maximum number of tokens for the LLM.
60
+ api_base (str): The API base URL.
61
+ api_key (str): The API key.
62
+ dataset_type (str): The type of dataset to generate.
63
+ topic (str): The topic of the dataset.
64
+ domains (str): The domains for the dataset.
65
+ language (str): The language of the dataset.
66
+ additional_description (str): Additional description for the dataset.
67
+ num_entries (int): The number of entries in the dataset.
68
+ hf_token (str): The Hugging Face token.
69
+ hf_repo_name (str): The Hugging Face repository name.
70
+ llm_env_vars (str): Comma-separated environment variables for the LLM.
71
+
72
+ Returns:
73
+ str: A message indicating the result of the dataset generation.
74
+ """
75
+ os.environ["HF_TOKEN"] = hf_token
76
+
77
+ for var in llm_env_vars.split(","):
78
+ key, value = var.split("=")
79
+ os.environ[key.strip()] = value.strip()
80
+
81
+ # Validate inputs
82
+ if not validate_inputs(
83
+ llm_model,
84
+ temperature,
85
+ top_p,
86
+ max_tokens,
87
+ dataset_type,
88
+ topic,
89
+ domains,
90
+ language,
91
+ num_entries,
92
+ hf_token,
93
+ hf_repo_name,
94
+ llm_env_vars,
95
+ ):
96
+ return "All fields except API Base and API Key must be filled."
97
+
98
+ if api_base and api_key:
99
+ llm_config = LLMConfig(
100
+ model=llm_model,
101
+ temperature=temperature,
102
+ top_p=top_p,
103
+ max_tokens=max_tokens,
104
+ api_base=api_base,
105
+ api_key=api_key,
106
+ )
107
+ else:
108
+ llm_config = LLMConfig(
109
+ model=llm_model,
110
+ temperature=temperature,
111
+ top_p=top_p,
112
+ max_tokens=max_tokens,
113
+ )
114
+
115
+ dataset_config = DatasetConfig(
116
+ topic=topic,
117
+ domains=domains.split(","),
118
+ language=language,
119
+ additional_description=additional_description,
120
+ num_entries=num_entries,
121
+ )
122
+
123
+ dataset_generator_config = DatasetGeneratorConfig(
124
+ llm_config=llm_config,
125
+ dataset_config=dataset_config,
126
+ )
127
+
128
+ if dataset_type == "Raw":
129
+ generator = RawDatasetGenerator(dataset_generator_config)
130
+ elif dataset_type == "Instruction":
131
+ generator = InstructionDatasetGenerator(dataset_generator_config)
132
+ elif dataset_type == "Preference":
133
+ generator = PreferenceDatasetGenerator(dataset_generator_config)
134
+ elif dataset_type == "Sentiment Analysis":
135
+ generator = SentimentAnalysisDatasetGenerator(dataset_generator_config)
136
+ elif dataset_type == "Summarization":
137
+ generator = SummarizationDatasetGenerator(dataset_generator_config)
138
+ elif dataset_type == "Text Classification":
139
+ generator = TextClassificationDatasetGenerator(dataset_generator_config)
140
+ else:
141
+ return "Invalid dataset type"
142
+
143
+ dataset = asyncio.run(generator.agenerate_dataset())
144
+ dataset.save_dataset(hf_repo_name=hf_repo_name)
145
+ return "Dataset generated and saved successfully."
146
+
147
+
148
+ def ui_main():
149
+ """
150
+ Launch the Gradio UI for the SynthGenAI dataset generator.
151
+ """
152
+ with gr.Blocks(
153
+ title="SynthGenAI Dataset Generator",
154
+ css="footer {visibility: hidden}",
155
+ theme="ParityError/Interstellar",
156
+ ) as demo:
157
+ gr.Markdown(
158
+ """
159
+ <div style="text-align: center;">
160
+ <img src="https://raw.githubusercontent.com/Shekswess/synthgenai/refs/heads/main/docs/assets/logo_header.png" alt="Header Image" style="display: block; margin-left: auto; margin-right: auto; width: 50%;"/>
161
+ <h1>SynthGenAI Dataset Generator</h1>
162
+ </div>
163
+ """
164
+ )
165
+
166
+ with gr.Row():
167
+ llm_model = gr.Textbox(
168
+ label="LLM Model", placeholder="model_provider/model_name"
169
+ )
170
+ temperature = gr.Slider(
171
+ label="Temperature", minimum=0.0, maximum=1.0, step=0.1, value=0.5
172
+ )
173
+ top_p = gr.Slider(
174
+ label="Top P", minimum=0.0, maximum=1.0, step=0.1, value=0.9
175
+ )
176
+ max_tokens = gr.Number(label="Max Tokens", value=2048)
177
+ api_base = gr.Textbox(label="API Base", placeholder="API Base - Optional")
178
+ api_key = gr.Textbox(
179
+ label="API Key", placeholder="Your API Key - Optional", type="password"
180
+ )
181
+
182
+ with gr.Row():
183
+ dataset_type = gr.Dropdown(
184
+ label="Dataset Type",
185
+ choices=[
186
+ "Raw",
187
+ "Instruction",
188
+ "Preference",
189
+ "Sentiment Analysis",
190
+ "Summarization",
191
+ "Text Classification",
192
+ ],
193
+ )
194
+ topic = gr.Textbox(label="Topic", placeholder="Dataset topic")
195
+ domains = gr.Textbox(label="Domains", placeholder="Comma-separated domains")
196
+ language = gr.Textbox(
197
+ label="Language", placeholder="Language", value="English"
198
+ )
199
+ additional_description = gr.Textbox(
200
+ label="Additional Description",
201
+ placeholder="Additional description",
202
+ value="",
203
+ )
204
+ num_entries = gr.Number(label="Number of Entries", value=1000)
205
+
206
+ with gr.Row():
207
+ hf_token = gr.Textbox(
208
+ label="Hugging Face Token",
209
+ placeholder="Your HF Token",
210
+ type="password",
211
+ value=None,
212
+ )
213
+ hf_repo_name = gr.Textbox(
214
+ label="Hugging Face Repo Name",
215
+ placeholder="organization_or_user_name/dataset_name",
216
+ value=None,
217
+ )
218
+ llm_env_vars = gr.Textbox(
219
+ label="LLM Environment Variables",
220
+ placeholder="Comma-separated environment variables (e.g., KEY1=VALUE1, KEY2=VALUE2)",
221
+ value=None,
222
+ )
223
+
224
+ generate_button = gr.Button("Generate Dataset")
225
+ output = gr.Textbox(label="Operation Result", value="")
226
+
227
+ generate_button.click(
228
+ generate_synthetic_dataset,
229
+ inputs=[
230
+ llm_model,
231
+ temperature,
232
+ top_p,
233
+ max_tokens,
234
+ api_base,
235
+ api_key,
236
+ dataset_type,
237
+ topic,
238
+ domains,
239
+ language,
240
+ additional_description,
241
+ num_entries,
242
+ hf_token,
243
+ hf_repo_name,
244
+ llm_env_vars,
245
+ ],
246
+ outputs=output,
247
+ )
248
+
249
+ demo.launch(inbrowser=True, favicon_path=None)
250
+
251
+
252
+ if __name__ == "__main__":
253
+ ui_main()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ synthgenai