Skip to content

Commit a845950

Browse files
authored
Tau2 Environment Setup (#12)
* starting * Comment out abstract methods in EnvironmentAdapter * adding missing files * remove readme * delete test file for now, not being used * adjust airline environment * remove dead code
1 parent ad8b9d1 commit a845950

9 files changed

Lines changed: 206683 additions & 0 deletions

File tree

examples/tau2_mcp/airplane_environment/airline_environment.py

Lines changed: 530 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
2+
3+
from pydantic import BaseModel, Field
4+
5+
from .utils import AIRLINE_DB_PATH
6+
from .db import DB
7+
8+
FlightType = Literal["round_trip", "one_way"]
9+
CabinClass = Literal["business", "economy", "basic_economy"]
10+
Insurance = Literal["yes", "no"]
11+
12+
13+
MembershipLevel = Annotated[
14+
Literal["gold", "silver", "regular"], Field(description="Membership level")
15+
]
16+
17+
18+
class AirportCode(BaseModel):
19+
iata: str = Field(description="IATA code")
20+
city: str = Field(description="City name")
21+
22+
23+
AirportInfo = Annotated[list[AirportCode], Field(description="Airport information")]
24+
25+
26+
class Name(BaseModel):
27+
first_name: str = Field(description="The person's first name")
28+
last_name: str = Field(description="The person's last name")
29+
30+
31+
class Address(BaseModel):
32+
address1: str = Field(description="Primary address line")
33+
address2: Optional[str] = Field(
34+
None, description="Secondary address line (optional)"
35+
)
36+
city: str = Field(description="City name")
37+
country: str = Field(description="Country name")
38+
state: str = Field(description="State or province name")
39+
zip: str = Field(description="Postal code")
40+
41+
42+
# Payment Related Models
43+
class Payment(BaseModel):
44+
payment_id: str = Field(description="Unique identifier for the payment")
45+
amount: int = Field(description="Payment amount in dollars")
46+
47+
48+
class PaymentMethodBase(BaseModel):
49+
source: str = Field(description="Type of payment method")
50+
id: str = Field(description="Unique identifier for the payment method")
51+
52+
53+
class CreditCard(PaymentMethodBase):
54+
source: Literal["credit_card"] = Field(
55+
description="Indicates this is a credit card payment method"
56+
)
57+
brand: str = Field(description="Credit card brand (e.g., visa, mastercard)")
58+
last_four: str = Field(description="Last four digits of the credit card")
59+
60+
61+
class GiftCard(PaymentMethodBase):
62+
source: Literal["gift_card"] = Field(
63+
description="Indicates this is a gift card payment method"
64+
)
65+
amount: float = Field(description="Gift card value amount")
66+
id: str = Field(description="Unique identifier for the gift card")
67+
68+
69+
class Certificate(PaymentMethodBase):
70+
source: Literal["certificate"] = Field(
71+
description="Indicates this is a certificate payment method"
72+
)
73+
amount: float = Field(description="Certificate value amount")
74+
75+
76+
PaymentMethod = Union[CreditCard, GiftCard, Certificate]
77+
78+
79+
class Passenger(BaseModel):
80+
first_name: str = Field(description="Passenger's first name")
81+
last_name: str = Field(description="Passenger's last name")
82+
dob: str = Field(description="Date of birth in YYYY-MM-DD format")
83+
84+
85+
SeatPrices = Annotated[
86+
dict[CabinClass, int], Field(description="Prices for different cabin classes")
87+
]
88+
AvailableSeats = Annotated[
89+
dict[CabinClass, int],
90+
Field(description="Available seats for different cabin classes"),
91+
]
92+
93+
94+
class FlightDateStatusAvailable(BaseModel):
95+
status: Literal["available"] = Field(
96+
description="Indicates flight is available for booking"
97+
)
98+
available_seats: AvailableSeats = Field(description="Available seats by class")
99+
prices: SeatPrices = Field(description="Current prices by class")
100+
101+
102+
class FlightDataStatusOnTime(BaseModel):
103+
status: Literal["on time"] = Field(description="Indicates flight is on time")
104+
estimated_departure_time_est: str = Field(
105+
description="Estimated departure time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T06:04:00"
106+
)
107+
estimated_arrival_time_est: str = Field(
108+
description="Estimated arrival time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T07:30:00"
109+
)
110+
111+
112+
class FlightDataStatusFlying(BaseModel):
113+
status: Literal["flying"] = Field(description="Indicates flight is in flight")
114+
actual_departure_time_est: str = Field(
115+
description="Actual departure time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T06:04:00"
116+
)
117+
estimated_arrival_time_est: str = Field(
118+
description="Estimated arrival time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T07:30:00"
119+
)
120+
121+
122+
class FlightDateStatusLanded(BaseModel):
123+
status: Literal["landed"] = Field(description="Indicates flight has landed")
124+
actual_departure_time_est: str = Field(
125+
description="Actual departure time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T06:04:00"
126+
)
127+
actual_arrival_time_est: str = Field(
128+
description="Actual arrival time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T07:30:00"
129+
)
130+
131+
132+
class FlightDateStatusCancelled(BaseModel):
133+
status: Literal["cancelled"] = Field(description="Indicates flight was cancelled")
134+
135+
136+
class FlightDateStatusDelayed(BaseModel):
137+
status: Literal["delayed"] = Field(description="Indicates flight was delayed")
138+
estimated_departure_time_est: str = Field(
139+
description="Estimated departure time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T06:04:00"
140+
)
141+
estimated_arrival_time_est: str = Field(
142+
description="Estimated arrival time in EST in the format YYYY-MM-DDTHH:MM:SS, e.g 2024-05-15T07:30:00"
143+
)
144+
145+
146+
FlightDateStatus = Union[
147+
FlightDateStatusAvailable,
148+
FlightDateStatusLanded,
149+
FlightDateStatusCancelled,
150+
FlightDateStatusDelayed,
151+
FlightDataStatusFlying,
152+
FlightDataStatusOnTime,
153+
]
154+
155+
156+
class FlightBase(BaseModel):
157+
flight_number: str = Field(description="Unique flight identifier")
158+
origin: str = Field(description="IATA code for origin airport")
159+
destination: str = Field(description="IATA code for destination airport")
160+
161+
162+
class Flight(FlightBase):
163+
scheduled_departure_time_est: str = Field(
164+
description="Scheduled departure time in EST in the format HH:MM:SS, e.g 06:00:00"
165+
)
166+
scheduled_arrival_time_est: str = Field(
167+
description="Scheduled arrival time in EST in the format HH:MM:SS, e.g 07:00:00"
168+
)
169+
dates: Dict[str, FlightDateStatus] = Field(
170+
description="Flight status by date (YYYY-MM-DD)"
171+
)
172+
173+
174+
class DirectFlight(FlightBase):
175+
status: Literal["available"] = Field(
176+
description="Indicates flight is available for booking"
177+
)
178+
scheduled_departure_time_est: str = Field(
179+
description="Scheduled departure time in EST in the format HH:MM:SS, e.g 06:00:00"
180+
)
181+
scheduled_arrival_time_est: str = Field(
182+
description="Scheduled arrival time in EST in the format HH:MM:SS, e.g 07:00:00"
183+
)
184+
date: Optional[str] = Field(
185+
description="Flight date in YYYY-MM-DD format", default=None
186+
)
187+
available_seats: AvailableSeats = Field(description="Available seats by class")
188+
prices: SeatPrices = Field(description="Current prices by class")
189+
190+
191+
class ReservationFlight(FlightBase):
192+
date: str = Field(description="Flight date in YYYY-MM-DD format")
193+
price: int = Field(description="Flight price in dollars.")
194+
195+
196+
class FlightInfo(BaseModel):
197+
flight_number: str = Field(description="Flight number, such as 'HAT001'.")
198+
date: str = Field(
199+
description="The date for the flight in the format 'YYYY-MM-DD', such as '2024-05-01'."
200+
)
201+
202+
203+
class User(BaseModel):
204+
user_id: str = Field(description="Unique identifier for the user")
205+
name: Name = Field(description="User's full name")
206+
address: Address = Field(description="User's address information")
207+
email: str = Field(description="User's email address")
208+
dob: str = Field(
209+
description="User's date of birth in the format YYYY-MM-DD, e.g 1990-04-05"
210+
)
211+
payment_methods: Dict[str, PaymentMethod] = Field(
212+
description="User's saved payment methods"
213+
)
214+
saved_passengers: List[Passenger] = Field(
215+
description="User's saved passenger information"
216+
)
217+
membership: MembershipLevel = Field(description="User's membership level")
218+
reservations: List[str] = Field(description="List of user's reservation IDs")
219+
220+
221+
# Reservation Models
222+
class Reservation(BaseModel):
223+
reservation_id: str = Field(description="Unique identifier for the reservation")
224+
user_id: str = Field(description="ID of the user who made the reservation")
225+
origin: str = Field(description="IATA code for trip origin")
226+
destination: str = Field(description="IATA code for trip destination")
227+
flight_type: FlightType = Field(description="Type of trip")
228+
cabin: CabinClass = Field(description="Selected cabin class")
229+
flights: List[ReservationFlight] = Field(
230+
description="List of flights in the reservation"
231+
)
232+
passengers: List[Passenger] = Field(
233+
description="List of passengers on the reservation"
234+
)
235+
payment_history: List[Payment] = Field(
236+
description="History of payments for this reservation"
237+
)
238+
created_at: str = Field(
239+
description="Timestamp when reservation was created in the format YYYY-MM-DDTHH:MM:SS"
240+
)
241+
total_baggages: int = Field(description="Total number of bags in reservation")
242+
nonfree_baggages: int = Field(description="Number of paid bags in reservation")
243+
insurance: Insurance = Field(description="Whether travel insurance was purchased")
244+
status: Optional[Literal["cancelled"]] = Field(
245+
description="Status of the reservation", default=None
246+
)
247+
248+
249+
class FlightDB(DB):
250+
"""Database of all flights, users, and reservations."""
251+
252+
flights: Dict[str, Flight] = Field(
253+
description="Dictionary of all flights indexed by flight number"
254+
)
255+
users: Dict[str, User] = Field(
256+
description="Dictionary of all users indexed by user ID"
257+
)
258+
reservations: Dict[str, Reservation] = Field(
259+
description="Dictionary of all reservations indexed by reservation ID"
260+
)
261+
262+
def get_statistics(self) -> dict[str, Any]:
263+
"""Get the statistics of the database."""
264+
num_flights = len(self.flights)
265+
num_flights_instances = sum(
266+
len(flight.dates) for flight in self.flights.values()
267+
)
268+
num_users = len(self.users)
269+
num_reservations = len(self.reservations)
270+
return {
271+
"num_flights": num_flights,
272+
"num_flights_instances": num_flights_instances,
273+
"num_users": num_users,
274+
"num_reservations": num_reservations,
275+
}
276+
277+
278+
def get_db():
279+
return FlightDB.load(AIRLINE_DB_PATH)
280+
281+
282+
if __name__ == "__main__":
283+
db = get_db()
284+
print(db.get_statistics())

0 commit comments

Comments
 (0)