In [ ]:
from openai import OpenAI # Works only with openai version >= 1.2.0
from dotenv import load_dotenv,find_dotenv
import pandas as pd
load_dotenv(find_dotenv())
client = OpenAI()
from langchain.agents.agent_types import AgentType
from langchain.chat_models import ChatOpenAI
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
In [ ]:
df = pd.read_parquet("yellow_tripdata_2023-01.parquet")
In [ ]:
print(df.info())
<class 'pandas.core.frame.DataFrame'> RangeIndex: 3066766 entries, 0 to 3066765 Data columns (total 19 columns): # Column Dtype --- ------ ----- 0 VendorID int64 1 tpep_pickup_datetime datetime64[ns] 2 tpep_dropoff_datetime datetime64[ns] 3 passenger_count float64 4 trip_distance float64 5 RatecodeID float64 6 store_and_fwd_flag object 7 PULocationID int64 8 DOLocationID int64 9 payment_type int64 10 fare_amount float64 11 extra float64 12 mta_tax float64 13 tip_amount float64 14 tolls_amount float64 15 improvement_surcharge float64 16 total_amount float64 17 congestion_surcharge float64 18 airport_fee float64 dtypes: datetime64[ns](2), float64(12), int64(4), object(1) memory usage: 444.6+ MB None
In [ ]:
agent = create_pandas_dataframe_agent(
ChatOpenAI(temperature=0, model="gpt-3.5-turbo-0613"),
df,
verbose=True,
agent_type=AgentType.OPENAI_FUNCTIONS,
)
agent
Out[ ]:
AgentExecutor(verbose=True, agent=OpenAIFunctionsAgent(llm=ChatOpenAI(client=<openai.resources.chat.completions.Completions object at 0x000001B9017D29B0>, async_client=<openai.resources.chat.completions.AsyncCompletions object at 0x000001B901824940>, model_name='gpt-3.5-turbo-0613', temperature=0.0, openai_api_key='sk-xxxxxx', openai_proxy=''), tools=[PythonAstREPLTool(locals={'df': VendorID tpep_pickup_datetime tpep_dropoff_datetime passenger_count \ 0 2 2023-01-01 00:32:10 2023-01-01 00:40:36 1.0 1 2 2023-01-01 00:55:08 2023-01-01 01:01:27 1.0 2 2 2023-01-01 00:25:04 2023-01-01 00:37:49 1.0 3 1 2023-01-01 00:03:48 2023-01-01 00:13:25 0.0 4 2 2023-01-01 00:10:29 2023-01-01 00:21:19 1.0 ... ... ... ... ... 3066761 2 2023-01-31 23:58:34 2023-02-01 00:12:33 NaN 3066762 2 2023-01-31 23:31:09 2023-01-31 23:50:36 NaN 3066763 2 2023-01-31 23:01:05 2023-01-31 23:25:36 NaN 3066764 2 2023-01-31 23:40:00 2023-01-31 23:53:00 NaN 3066765 2 2023-01-31 23:07:32 2023-01-31 23:21:56 NaN trip_distance RatecodeID store_and_fwd_flag PULocationID \ 0 0.97 1.0 N 161 1 1.10 1.0 N 43 2 2.51 1.0 N 48 3 1.90 1.0 N 138 4 1.43 1.0 N 107 ... ... ... ... ... 3066761 3.05 NaN None 107 3066762 5.80 NaN None 112 3066763 4.67 NaN None 114 3066764 3.15 NaN None 230 3066765 2.85 NaN None 262 DOLocationID payment_type fare_amount extra mta_tax tip_amount \ 0 141 2 9.30 1.00 0.5 0.00 1 237 1 7.90 1.00 0.5 4.00 2 238 1 14.90 1.00 0.5 15.00 3 7 1 12.10 7.25 0.5 0.00 4 79 1 11.40 1.00 0.5 3.28 ... ... ... ... ... ... ... 3066761 48 0 15.80 0.00 0.5 3.96 3066762 75 0 22.43 0.00 0.5 2.64 3066763 239 0 17.61 0.00 0.5 5.32 3066764 79 0 18.15 0.00 0.5 4.43 3066765 143 0 15.97 0.00 0.5 2.00 tolls_amount improvement_surcharge total_amount \ 0 0.0 1.0 14.30 1 0.0 1.0 16.90 2 0.0 1.0 34.90 3 0.0 1.0 20.85 4 0.0 1.0 19.68 ... ... ... ... 3066761 0.0 1.0 23.76 3066762 0.0 1.0 29.07 3066763 0.0 1.0 26.93 3066764 0.0 1.0 26.58 3066765 0.0 1.0 21.97 congestion_surcharge airport_fee 0 2.5 0.00 1 2.5 0.00 2 2.5 0.00 3 0.0 1.25 4 2.5 0.00 ... ... ... 3066761 NaN NaN 3066762 NaN NaN 3066763 NaN NaN 3066764 NaN NaN 3066765 NaN NaN [3066766 rows x 19 columns]})], prompt=ChatPromptTemplate(input_variables=['agent_scratchpad', 'input'], input_types={'agent_scratchpad': typing.List[typing.Union[langchain.schema.messages.AIMessage, langchain.schema.messages.HumanMessage, langchain.schema.messages.ChatMessage, langchain.schema.messages.SystemMessage, langchain.schema.messages.FunctionMessage, langchain.schema.messages.ToolMessage]]}, messages=[SystemMessage(content='\nYou are working with a pandas dataframe in Python. The name of the dataframe is `df`.\nThis is the result of `print(df.head())`:\n| | VendorID | tpep_pickup_datetime | tpep_dropoff_datetime | passenger_count | trip_distance | RatecodeID | store_and_fwd_flag | PULocationID | DOLocationID | payment_type | fare_amount | extra | mta_tax | tip_amount | tolls_amount | improvement_surcharge | total_amount | congestion_surcharge | airport_fee |\n|---:|-----------:|:-----------------------|:------------------------|------------------:|----------------:|-------------:|:---------------------|---------------:|---------------:|---------------:|--------------:|--------:|----------:|-------------:|---------------:|------------------------:|---------------:|-----------------------:|--------------:|\n| 0 | 2 | 2023-01-01 00:32:10 | 2023-01-01 00:40:36 | 1 | 0.97 | 1 | N | 161 | 141 | 2 | 9.3 | 1 | 0.5 | 0 | 0 | 1 | 14.3 | 2.5 | 0 |\n| 1 | 2 | 2023-01-01 00:55:08 | 2023-01-01 01:01:27 | 1 | 1.1 | 1 | N | 43 | 237 | 1 | 7.9 | 1 | 0.5 | 4 | 0 | 1 | 16.9 | 2.5 | 0 |\n| 2 | 2 | 2023-01-01 00:25:04 | 2023-01-01 00:37:49 | 1 | 2.51 | 1 | N | 48 | 238 | 1 | 14.9 | 1 | 0.5 | 15 | 0 | 1 | 34.9 | 2.5 | 0 |\n| 3 | 1 | 2023-01-01 00:03:48 | 2023-01-01 00:13:25 | 0 | 1.9 | 1 | N | 138 | 7 | 1 | 12.1 | 7.25 | 0.5 | 0 | 0 | 1 | 20.85 | 0 | 1.25 |\n| 4 | 2 | 2023-01-01 00:10:29 | 2023-01-01 00:21:19 | 1 | 1.43 | 1 | N | 107 | 79 | 1 | 11.4 | 1 | 0.5 | 3.28 | 0 | 1 | 19.68 | 2.5 | 0 |'), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}')), MessagesPlaceholder(variable_name='agent_scratchpad')])), tools=[PythonAstREPLTool(locals={'df': VendorID tpep_pickup_datetime tpep_dropoff_datetime passenger_count \ 0 2 2023-01-01 00:32:10 2023-01-01 00:40:36 1.0 1 2 2023-01-01 00:55:08 2023-01-01 01:01:27 1.0 2 2 2023-01-01 00:25:04 2023-01-01 00:37:49 1.0 3 1 2023-01-01 00:03:48 2023-01-01 00:13:25 0.0 4 2 2023-01-01 00:10:29 2023-01-01 00:21:19 1.0 ... ... ... ... ... 3066761 2 2023-01-31 23:58:34 2023-02-01 00:12:33 NaN 3066762 2 2023-01-31 23:31:09 2023-01-31 23:50:36 NaN 3066763 2 2023-01-31 23:01:05 2023-01-31 23:25:36 NaN 3066764 2 2023-01-31 23:40:00 2023-01-31 23:53:00 NaN 3066765 2 2023-01-31 23:07:32 2023-01-31 23:21:56 NaN trip_distance RatecodeID store_and_fwd_flag PULocationID \ 0 0.97 1.0 N 161 1 1.10 1.0 N 43 2 2.51 1.0 N 48 3 1.90 1.0 N 138 4 1.43 1.0 N 107 ... ... ... ... ... 3066761 3.05 NaN None 107 3066762 5.80 NaN None 112 3066763 4.67 NaN None 114 3066764 3.15 NaN None 230 3066765 2.85 NaN None 262 DOLocationID payment_type fare_amount extra mta_tax tip_amount \ 0 141 2 9.30 1.00 0.5 0.00 1 237 1 7.90 1.00 0.5 4.00 2 238 1 14.90 1.00 0.5 15.00 3 7 1 12.10 7.25 0.5 0.00 4 79 1 11.40 1.00 0.5 3.28 ... ... ... ... ... ... ... 3066761 48 0 15.80 0.00 0.5 3.96 3066762 75 0 22.43 0.00 0.5 2.64 3066763 239 0 17.61 0.00 0.5 5.32 3066764 79 0 18.15 0.00 0.5 4.43 3066765 143 0 15.97 0.00 0.5 2.00 tolls_amount improvement_surcharge total_amount \ 0 0.0 1.0 14.30 1 0.0 1.0 16.90 2 0.0 1.0 34.90 3 0.0 1.0 20.85 4 0.0 1.0 19.68 ... ... ... ... 3066761 0.0 1.0 23.76 3066762 0.0 1.0 29.07 3066763 0.0 1.0 26.93 3066764 0.0 1.0 26.58 3066765 0.0 1.0 21.97 congestion_surcharge airport_fee 0 2.5 0.00 1 2.5 0.00 2 2.5 0.00 3 0.0 1.25 4 2.5 0.00 ... ... ... 3066761 NaN NaN 3066762 NaN NaN 3066763 NaN NaN 3066764 NaN NaN 3066765 NaN NaN [3066766 rows x 19 columns]})])
In [ ]:
agent.run("how many rows are there?")
> Entering new AgentExecutor chain... Invoking: `python_repl_ast` with `{'query': 'df.shape[0]'}` 3066766There are 3,066,766 rows in the dataframe. > Finished chain.
Out[ ]:
'There are 3,066,766 rows in the dataframe.'