diff --git a/defog_data/supplementary.py b/defog_data/supplementary.py index add5ec6..909546a 100644 --- a/defog_data/supplementary.py +++ b/defog_data/supplementary.py @@ -100,35 +100,33 @@ def load_embeddings(emb_path: str) -> tuple[dict, dict]: }, "atis": { "GPE": [ + "airport_service.city_code,text,The city code where airport service is available", + "airport.airport_location,text,The airport location", + "airport.country_name,text,The country the airport is located in", + "airport.state_code,text,The state the airport is located in", "city.city_code,text,The city code", "city.city_name,text,The city name", - "city.state_code,text,The state code", "city.country_name,text,The country name", + "city.state_code,text,The state code", + "ground_service.city_code,text,The city code where ground service is available", + "state.country_name,text,The country name", "state.state_code,text,The state code", "state.state_name,text,The state name", - "state.country_name,text,The country name", - "airport.airport_location,text,The airport location", - "airport.country_name,text,The country the airport is located in", - "airport.state_code,text,The state the airport is located in", - "flight_stop.stop_airport,text,The 3-letter airport code for the stop", - "ground_service.city_code,text,The city code where ground service is available", - "ground_service.airport_code,text,The airport code where ground service is available", - "airport_service.city_code,text,The city code where airport service is available", - "airport_service.airport_code,text,The airport code where airport service is available", ], "ORG": [ + "airline.airline_code,text,Code assigned to airline", + "airline.airline_name,text,The airline's name", + "airport_service.airport_code,text,The airport code where airport service is available", + "airport.airport_code,text,The 3-letter airport code", + "airport.airport_name,text,The name of the airport", + "dual_carrier.main_airline,text,The name of the main airline operating the flight", "fare.fare_airline,text,The airline's name", "fare.from_airport,text,The 3-letter airport code for the departure location", "fare.to_airport,text,The 3-letter airport code for the arrival location", "flight.airline_code,text,Code assigned to airline", "flight.from_airport,text,The 3-letter airport code for the departure location", "flight.to_airport,text,The 3-letter airport code for the arrival location", - "flight.airline_flight,text,Code assigned to the flight", - "airline.airline_code,text,Code assigned to airline", - "airline.airline_name,text,The airline's name", - "airport.airport_name,text,The name of the airport", - "airport.airport_code,text,The 3-letter airport code", - "dual_carrier.main_airline,text,The name of the main airline operating the flight", + "ground_service.airport_code,text,The airport code where ground service is available", ], }, "yelp": { @@ -157,7 +155,6 @@ def load_embeddings(emb_path: str) -> tuple[dict, dict]: "ORG": [ "restaurant.name,text,The name of the restaurant", "restaurant.id,bigint,The ID of the restaurant", - "restaurant.name,text,The name of the restaurant", ], "PER": [], }, diff --git a/tests.py b/tests.py index 6363546..203166d 100644 --- a/tests.py +++ b/tests.py @@ -1,6 +1,7 @@ import os import unittest from defog_data.metadata import get_db, dbs +from defog_data.supplementary import columns_ner class TestDB(unittest.TestCase): @@ -157,6 +158,18 @@ def test_yelp(self): ] self.assertEqual(list(db_schema.keys()), expected_tables) + def test_supplementary_columns_ner(self): + # for each db, go through each table and add column names to a set and make sure they are not repeated + for db_name, ner_mapping in columns_ner.items(): + column_names = set() + for _, column_str_list in ner_mapping.items(): + for column_str in column_str_list: + column_name = column_str.split(",")[0] + if column_name in column_names: + raise Exception( + f"Column name {column_name} is repeated in {db_name}" + ) + column_names.add(column_name) if __name__ == "__main__": unittest.main()