In [ ]:
# Import necessary libraries
import streamlit as st
import openai
import psycopg2
import pandas as pd
import matplotlib.pyplot as plt
# Set OpenAI API key
openai.api_key ="a"
# Establish a connection to the PostgreSQL database
conn = psycopg2.connect(
host="localhost",
database="adventureworks",
user="postgres",
password="postgres"
)
# Create a cursor object using the connection
cur = conn.cursor()
# Function to get the schema of the 'humanresources' table
def get_schema():
# Execute SQL query to get table names from 'humanresources' schema
cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'humanresources'")
# Fetch all table names
tables = [table[0] for table in cur.fetchall()]
schema = {}
# For each table, get the column names
for table in tables:
cur.execute("SELECT column_name FROM information_schema.columns WHERE table_name = %s", (table,))
# Fetch all column names
columns = [column[0] for column in cur.fetchall()]
# Add table and its columns to the schema dictionary
schema[table] = columns
return schema
# Function to get a SQL query from OpenAI based on the input question and schema
def get_query_from_ai(input_question, schema):
# Create a chat completion with OpenAI
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo-1106",
temperature=0,
messages=[
{"role": "system", "content": f"You are a helpful assistant that suggests SQL queries based on user input. You have access to the following schema: {schema}."},
{"role": "user", "content": input_question}
]
)
# Extract the AI suggested query from the response
ai_suggested_query = response.choices[0].message['content']
# Find the start and end indices of the SQL query in the response
start_index = ai_suggested_query.find("```sql") + len("```sql")
end_index = ai_suggested_query.find("```", start_index)
# Extract the SQL query
ai_suggested_query = ai_suggested_query[start_index:end_index].strip()
# Replace "FROM " and "JOIN " with "FROM humanresources." and "JOIN humanresources." respectively
ai_suggested_query = ai_suggested_query.replace("FROM ", "FROM humanresources.")
ai_suggested_query = ai_suggested_query.replace("JOIN ", "JOIN humanresources.")
return ai_suggested_query
# Function to execute a SQL query
def execute_query(query):
try:
# Execute the SQL query
cur.execute(query)
# Fetch all the results
result = cur.fetchall()
# If the result is a single value, return it
if len(result) == 1 and len(result[0]) == 1:
return result[0][0]
# Otherwise, return the entire result
return result
except Exception as e:
# If there's an error, return the error message
return str(e)
# Function to handle the chatbot functionality
def chatbot(input_question):
# Get the schema of the 'humanresources' table
schema = get_schema()
# Get a SQL query from OpenAI based on the input question and schema
ai_suggested_query = get_query_from_ai(input_question, schema)
# Execute the SQL query and get the result
result = execute_query(ai_suggested_query)
# If the result is a list of tuples, convert it to a DataFrame
if isinstance(result, list) and all(isinstance(i, tuple) for i in result):
df = pd.DataFrame(result)
# If the input question asks for a table or tabular data, return the DataFrame
if 'table' in input_question.lower() or 'tabular' in input_question.lower():
return df # Display the DataFrame as a table in Streamlit
# If the DataFrame has at least 2 columns, create a chart
elif len(df.columns) >= 2:
df.set_index(df.columns[0], inplace=True)
# Create a new figure with a default 111 subplot
fig, ax = plt.subplots(figsize=(5, 5))
# If the input question asks for a pie chart, create a pie chart
if 'pie chart' in input_question.lower():
ax.pie(df[df.columns[0]], autopct='%1.1f%%')
# If the input question asks for a bar chart, create a bar chart
elif 'bar chart' in input_question.lower():
ax.bar(df.index, df[df.columns[0]])
# Return the figure object
return fig
else:
st.write()
else:
# If the result is not a list of tuples, return the result
return result
# Function to handle input change
def on_input_change():
# Get the input question from the session state
input_question = st.session_state.input
if input_question:
# Get the result from the chatbot function
result = chatbot(input_question)
# Append the question and answer to the conversation history in the session state
st.session_state.conversation_history.append({'question': input_question, 'answer': result})
# Clear the input
st.session_state.input = ''
# Set the title of the Streamlit app
st.title('SQL Query Chatbot')
# Initialize the conversation history in the session state
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
# For each chat in the conversation history, display the question and answer
for chat in st.session_state.conversation_history:
st.markdown(f'**Question:** {chat["question"]}')
answer = chat["answer"]
# If the answer is a figure object, display the chart
if isinstance(answer, plt.Figure):
st.markdown('**Answer:**')
st.pyplot(answer)
# If the answer is a DataFrame, display the DataFrame
elif isinstance(answer, pd.DataFrame):
st.markdown('**Answer:**')
st.dataframe(answer)
else:
# Otherwise, display the answer as text
st.markdown(f'**Answer:** {answer}')
# Create a text input field for the user to enter their question
input_question = st.text_input('Enter your question:', key='input', on_change=on_input_change)
1. Importing Libraries:
streamlit: A Python library for creating web applications with minimal effort.
openai: A library for interacting with the OpenAI GPT-3.5 model.
psycopg2: A PostgreSQL adapter for Python, used for connecting to a PostgreSQL database.
pandas: A data manipulation and analysis library.
matplotlib.pyplot: A library for creating static, animated, and interactive visualizations in Python.
2. Set OpenAI API Key:
openai.api_key: The API key for accessing the OpenAI GPT-3.5 model.
3. Database Connection:
Connects to a PostgreSQL database named "adventureworks" on a local server using the psycopg2 library.
4. Define Functions:
get_schema():
Fetches the schema of tables in the 'humanresources' schema of the PostgreSQL database.
get_query_from_ai(input_question, schema):
Sends a chat-based message to the OpenAI GPT-3.5 model, incorporating the schema.
Extracts and modifies the suggested SQL query from the OpenAI response.
execute_query(query):
Executes a SQL query using the PostgreSQL connection and returns the result.
Handles exceptions and returns error messages if there's an issue.
chatbot(input_question):
Integrates the above functions to provide a conversational interface.
Gets the schema, generates a SQL query using OpenAI, executes the query, and returns the result.
on_input_change():
Handles input changes, triggers the chatbot function, and updates the conversation history.
5. Streamlit App:
Set Title:
Sets the title of the Streamlit app to "SQL Query Chatbot."
Initialize Conversation History:
Initializes and maintains a conversation history in the session state.
Display Conversation History:
For each chat in the conversation history, displays the question and answer.
If the answer is a figure, it displays the chart. If it's a DataFrame, it displays the table.
Text Input Field:
Creates a text input field for the user to enter their SQL-related question.
Calls the on_input_change function on input change.
6. Run the Streamlit App:
The code ends with the Streamlit app running and waiting for user input.
This code combines Streamlit for the web interface, OpenAI GPT-3.5 for generating SQL queries based on user input, and PostgreSQL for executing those queries against a specific database schema. The app provides a conversational interface for users to interact with the SQL Query Chatbot.