-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagents.py
142 lines (121 loc) · 4.83 KB
/
agents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#agents.py
import os
import csv
from prompts import *
from groq import Groq
from openai import OpenAI
#from dotenv import load_dotenv
#load_dotenv()
# Set up the API key
#client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) # Open AI ChatGPT
#MODEL = "gpt-3.5-turbo"
#client = Groq(api_key=os.environ.get("GROQ_API_KEY")) # Groq API
#MODEL = "llama3-8b-8192"
# Set up the API key
if not os.getenv("GROQ_API_KEY"):
if not os.getenv("OPENAI_API_KEY"):
select = int(input("Define de llm provider (0 - OpenAI, 1 - Groq, 2 - Ollama): "))
if select==1:
print("\nGroq API selected.")
os.environ["GROQ_API_KEY"] = input("\nPlease enter your Groq API key: ")
client = Groq() # Groq API
MODEL = "llama3-8b-8192"
elif select==0:
print("\nOpenAI API selected.")
os.environ["OPENAI_API_KEY"] = input("\nPlease enter your OpenAI API key: ")
client = OpenAI() # Open AI ChatGPT
MODEL = "gpt-3.5-turbo"
else:
os.environ["OPENAI_API_KEY"] = "NA"
os.environ["OPENAI_API_BASE"] = "http://172.26.40.58:5000"
os.environ["OPENAI_MODEL_NAME"] ='crewai-phi3'
client = OpenAI(
base_url = 'http://localhost:11434/v1',
api_key='ollama'
)
MODEL = "crewai-phi3"
else:
client = Groq() # Groq API
MODEL = "llama3-8b-8192"
# Function to read CSV file from the user
def read_csv(file_path):
data = []
with open(file_path, "r", newline="") as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
data.append(row)
return data
# Function to save generated data to a new CSV file
def save_to_csv(data, output_file, headers=None):
mode = 'w' if headers else 'a'
with open(output_file, mode, newline="") as f:
writer = csv.writer(f)
if headers:
writer.writerow(headers)
for row in csv.reader(data.splitlines()):
writer.writerow(row)
# Create the Analyzer Agent
def analyzer_agent(sample_data):
message = client.chat.completions.create(
model=MODEL,
max_tokens=400,
temperature=0.1,
messages=[
{"role": "user", "content": ANALYZER_SYSTEM_PROMPT},
{"role": "user", "content": ANALYZER_USER_PROMPT.format(sample_data=sample_data)}
]
)
return message.choices[0].message.content
# Create the Generator Agent
def generator_agent(analysis_result, sample_data, num_rows=30):
message = client.chat.completions.create(
model=MODEL,
max_tokens=1500,
temperature=1,
messages=[
{"role": "user", "content": GENERATOR_SYSTEM_PROMPT},
{"role": "user", "content": GENERATOR_USER_PROMPT.format(
num_rows=num_rows,
analysis_result=analysis_result,
sample_data=sample_data
)
}
]
)
return message.choices[0].message.content
# Main execution flow
# Caminho do diretório mapeado no contêiner
mapped_directory = "./app/data"
# Get input from the user
file_path = input("\nEnter the name of your CSV file:")
file_path = os.path.join(mapped_directory, file_path)
desired_rows = int(input("Enter the number of rows you want in the new dataset: "))
sample_data = read_csv(file_path)
sample_data_str = "\n".join([",".join(row) for row in sample_data]) #Converts 2D list to a single strin g
print("\n Launching team of Agents...")
# Analyze the sample data using the Analyzer Agent
analysis_result = analyzer_agent(sample_data_str)
print("\n### Analyzer Agent output: ###\n")
print(analysis_result)
print("\n--------------------------------------------\n\nGenerating new data...")
# Set up the output file
output_file = os.path.join(mapped_directory, "new_dataset.csv")
headers = sample_data[0]
# Create the output file with headers
save_to_csv("", output_file, headers)
batch_size = 10 # Number of rows to generate in each batch
generated_rows = 0 # Counter to keep track of how many rows have been generated
# Generate data in batches until we reach the desired number of rows
while generated_rows < desired_rows:
# Calculate how many rows to generate in this batch
rows_to_generate = min(batch_size, desired_rows - generated_rows)
# Generate a batch of data using the Generator Agent
generated_data = generator_agent(analysis_result, sample_data_str, rows_to_generate)
# Append the generated data to the output file
save_to_csv(generated_data, output_file)
# Update the count of generated rows
generated_rows += rows_to_generate
# Print progress update
print(f"Generated {generated_rows} rows out of {desired_rows}" )
# Inform the user that we process is complete
print(f"\nGenerated data has been saved to {output_file}")