]> ToastFreeware Gitweb - philipp/winterrodeln/wrpylib.git/blob - scripts/update_public_transport_times.py
Add date as command line parameter.
[philipp/winterrodeln/wrpylib.git] / scripts / update_public_transport_times.py
1 #!/usr/bin/python
2 import argparse
3 import sys
4 from collections import defaultdict
5 from copy import deepcopy
6 from datetime import datetime, date, time, timedelta
7 from typing import List, Iterable, Optional, Dict, FrozenSet, Final
8
9 import jsonschema
10 import numpy as np
11 import partridge as ptg
12 from partridge.gtfs import Feed
13 from partridge.readers import read_service_ids_by_date
14 from termcolor import cprint  # python3-termcolor
15
16 from wrpylib.cli_tools import unified_diff, input_yes_no_quit, Choice
17 from wrpylib.json_tools import order_json_keys, format_json
18 from wrpylib.lib_update_public_transport import default_query_date
19 from wrpylib.mwapi import WikiSite, page_json
20 from wrpylib.sledrun_json import Sledrun, PublicTransportStop, PublicTransportStopLine
21
22 np.unicode = np.unicode_  # Prevent "AttributeError: module 'numpy' has no attribute 'unicode'"
23
24
25 ROUTE_TYPE: Final = {
26     0: 'Straßenbahn',
27     1: 'U-Bahn',
28     2: 'Bahn',
29     3: 'Bus',
30     4: 'Fähre',
31     5: 'Kabelstraßenbahn',
32     6: 'Seilbahn',
33     7: 'Standseilbahn',
34     11: 'Oberleitungsbus',
35     12: 'Einschienenbahn',
36 }
37
38
39 def format_time(day: date, seconds: float) -> str:
40     dt = datetime.combine(day, time.min)
41     return (dt + timedelta(seconds=seconds)).isoformat(timespec='minutes')
42
43
44 def update_sledrun(site: WikiSite, title: str, service_date: date, feed: Feed, service_ids_by_date: \
45         Dict[datetime.date, FrozenSet[str]]):
46     cprint(title, 'green')
47     sledrun_json_page = site.query_page(f'{title}/Rodelbahn.json')
48     sledrun: Sledrun = page_json(sledrun_json_page)
49     sledrun_orig = deepcopy(sledrun)
50
51     stops = feed.stops
52     stop_times = feed.stop_times
53     trips = feed.trips
54     routes = feed.routes
55     agency = feed.agency
56
57     for pt_stop in sledrun.get('public_transport_stops', []):
58         pt_stop: PublicTransportStop = pt_stop
59         ifopt_stop_id = pt_stop.get('ifopt_stop_id')  # e.g. "at:47:61646"
60         if ifopt_stop_id is None:
61             continue
62
63         selected_stops = stops[stops.stop_id.str.startswith(f'{ifopt_stop_id}:') | (stops.stop_id == ifopt_stop_id)]
64         if len(selected_stops) == 0:
65             continue
66
67         selected_stop_times = stop_times.merge(selected_stops.stop_id)
68         selected_trips = trips.merge(selected_stop_times)
69         selected_trips = selected_trips[selected_trips.service_id.isin(service_ids_by_date[service_date])]
70         selected_routes = routes.merge(selected_trips.route_id.drop_duplicates())
71         selected_routes = selected_routes.merge(agency)
72
73         selected_trip_stop_times = stop_times.merge(selected_trips.trip_id)
74         selected_trip_stop_times_grouper = selected_trip_stop_times.groupby('trip_id')['stop_sequence']
75         first_stops = selected_trip_stop_times.loc[selected_trip_stop_times_grouper.idxmin()]
76         first_stops = first_stops.merge(stops)
77         last_stops = selected_trip_stop_times.loc[selected_trip_stop_times_grouper.idxmax()]
78         last_stops = last_stops.merge(stops)
79
80         lines: List[PublicTransportStopLine] = []
81         for route in selected_routes.itertuples():
82             departures_dict: Dict[str, List[str]] = defaultdict(list)
83             arrivals_dict: Dict[str, List[str]] = defaultdict(list)
84
85             for trip in selected_trips[selected_trips.route_id == route.route_id].itertuples():
86                 trip_first_stop = first_stops[first_stops.trip_id == trip.trip_id].iloc[0]
87                 trip_last_stop = last_stops[last_stops.trip_id == trip.trip_id].iloc[0]
88                 if trip.stop_id != trip_first_stop.stop_id:
89                     arrivals_dict[trip_first_stop.stop_name].append(format_time(service_date, trip.arrival_time))
90                 if trip.stop_id != trip_last_stop.stop_id:
91                     departures_dict[trip_last_stop.stop_name].append(format_time(service_date, trip.departure_time))
92
93             schedule = {
94                 'service_date': service_date.isoformat(),
95                 'day_type': 'work_day',
96                 'departure': [{'direction': k, 'datetime': sorted(v)} for k, v in departures_dict.items()],
97                 'arrival': [{'origin': k, 'datetime': sorted(v)} for k, v in arrivals_dict.items()],
98             }
99
100             lines.append({
101                 'vao_line_id': f'vvt-{route.route_id}',
102                 'line': route.route_short_name,
103                 'category': ROUTE_TYPE[route.route_type],
104                 'operator': route.agency_name,
105                 'schedules': [schedule],
106             })
107         pt_stop['lines'] = lines
108
109     if sledrun == sledrun_orig:
110         return
111
112     jsonschema.validate(instance=sledrun, schema=site.sledrun_schema())
113     sledrun_ordered = order_json_keys(sledrun, site.sledrun_schema())
114     assert sledrun_ordered == sledrun
115     sledrun_orig_str = format_json(sledrun_orig)
116     sledrun_str = format_json(sledrun_ordered)
117
118     unified_diff(sledrun_orig_str, sledrun_str)
119     choice = input_yes_no_quit('Do you accept the changes [yes, no, quit]? ', None)
120     if choice == Choice.no:
121         return
122     elif choice == Choice.quit:
123         sys.exit(0)
124
125     site(
126         'edit',
127         pageid=sledrun_json_page['pageid'],
128         text=sledrun_str,
129         summary='Fahrplan zu Haltestellen ergänzt.',
130         bot=1,
131         baserevid=sledrun_json_page['revisions'][0]['revid'],
132         nocreate=1,
133         token=site.token(),
134     )
135
136
137 def get_all_sledrun_titles(site: WikiSite) -> Iterable[str]:
138     for result in site.query(list='categorymembers', cmtitle='Kategorie:Rodelbahn', cmlimit='max'):
139         for page in result['categorymembers']:
140             yield page['title']
141
142 def update_public_transport_times(ini_files: List[str], service_date: date, gtfs_file: str, sledrun_title: Optional[str]):
143     feed = ptg.load_feed(gtfs_file)
144     service_ids_by_date = read_service_ids_by_date(gtfs_file)
145
146     site = WikiSite(ini_files)
147     if sledrun_title is None:
148         for sledrun_title in get_all_sledrun_titles(site):
149             update_sledrun(site, sledrun_title, service_date, feed, service_ids_by_date)
150     else:
151         update_sledrun(site, sledrun_title, service_date, feed, service_ids_by_date)
152
153
154 def main():
155     query_date = default_query_date(date.today())
156     parser = argparse.ArgumentParser(description='Update public transport bus stop time info in sledrun JSON files.')
157     parser.add_argument('--sledrun', help='If given, work on a single sled run page, otherwise at the whole category.')
158     parser.add_argument('--date', type=date.fromisoformat, default=query_date,
159                         help='Working week date to query the database.')
160     parser.add_argument('gtfs_file', help='GTFS file.')
161     parser.add_argument('inifile', nargs='+', help='inifile.ini, see: https://www.winterrodeln.org/trac/wiki/ConfigIni')
162     args = parser.parse_args()
163     update_public_transport_times(args.inifile, args.date, args.gtfs_file, args.sledrun)
164
165
166 if __name__ == '__main__':
167     main()