Skip to content

Transit

Tables

Data models for various GTFS tables using pandera library.

The module includes the following classes:

  • AgencyTable: Optional. Represents the Agency table in the GTFS dataset.
  • WranglerStopsTable: Represents the Stops table in the GTFS dataset.
  • RoutesTable: Represents the Routes table in the GTFS dataset.
  • WranglerShapesTable: Represents the Shapes table in the GTFS dataset.
  • WranglerStopTimesTable: Represents the Stop Times table in the GTFS dataset.
  • WranglerTripsTable: Represents the Trips table in the GTFS dataset.

Each table model leverages the Pydantic data models defined in the records module to define the data model for the corresponding table. The classes also include additional configurations for, such as uniqueness constraints.

Validating a table to the WranglerStopsTable

from network_wrangler.models.gtfs.tables import WranglerStopsTable
from network_wrangler.utils.modesl import validate_df_to_model

validated_stops_df = validate_df_to_model(stops_df, WranglerStopsTable)

network_wrangler.models.gtfs.tables.AgenciesTable

Bases: DataFrameModel


              flowchart TD
              network_wrangler.models.gtfs.tables.AgenciesTable[AgenciesTable]

              

              click network_wrangler.models.gtfs.tables.AgenciesTable href "" "network_wrangler.models.gtfs.tables.AgenciesTable"
            

Represents the Agency table in the GTFS dataset.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#agencytxt

Attributes:

Name Type Description
agency_id str

The agency_id. Primary key. Required to be unique.

agency_name str

The agency name.

agency_url str

The agency URL.

agency_timezone str

The agency timezone.

agency_lang str

The agency language.

agency_phone str

The agency phone number.

agency_fare_url str

The agency fare URL.

agency_email str

The agency email.

Source code in network_wrangler/models/gtfs/tables.py
class AgenciesTable(DataFrameModel):
    """Represents the Agency table in the GTFS dataset.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#agencytxt>

    Attributes:
        agency_id (str): The agency_id. Primary key. Required to be unique.
        agency_name (str): The agency name.
        agency_url (str): The agency URL.
        agency_timezone (str): The agency timezone.
        agency_lang (str): The agency language.
        agency_phone (str): The agency phone number.
        agency_fare_url (str): The agency fare URL.
        agency_email (str): The agency email.
    """

    agency_id: Series[str] = Field(coerce=True, nullable=False, unique=True)
    agency_name: Series[str] = Field(coerce=True, nullable=True)
    agency_url: Series[HttpURL] = Field(coerce=True, nullable=True)
    agency_timezone: Series[str] = Field(coerce=True, nullable=True)
    agency_lang: Series[str] = Field(coerce=True, nullable=True)
    agency_phone: Series[str] = Field(coerce=True, nullable=True)
    agency_fare_url: Series[str] = Field(coerce=True, nullable=True)
    agency_email: Series[str] = Field(coerce=True, nullable=True)

    class Config:
        """Config for the AgenciesTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["agency_id"]

network_wrangler.models.gtfs.tables.FrequenciesTable

Bases: DataFrameModel


              flowchart TD
              network_wrangler.models.gtfs.tables.FrequenciesTable[FrequenciesTable]

              

              click network_wrangler.models.gtfs.tables.FrequenciesTable href "" "network_wrangler.models.gtfs.tables.FrequenciesTable"
            

Represents the Frequency table in the GTFS dataset.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#frequenciestxt

The primary key of this table is a composite key of trip_id and start_time.

Attributes:

Name Type Description
trip_id str

Foreign key to trip_id in the trips table.

start_time TimeString

The start time in HH:MM:SS format.

end_time TimeString

The end time in HH:MM:SS format.

headway_secs int

The headway in seconds.

Source code in network_wrangler/models/gtfs/tables.py
class FrequenciesTable(DataFrameModel):
    """Represents the Frequency table in the GTFS dataset.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#frequenciestxt>

    The primary key of this table is a composite key of `trip_id` and `start_time`.

    Attributes:
        trip_id (str): Foreign key to `trip_id` in the trips table.
        start_time (TimeString): The start time in HH:MM:SS format.
        end_time (TimeString): The end time in HH:MM:SS format.
        headway_secs (int): The headway in seconds.
    """

    trip_id: Series[str] = Field(nullable=False, coerce=True)
    start_time: Series[TimeString] = Field(
        nullable=False, coerce=True, default=DEFAULT_TIMESPAN[0]
    )
    end_time: Series[TimeString] = Field(nullable=False, coerce=True, default=DEFAULT_TIMESPAN[1])
    headway_secs: Series[int] = Field(
        coerce=True,
        ge=1,
        nullable=False,
    )

    class Config:
        """Config for the FrequenciesTable data model."""

        coerce = True
        add_missing_columns = True
        unique: ClassVar[list[str]] = ["trip_id", "start_time"]
        _pk: ClassVar[TablePrimaryKeys] = ["trip_id", "start_time"]
        _fk: ClassVar[TableForeignKeys] = {"trip_id": ("trips", "trip_id")}

network_wrangler.models.gtfs.tables.RoutesTable

Bases: DataFrameModel


              flowchart TD
              network_wrangler.models.gtfs.tables.RoutesTable[RoutesTable]

              

              click network_wrangler.models.gtfs.tables.RoutesTable href "" "network_wrangler.models.gtfs.tables.RoutesTable"
            

Represents the Routes table in the GTFS dataset.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#routestxt

Attributes:

Name Type Description
route_id str

The route_id. Primary key. Required to be unique.

route_short_name str | None

The route short name.

route_long_name str | None

The route long name.

route_type RouteType

The route type. Required. Values can be: - 0: Tram, Streetcar, Light rail - 1: Subway, Metro - 2: Rail - 3: Bus

agency_id str | None

The agency_id. Foreign key to agency_id in the agencies table.

route_desc str | None

The route description.

route_url str | None

The route URL.

route_color str | None

The route color.

route_text_color str | None

The route text color.

Source code in network_wrangler/models/gtfs/tables.py
class RoutesTable(DataFrameModel):
    """Represents the Routes table in the GTFS dataset.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#routestxt>

    Attributes:
        route_id (str): The route_id. Primary key. Required to be unique.
        route_short_name (str | None): The route short name.
        route_long_name (str | None): The route long name.
        route_type (RouteType): The route type. Required. Values can be:
            - 0: Tram, Streetcar, Light rail
            - 1: Subway, Metro
            - 2: Rail
            - 3: Bus
        agency_id (str | None): The agency_id. Foreign key to agency_id in the agencies table.
        route_desc (str | None): The route description.
        route_url (str | None): The route URL.
        route_color (str | None): The route color.
        route_text_color (str | None): The route text color.
    """

    route_id: Series[str] = Field(nullable=False, unique=True, coerce=True)
    route_short_name: Series[str] = Field(nullable=True, coerce=True)
    route_long_name: Series[str] = Field(nullable=True, coerce=True)
    route_type: Series[Category] = Field(
        dtype_kwargs={"categories": RouteType}, coerce=True, nullable=False
    )

    # Optional Fields
    agency_id: Series[str] | None = Field(nullable=True, coerce=True)
    route_desc: Series[str] | None = Field(nullable=True, coerce=True)
    route_url: Series[str] | None = Field(nullable=True, coerce=True)
    route_color: Series[str] | None = Field(nullable=True, coerce=True)
    route_text_color: Series[str] | None = Field(nullable=True, coerce=True)

    class Config:
        """Config for the RoutesTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["route_id"]
        _fk: ClassVar[TableForeignKeys] = {"agency_id": ("agencies", "agency_id")}

network_wrangler.models.gtfs.tables.ShapesTable

Bases: DataFrameModel


              flowchart TD
              network_wrangler.models.gtfs.tables.ShapesTable[ShapesTable]

              

              click network_wrangler.models.gtfs.tables.ShapesTable href "" "network_wrangler.models.gtfs.tables.ShapesTable"
            

Represents the Shapes table in the GTFS dataset.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#shapestxt

Attributes:

Name Type Description
shape_id str

The shape_id. Primary key. Required to be unique.

shape_pt_lat float

The shape point latitude.

shape_pt_lon float

The shape point longitude.

shape_pt_sequence int

The shape point sequence.

shape_dist_traveled float | None

The shape distance traveled.

Source code in network_wrangler/models/gtfs/tables.py
class ShapesTable(DataFrameModel):
    """Represents the Shapes table in the GTFS dataset.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#shapestxt>

    Attributes:
        shape_id (str): The shape_id. Primary key. Required to be unique.
        shape_pt_lat (float): The shape point latitude.
        shape_pt_lon (float): The shape point longitude.
        shape_pt_sequence (int): The shape point sequence.
        shape_dist_traveled (float | None): The shape distance traveled.
    """

    shape_id: Series[str] = Field(nullable=False, coerce=True)
    shape_pt_lat: Series[float] = Field(coerce=True, nullable=False, ge=-90, le=90)
    shape_pt_lon: Series[float] = Field(coerce=True, nullable=False, ge=-180, le=180)
    shape_pt_sequence: Series[int] = Field(coerce=True, nullable=False, ge=0)

    # Optional
    shape_dist_traveled: Series[float] | None = Field(coerce=True, nullable=True, ge=0)

    class Config:
        """Config for the ShapesTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["shape_id", "shape_pt_sequence"]
        _fk: ClassVar[TableForeignKeys] = {}
        unique: ClassVar[list[str]] = ["shape_id", "shape_pt_sequence"]

network_wrangler.models.gtfs.tables.StopTimesTable

Bases: DataFrameModel


              flowchart TD
              network_wrangler.models.gtfs.tables.StopTimesTable[StopTimesTable]

              

              click network_wrangler.models.gtfs.tables.StopTimesTable href "" "network_wrangler.models.gtfs.tables.StopTimesTable"
            

Represents the Stop Times table in the GTFS dataset.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#stop_timestxt

The primary key of this table is a composite key of trip_id and stop_sequence.

Attributes:

Name Type Description
trip_id str

Foreign key to trip_id in the trips table.

stop_id str

Foreign key to stop_id in the stops table.

stop_sequence int

The stop sequence.

pickup_type PickupDropoffType

The pickup type. Values can be: - 0: Regularly scheduled pickup - 1: No pickup available - 2: Must phone agency to arrange pickup - 3: Must coordinate with driver to arrange pickup

drop_off_type PickupDropoffType

The drop off type. Values can be: - 0: Regularly scheduled drop off - 1: No drop off available - 2: Must phone agency to arrange drop off - 3: Must coordinate with driver to arrange drop off

arrival_time TimeString

The arrival time in HH:MM:SS format.

departure_time TimeString

The departure time in HH:MM:SS format.

shape_dist_traveled float | None

The shape distance traveled.

timepoint TimepointType | None

The timepoint type. Values can be: - 0: The stop is not a timepoint - 1: The stop is a timepoint

Source code in network_wrangler/models/gtfs/tables.py
class StopTimesTable(DataFrameModel):
    """Represents the Stop Times table in the GTFS dataset.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#stop_timestxt>

    The primary key of this table is a composite key of `trip_id` and `stop_sequence`.

    Attributes:
        trip_id (str): Foreign key to `trip_id` in the trips table.
        stop_id (str): Foreign key to `stop_id` in the stops table.
        stop_sequence (int): The stop sequence.
        pickup_type (PickupDropoffType): The pickup type. Values can be:
            - 0: Regularly scheduled pickup
            - 1: No pickup available
            - 2: Must phone agency to arrange pickup
            - 3: Must coordinate with driver to arrange pickup
        drop_off_type (PickupDropoffType): The drop off type. Values can be:
            - 0: Regularly scheduled drop off
            - 1: No drop off available
            - 2: Must phone agency to arrange drop off
            - 3: Must coordinate with driver to arrange drop off
        arrival_time (TimeString): The arrival time in HH:MM:SS format.
        departure_time (TimeString): The departure time in HH:MM:SS format.
        shape_dist_traveled (float | None): The shape distance traveled.
        timepoint (TimepointType | None): The timepoint type. Values can be:
            - 0: The stop is not a timepoint
            - 1: The stop is a timepoint
    """

    trip_id: Series[str] = Field(nullable=False, coerce=True)
    stop_id: Series[str] = Field(nullable=False, coerce=True)
    stop_sequence: Series[int] = Field(nullable=False, coerce=True, ge=0)
    pickup_type: Series[Category] = Field(
        dtype_kwargs={"categories": PickupDropoffType},
        nullable=True,
        coerce=True,
    )
    drop_off_type: Series[Category] = Field(
        dtype_kwargs={"categories": PickupDropoffType},
        nullable=True,
        coerce=True,
    )
    arrival_time: Series[pa.Timestamp] = Field(nullable=True, default=pd.NaT, coerce=True)
    departure_time: Series[pa.Timestamp] = Field(nullable=True, default=pd.NaT, coerce=True)

    # Optional
    shape_dist_traveled: Series[float] | None = Field(coerce=True, nullable=True, ge=0)
    timepoint: Series[Category] | None = Field(
        dtype_kwargs={"categories": TimepointType}, coerce=True, default=0
    )

    class Config:
        """Config for the StopTimesTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["trip_id", "stop_sequence"]
        _fk: ClassVar[TableForeignKeys] = {
            "trip_id": ("trips", "trip_id"),
            "stop_id": ("stops", "stop_id"),
        }

        unique: ClassVar[list[str]] = ["trip_id", "stop_sequence"]

    @pa.dataframe_parser
    def parse_times(cls, df):
        """Parse time strings to timestamps."""
        # Convert string times to timestamps
        if "arrival_time" in df.columns and "departure_time" in df.columns:
            # Convert string times to timestamps using str_to_time_series
            df["arrival_time"] = str_to_time_series(df["arrival_time"])
            df["departure_time"] = str_to_time_series(df["departure_time"])

        return df

network_wrangler.models.gtfs.tables.StopTimesTable.parse_times

parse_times(df)

Parse time strings to timestamps.

Source code in network_wrangler/models/gtfs/tables.py
@pa.dataframe_parser
def parse_times(cls, df):
    """Parse time strings to timestamps."""
    # Convert string times to timestamps
    if "arrival_time" in df.columns and "departure_time" in df.columns:
        # Convert string times to timestamps using str_to_time_series
        df["arrival_time"] = str_to_time_series(df["arrival_time"])
        df["departure_time"] = str_to_time_series(df["departure_time"])

    return df

network_wrangler.models.gtfs.tables.StopsTable

Bases: DataFrameModel


              flowchart TD
              network_wrangler.models.gtfs.tables.StopsTable[StopsTable]

              

              click network_wrangler.models.gtfs.tables.StopsTable href "" "network_wrangler.models.gtfs.tables.StopsTable"
            

Represents the Stops table in the GTFS dataset.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#stopstxt

Attributes:

Name Type Description
stop_id str

The stop_id. Primary key. Required to be unique.

stop_lat float

The stop latitude.

stop_lon float

The stop longitude.

wheelchair_boarding int | None

The wheelchair boarding.

stop_code str | None

The stop code.

stop_name str | None

The stop name.

tts_stop_name str | None

The text-to-speech stop name.

stop_desc str | None

The stop description.

zone_id str | None

The zone id.

stop_url str | None

The stop URL.

location_type LocationType | None

The location type. Values can be: - 0: stop platform - 1: station - 2: entrance/exit - 3: generic node - 4: boarding area Default of blank assumes a stop platform.

parent_station str | None

The stop_id of the parent station.

stop_timezone str | None

The stop timezone.

Source code in network_wrangler/models/gtfs/tables.py
class StopsTable(DataFrameModel):
    """Represents the Stops table in the GTFS dataset.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#stopstxt>

    Attributes:
        stop_id (str): The stop_id. Primary key. Required to be unique.
        stop_lat (float): The stop latitude.
        stop_lon (float): The stop longitude.
        wheelchair_boarding (int | None): The wheelchair boarding.
        stop_code (str | None): The stop code.
        stop_name (str | None): The stop name.
        tts_stop_name (str | None): The text-to-speech stop name.
        stop_desc (str | None): The stop description.
        zone_id (str | None): The zone id.
        stop_url (str | None): The stop URL.
        location_type (LocationType | None): The location type. Values can be:
            - 0: stop platform
            - 1: station
            - 2: entrance/exit
            - 3: generic node
            - 4: boarding area
            Default of blank assumes a stop platform.
        parent_station (str | None): The `stop_id` of the parent station.
        stop_timezone (str | None): The stop timezone.
    """

    stop_id: Series[str] = Field(coerce=True, nullable=False, unique=True)
    stop_lat: Series[float] = Field(coerce=True, nullable=False, ge=-90, le=90)
    stop_lon: Series[float] = Field(coerce=True, nullable=False, ge=-180, le=180)

    # Optional Fields
    wheelchair_boarding: Series[Category] | None = Field(
        dtype_kwargs={"categories": WheelchairAccessible}, coerce=True, default=0
    )
    stop_code: Series[str] | None = Field(nullable=True, coerce=True)
    stop_name: Series[str] | None = Field(nullable=True, coerce=True)
    tts_stop_name: Series[str] | None = Field(nullable=True, coerce=True)
    stop_desc: Series[str] | None = Field(nullable=True, coerce=True)
    zone_id: Series[str] | None = Field(nullable=True, coerce=True)
    stop_url: Series[str] | None = Field(nullable=True, coerce=True)
    location_type: Series[Category] | None = Field(
        dtype_kwargs={"categories": LocationType},
        nullable=True,
        coerce=True,
        default=0,
    )
    parent_station: Series[str] | None = Field(nullable=True, coerce=True)
    stop_timezone: Series[str] | None = Field(nullable=True, coerce=True)

    class Config:
        """Config for the StopsTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["stop_id"]
        _fk: ClassVar[TableForeignKeys] = {"parent_station": ("stops", "stop_id")}

network_wrangler.models.gtfs.tables.TripsTable

Bases: DataFrameModel


              flowchart TD
              network_wrangler.models.gtfs.tables.TripsTable[TripsTable]

              

              click network_wrangler.models.gtfs.tables.TripsTable href "" "network_wrangler.models.gtfs.tables.TripsTable"
            

Represents the Trips table in the GTFS dataset.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#tripstxt

Attributes:

Name Type Description
trip_id str

Primary key. Required to be unique.

shape_id str

Foreign key to shape_id in the shapes table.

direction_id DirectionID

The direction id. Required. Values can be: - 0: Outbound - 1: Inbound

service_id str

The service id.

route_id str

The route id. Foreign key to route_id in the routes table.

trip_short_name str | None

The trip short name.

trip_headsign str | None

The trip headsign.

block_id str | None

The block id.

wheelchair_accessible int | None

The wheelchair accessible. Values can be: - 0: No information - 1: Allowed - 2: Not allowed

bikes_allowed int | None

The bikes allowed. Values can be: - 0: No information - 1: Allowed - 2: Not allowed

Source code in network_wrangler/models/gtfs/tables.py
class TripsTable(DataFrameModel):
    """Represents the Trips table in the GTFS dataset.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#tripstxt>

    Attributes:
        trip_id (str): Primary key. Required to be unique.
        shape_id (str): Foreign key to `shape_id` in the shapes table.
        direction_id (DirectionID): The direction id. Required. Values can be:
            - 0: Outbound
            - 1: Inbound
        service_id (str): The service id.
        route_id (str): The route id. Foreign key to `route_id` in the routes table.
        trip_short_name (str | None): The trip short name.
        trip_headsign (str | None): The trip headsign.
        block_id (str | None): The block id.
        wheelchair_accessible (int | None): The wheelchair accessible. Values can be:
            - 0: No information
            - 1: Allowed
            - 2: Not allowed
        bikes_allowed (int | None): The bikes allowed. Values can be:
            - 0: No information
            - 1: Allowed
            - 2: Not allowed
    """

    trip_id: Series[str] = Field(nullable=False, unique=True, coerce=True)
    shape_id: Series[str] = Field(nullable=False, coerce=True)
    direction_id: Series[Category] = Field(
        dtype_kwargs={"categories": DirectionID}, coerce=True, nullable=False, default=0
    )
    service_id: Series[str] = Field(nullable=False, coerce=True, default="1")
    route_id: Series[str] = Field(nullable=False, coerce=True)

    # Optional Fields
    trip_short_name: Series[str] | None = Field(nullable=True, coerce=True)
    trip_headsign: Series[str] | None = Field(nullable=True, coerce=True)
    block_id: Series[str] | None = Field(nullable=True, coerce=True)
    wheelchair_accessible: Series[Category] | None = Field(
        dtype_kwargs={"categories": WheelchairAccessible}, coerce=True, default=0
    )
    bikes_allowed: Series[Category] | None = Field(
        dtype_kwargs={"categories": BikesAllowed},
        coerce=True,
        default=0,
    )

    class Config:
        """Config for the TripsTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["trip_id"]
        _fk: ClassVar[TableForeignKeys] = {"route_id": ("routes", "route_id")}

network_wrangler.models.gtfs.tables.WranglerFrequenciesTable

Bases: FrequenciesTable


              flowchart TD
              network_wrangler.models.gtfs.tables.WranglerFrequenciesTable[WranglerFrequenciesTable]
              network_wrangler.models.gtfs.tables.FrequenciesTable[FrequenciesTable]

                              network_wrangler.models.gtfs.tables.FrequenciesTable --> network_wrangler.models.gtfs.tables.WranglerFrequenciesTable
                


              click network_wrangler.models.gtfs.tables.WranglerFrequenciesTable href "" "network_wrangler.models.gtfs.tables.WranglerFrequenciesTable"
              click network_wrangler.models.gtfs.tables.FrequenciesTable href "" "network_wrangler.models.gtfs.tables.FrequenciesTable"
            

Wrangler flavor of GTFS FrequenciesTable.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#frequenciestxt

The primary key of this table is a composite key of trip_id and start_time.

Attributes:

Name Type Description
trip_id str

Foreign key to trip_id in the trips table.

start_time datetime

The start time in datetime format.

end_time datetime

The end time in datetime format.

headway_secs int

The headway in seconds.

Source code in network_wrangler/models/gtfs/tables.py
class WranglerFrequenciesTable(FrequenciesTable):
    """Wrangler flavor of GTFS FrequenciesTable.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#frequenciestxt>

    The primary key of this table is a composite key of `trip_id` and `start_time`.

    Attributes:
        trip_id (str): Foreign key to `trip_id` in the trips table.
        start_time (datetime.datetime): The start time in datetime format.
        end_time (datetime.datetime): The end time in datetime format.
        headway_secs (int): The headway in seconds.
    """

    projects: Series[str] = Field(coerce=True, default="")
    start_time: Series = Field(
        nullable=False, coerce=True, default=str_to_time(DEFAULT_TIMESPAN[0])
    )
    end_time: Series = Field(nullable=False, coerce=True, default=str_to_time(DEFAULT_TIMESPAN[1]))

    class Config:
        """Config for the FrequenciesTable data model."""

        coerce = True
        add_missing_columns = True
        unique: ClassVar[list[str]] = ["trip_id", "start_time"]
        _pk: ClassVar[TablePrimaryKeys] = ["trip_id", "start_time"]
        _fk: ClassVar[TableForeignKeys] = {"trip_id": ("trips", "trip_id")}

    @pa.parser("start_time")
    def st_to_timestamp(cls, series: Series) -> Series[Timestamp]:
        """Check that start time is timestamp."""
        series = series.fillna(str_to_time(DEFAULT_TIMESPAN[0]))
        if series.dtype == "datetime64[ns]":
            return series
        series = str_to_time_series(series)
        return series.astype("datetime64[ns]")

    @pa.parser("end_time")
    def et_to_timestamp(cls, series: Series) -> Series[Timestamp]:
        """Check that start time is timestamp."""
        series = series.fillna(str_to_time(DEFAULT_TIMESPAN[1]))
        if series.dtype == "datetime64[ns]":
            return series
        return str_to_time_series(series)

network_wrangler.models.gtfs.tables.WranglerFrequenciesTable.et_to_timestamp

et_to_timestamp(series)

Check that start time is timestamp.

Source code in network_wrangler/models/gtfs/tables.py
@pa.parser("end_time")
def et_to_timestamp(cls, series: Series) -> Series[Timestamp]:
    """Check that start time is timestamp."""
    series = series.fillna(str_to_time(DEFAULT_TIMESPAN[1]))
    if series.dtype == "datetime64[ns]":
        return series
    return str_to_time_series(series)

network_wrangler.models.gtfs.tables.WranglerFrequenciesTable.st_to_timestamp

st_to_timestamp(series)

Check that start time is timestamp.

Source code in network_wrangler/models/gtfs/tables.py
@pa.parser("start_time")
def st_to_timestamp(cls, series: Series) -> Series[Timestamp]:
    """Check that start time is timestamp."""
    series = series.fillna(str_to_time(DEFAULT_TIMESPAN[0]))
    if series.dtype == "datetime64[ns]":
        return series
    series = str_to_time_series(series)
    return series.astype("datetime64[ns]")

network_wrangler.models.gtfs.tables.WranglerShapesTable

Bases: ShapesTable


              flowchart TD
              network_wrangler.models.gtfs.tables.WranglerShapesTable[WranglerShapesTable]
              network_wrangler.models.gtfs.tables.ShapesTable[ShapesTable]

                              network_wrangler.models.gtfs.tables.ShapesTable --> network_wrangler.models.gtfs.tables.WranglerShapesTable
                


              click network_wrangler.models.gtfs.tables.WranglerShapesTable href "" "network_wrangler.models.gtfs.tables.WranglerShapesTable"
              click network_wrangler.models.gtfs.tables.ShapesTable href "" "network_wrangler.models.gtfs.tables.ShapesTable"
            

Wrangler flavor of GTFS ShapesTable.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#shapestxt

Attributes:

Name Type Description
shape_id str

The shape_id. Primary key. Required to be unique.

shape_pt_lat float

The shape point latitude.

shape_pt_lon float

The shape point longitude.

shape_pt_sequence int

The shape point sequence.

shape_dist_traveled float | None

The shape distance traveled.

shape_model_node_id int

The model_node_id of the shape point. Foreign key to the model_node_id in the nodes table.

projects str

A comma-separated string value for projects that have been applied to this shape.

Source code in network_wrangler/models/gtfs/tables.py
class WranglerShapesTable(ShapesTable):
    """Wrangler flavor of GTFS ShapesTable.

     For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#shapestxt>

    Attributes:
        shape_id (str): The shape_id. Primary key. Required to be unique.
        shape_pt_lat (float): The shape point latitude.
        shape_pt_lon (float): The shape point longitude.
        shape_pt_sequence (int): The shape point sequence.
        shape_dist_traveled (float | None): The shape distance traveled.
        shape_model_node_id (int): The model_node_id of the shape point. Foreign key to the model_node_id in the nodes table.
        projects (str): A comma-separated string value for projects that have been applied to this shape.
    """

    shape_model_node_id: Series[int] = Field(coerce=True, nullable=False)
    projects: Series[str] = Field(coerce=True, default="")

network_wrangler.models.gtfs.tables.WranglerStopTimesTable

Bases: StopTimesTable


              flowchart TD
              network_wrangler.models.gtfs.tables.WranglerStopTimesTable[WranglerStopTimesTable]
              network_wrangler.models.gtfs.tables.StopTimesTable[StopTimesTable]

                              network_wrangler.models.gtfs.tables.StopTimesTable --> network_wrangler.models.gtfs.tables.WranglerStopTimesTable
                


              click network_wrangler.models.gtfs.tables.WranglerStopTimesTable href "" "network_wrangler.models.gtfs.tables.WranglerStopTimesTable"
              click network_wrangler.models.gtfs.tables.StopTimesTable href "" "network_wrangler.models.gtfs.tables.StopTimesTable"
            

Wrangler flavor of GTFS StopTimesTable.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#stop_timestxt

The primary key of this table is a composite key of trip_id and stop_sequence.

Attributes:

Name Type Description
trip_id str

Foreign key to trip_id in the trips table.

stop_id int

Foreign key to stop_id in the stops table.

stop_sequence int

The stop sequence.

pickup_type PickupDropoffType

The pickup type. Values can be: - 0: Regularly scheduled pickup - 1: No pickup available - 2: Must phone agency to arrange pickup - 3: Must coordinate with driver to arrange pickup

drop_off_type PickupDropoffType

The drop off type. Values can be: - 0: Regularly scheduled drop off - 1: No drop off available - 2: Must phone agency to arrange drop off - 3: Must coordinate with driver to arrange drop off

shape_dist_traveled float | None

The shape distance traveled.

timepoint TimepointType | None

The timepoint type. Values can be: - 0: The stop is not a timepoint - 1: The stop is a timepoint

projects str

A comma-separated string value for projects that have been applied to this stop.

Source code in network_wrangler/models/gtfs/tables.py
class WranglerStopTimesTable(StopTimesTable):
    """Wrangler flavor of GTFS StopTimesTable.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#stop_timestxt>

    The primary key of this table is a composite key of `trip_id` and `stop_sequence`.

    Attributes:
        trip_id (str): Foreign key to `trip_id` in the trips table.
        stop_id (int): Foreign key to `stop_id` in the stops table.
        stop_sequence (int): The stop sequence.
        pickup_type (PickupDropoffType): The pickup type. Values can be:
            - 0: Regularly scheduled pickup
            - 1: No pickup available
            - 2: Must phone agency to arrange pickup
            - 3: Must coordinate with driver to arrange pickup
        drop_off_type (PickupDropoffType): The drop off type. Values can be:
            - 0: Regularly scheduled drop off
            - 1: No drop off available
            - 2: Must phone agency to arrange drop off
            - 3: Must coordinate with driver to arrange drop off
        shape_dist_traveled (float | None): The shape distance traveled.
        timepoint (TimepointType | None): The timepoint type. Values can be:
            - 0: The stop is not a timepoint
            - 1: The stop is a timepoint
        projects (str): A comma-separated string value for projects that have been applied to this stop.
    """

    stop_id: Series[int] = Field(nullable=False, coerce=True, description="The model_node_id.")
    projects: Series[str] = Field(coerce=True, default="")
    arrival_time: Series[pa.Timestamp] = Field(nullable=True, default=pd.NaT, coerce=True)
    departure_time: Series[pa.Timestamp] = Field(nullable=True, default=pd.NaT, coerce=True)

    class Config:
        """Config for the StopTimesTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["trip_id", "stop_sequence"]
        _fk: ClassVar[TableForeignKeys] = {
            "trip_id": ("trips", "trip_id"),
            "stop_id": ("stops", "stop_id"),
        }

        unique: ClassVar[list[str]] = ["trip_id", "stop_sequence"]

    @pa.dataframe_parser
    def parse_times(cls, df):
        """Parse time strings to timestamps."""
        # Convert string times to timestamps
        if "arrival_time" in df.columns and "departure_time" in df.columns:
            # Convert string times to timestamps using str_to_time_series
            df["arrival_time"] = str_to_time_series(df["arrival_time"])
            df["departure_time"] = str_to_time_series(df["departure_time"])

        return df

network_wrangler.models.gtfs.tables.WranglerStopTimesTable.parse_times

parse_times(df)

Parse time strings to timestamps.

Source code in network_wrangler/models/gtfs/tables.py
@pa.dataframe_parser
def parse_times(cls, df):
    """Parse time strings to timestamps."""
    # Convert string times to timestamps
    if "arrival_time" in df.columns and "departure_time" in df.columns:
        # Convert string times to timestamps using str_to_time_series
        df["arrival_time"] = str_to_time_series(df["arrival_time"])
        df["departure_time"] = str_to_time_series(df["departure_time"])

    return df

network_wrangler.models.gtfs.tables.WranglerStopsTable

Bases: StopsTable


              flowchart TD
              network_wrangler.models.gtfs.tables.WranglerStopsTable[WranglerStopsTable]
              network_wrangler.models.gtfs.tables.StopsTable[StopsTable]

                              network_wrangler.models.gtfs.tables.StopsTable --> network_wrangler.models.gtfs.tables.WranglerStopsTable
                


              click network_wrangler.models.gtfs.tables.WranglerStopsTable href "" "network_wrangler.models.gtfs.tables.WranglerStopsTable"
              click network_wrangler.models.gtfs.tables.StopsTable href "" "network_wrangler.models.gtfs.tables.StopsTable"
            

Wrangler flavor of GTFS StopsTable.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#stopstxt

Attributes:

Name Type Description
stop_id int

The stop_id. Primary key. Required to be unique. Wrangler assumes that this is a reference to a roadway node and as such must be an integer

stop_lat float

The stop latitude.

stop_lon float

The stop longitude.

wheelchair_boarding int | None

The wheelchair boarding.

stop_code str | None

The stop code.

stop_name str | None

The stop name.

tts_stop_name str | None

The text-to-speech stop name.

stop_desc str | None

The stop description.

zone_id str | None

The zone id.

stop_url str | None

The stop URL.

location_type LocationType | None

The location type. Values can be: - 0: stop platform - 1: station - 2: entrance/exit - 3: generic node - 4: boarding area Default of blank assumes a stop platform.

parent_station int | None

The stop_id of the parent station. Since stop_id is an integer in Wrangler, this field is also an integer

stop_timezone str | None

The stop timezone.

stop_id_GTFS str | None

The stop_id from the GTFS data.

projects str

A comma-separated string value for projects that have been applied to this stop.

Source code in network_wrangler/models/gtfs/tables.py
class WranglerStopsTable(StopsTable):
    """Wrangler flavor of GTFS StopsTable.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#stopstxt>

    Attributes:
        stop_id (int): The stop_id. Primary key. Required to be unique. **Wrangler assumes that this is a reference to a roadway node and as such must be an integer**
        stop_lat (float): The stop latitude.
        stop_lon (float): The stop longitude.
        wheelchair_boarding (int | None): The wheelchair boarding.
        stop_code (str | None): The stop code.
        stop_name (str | None): The stop name.
        tts_stop_name (str | None): The text-to-speech stop name.
        stop_desc (str | None): The stop description.
        zone_id (str | None): The zone id.
        stop_url (str | None): The stop URL.
        location_type (LocationType | None): The location type. Values can be:
            - 0: stop platform
            - 1: station
            - 2: entrance/exit
            - 3: generic node
            - 4: boarding area
            Default of blank assumes a stop platform.
        parent_station (int | None): The `stop_id` of the parent station. **Since stop_id is an integer in Wrangler, this field is also an integer**
        stop_timezone (str | None): The stop timezone.
        stop_id_GTFS (str | None): The stop_id from the GTFS data.
        projects (str): A comma-separated string value for projects that have been applied to this stop.
    """

    stop_id: Series[int] = Field(
        coerce=True, nullable=False, unique=True, description="The model_node_id."
    )
    stop_id_GTFS: Series[str] = Field(
        coerce=True,
        nullable=True,
        description="The stop_id from the GTFS data",
    )
    stop_lat: Series[float] = Field(coerce=True, nullable=True, ge=-90, le=90)
    stop_lon: Series[float] = Field(coerce=True, nullable=True, ge=-180, le=180)
    projects: Series[str] = Field(coerce=True, default="")

network_wrangler.models.gtfs.tables.WranglerTripsTable

Bases: TripsTable


              flowchart TD
              network_wrangler.models.gtfs.tables.WranglerTripsTable[WranglerTripsTable]
              network_wrangler.models.gtfs.tables.TripsTable[TripsTable]

                              network_wrangler.models.gtfs.tables.TripsTable --> network_wrangler.models.gtfs.tables.WranglerTripsTable
                


              click network_wrangler.models.gtfs.tables.WranglerTripsTable href "" "network_wrangler.models.gtfs.tables.WranglerTripsTable"
              click network_wrangler.models.gtfs.tables.TripsTable href "" "network_wrangler.models.gtfs.tables.TripsTable"
            

Represents the Trips table in the Wrangler feed, adding projects list.

For field definitions, see the GTFS reference: https://gtfs.org/documentation/schedule/reference/#tripstxt

Attributes:

Name Type Description
trip_id str

Primary key. Required to be unique.

shape_id str

Foreign key to shape_id in the shapes table.

direction_id DirectionID

The direction id. Required. Values can be: - 0: Outbound - 1: Inbound

service_id str

The service id.

route_id str

The route id. Foreign key to route_id in the routes table.

trip_short_name str | None

The trip short name.

trip_headsign str | None

The trip headsign.

block_id str | None

The block id.

wheelchair_accessible int | None

The wheelchair accessible. Values can be: - 0: No information - 1: Allowed - 2: Not allowed

bikes_allowed int | None

The bikes allowed. Values can be: - 0: No information - 1: Allowed - 2: Not allowed

projects str

A comma-separated string value for projects that have been applied to this trip.

Source code in network_wrangler/models/gtfs/tables.py
class WranglerTripsTable(TripsTable):
    """Represents the Trips table in the Wrangler feed, adding projects list.

    For field definitions, see the GTFS reference: <https://gtfs.org/documentation/schedule/reference/#tripstxt>

    Attributes:
        trip_id (str): Primary key. Required to be unique.
        shape_id (str): Foreign key to `shape_id` in the shapes table.
        direction_id (DirectionID): The direction id. Required. Values can be:
            - 0: Outbound
            - 1: Inbound
        service_id (str): The service id.
        route_id (str): The route id. Foreign key to `route_id` in the routes table.
        trip_short_name (str | None): The trip short name.
        trip_headsign (str | None): The trip headsign.
        block_id (str | None): The block id.
        wheelchair_accessible (int | None): The wheelchair accessible. Values can be:
            - 0: No information
            - 1: Allowed
            - 2: Not allowed
        bikes_allowed (int | None): The bikes allowed. Values can be:
            - 0: No information
            - 1: Allowed
            - 2: Not allowed
        projects (str): A comma-separated string value for projects that have been applied to this trip.
    """

    projects: Series[str] = Field(coerce=True, default="")

    class Config:
        """Config for the WranglerTripsTable data model."""

        coerce = True
        add_missing_columns = True
        _pk: ClassVar[TablePrimaryKeys] = ["trip_id"]
        _fk: ClassVar[TableForeignKeys] = {"route_id": ("routes", "route_id")}

Data Model for Pure GTFS Feed (not wrangler-flavored).

network_wrangler.models.gtfs.gtfs.FERRY_ROUTE_TYPES module-attribute

FERRY_ROUTE_TYPES = [FERRY]

GTFS route types which trigger ‘ferry_only’ link creation in add_stations_and_links_to_roadway_network()

network_wrangler.models.gtfs.gtfs.MIXED_TRAFFIC_ROUTE_TYPES module-attribute

MIXED_TRAFFIC_ROUTE_TYPES = [
    TRAM,
    BUS,
    CABLE_TRAM,
    TROLLEYBUS,
]

GTFS route types that operate in mixed traffic so stops are nodes that are drive-accessible.

See GTFS routes.txt

  • TRAM = Tram, Streetcar, Light rail, operates in mixed traffic AND at stations
  • CABLE_TRAM = street-level rail with underground cable
  • TROLLEYBUS = electric buses with overhead wires

network_wrangler.models.gtfs.gtfs.RAIL_ROUTE_TYPES module-attribute

RAIL_ROUTE_TYPES = [
    TRAM,
    SUBWAY,
    RAIL,
    CABLE_TRAM,
    AERIAL_LIFT,
    FUNICULAR,
    MONORAIL,
]

GTFS route types which trigger ‘rail_only’ link creation in add_stations_and_links_to_roadway_network()

network_wrangler.models.gtfs.gtfs.STATION_ROUTE_TYPES module-attribute

STATION_ROUTE_TYPES = [
    TRAM,
    SUBWAY,
    RAIL,
    FERRY,
    CABLE_TRAM,
    AERIAL_LIFT,
    FUNICULAR,
    MONORAIL,
]

GTFS route types that operate at stations.

network_wrangler.models.gtfs.gtfs.GtfsModel

Bases: DBModelMixin


              flowchart TD
              network_wrangler.models.gtfs.gtfs.GtfsModel[GtfsModel]
              network_wrangler.models._base.db.DBModelMixin[DBModelMixin]

                              network_wrangler.models._base.db.DBModelMixin --> network_wrangler.models.gtfs.gtfs.GtfsModel
                


              click network_wrangler.models.gtfs.gtfs.GtfsModel href "" "network_wrangler.models.gtfs.gtfs.GtfsModel"
              click network_wrangler.models._base.db.DBModelMixin href "" "network_wrangler.models._base.db.DBModelMixin"
            

Wrapper class around GTFS feed.

This is the pure GTFS model version of Feed

Most functionality derives from mixin class DBModelMixin which provides:

  • validation of tables to schemas when setting a table attribute (e.g. self.trips = trips_df)
  • validation of fks when setting a table attribute (e.g. self.trips = trips_df)
  • hashing and deep copy functionality
  • overload of eq to apply only to tables in table_names.
  • convenience methods for accessing tables

Attributes:

Name Type Description
table_names list[str]

list of table names in GTFS feed.

tables list[DataFrame]

list tables as dataframes.

agency DataFrame[AgenciesTable]

agency dataframe

stop_times DataFrame[StopTimesTable]

stop_times dataframe

stops DataFrame[WranglerStopsTable]

stops dataframe

shapes DataFrame[ShapesTable]

shapes dataframe

trips DataFrame[TripsTable]

trips dataframe

frequencies Optional[DataFrame[FrequenciesTable]]

frequencies dataframe

routes DataFrame[RoutesTable]

route dataframe

net Optional[TransitNetwork]

TransitNetwork object

Source code in network_wrangler/models/gtfs/gtfs.py
class GtfsModel(DBModelMixin):
    """Wrapper class around GTFS feed.

    This is the pure GTFS model version of [Feed][network_wrangler.transit.feed.feed.Feed]

    Most functionality derives from mixin class
    [`DBModelMixin`][network_wrangler.models._base.db.DBModelMixin] which provides:

    - validation of tables to schemas when setting a table attribute (e.g. self.trips = trips_df)
    - validation of fks when setting a table attribute (e.g. self.trips = trips_df)
    - hashing and deep copy functionality
    - overload of __eq__ to apply only to tables in table_names.
    - convenience methods for accessing tables

    Attributes:
        table_names (list[str]): list of table names in GTFS feed.
        tables (list[DataFrame]): list tables as dataframes.
        agency (DataFrame[AgenciesTable]): agency dataframe
        stop_times (DataFrame[StopTimesTable]): stop_times dataframe
        stops (DataFrame[WranglerStopsTable]): stops dataframe
        shapes (DataFrame[ShapesTable]): shapes dataframe
        trips (DataFrame[TripsTable]): trips dataframe
        frequencies (Optional[DataFrame[FrequenciesTable]]): frequencies dataframe
        routes (DataFrame[RoutesTable]): route dataframe
        net (Optional[TransitNetwork]): TransitNetwork object
    """

    # the ordering here matters because the stops need to be added before stop_times if
    # stop times needs to be converted
    _table_models: ClassVar[dict] = {
        "agency": AgenciesTable,
        "frequencies": FrequenciesTable,
        "routes": RoutesTable,
        "shapes": ShapesTable,
        "stops": StopsTable,
        "trips": TripsTable,
        "stop_times": StopTimesTable,
    }

    table_names: ClassVar[list[str]] = [
        "agency",
        "routes",
        "shapes",
        "stops",
        "trips",
        "stop_times",
    ]

    optional_table_names: ClassVar[list[str]] = ["frequencies"]

    def __init__(self, **kwargs):
        """Initialize GTFS model."""
        self.initialize_tables(**kwargs)

        # Set extra provided attributes.
        extra_attr = {k: v for k, v in kwargs.items() if k not in self.table_names}
        for k, v in extra_attr.items():
            self.__setattr__(k, v)

    @property
    def summary(self) -> dict:
        """A high level summary of the GTFS model object and public attributes."""
        summary_dict = {}
        for table_name in self._table_models:
            if hasattr(self, table_name):
                table = getattr(self, table_name)
                table_type = type(table)
                summary_dict[table_name] = (
                    f"{len(getattr(self, table_name)):,} {table_name} (type={table_type})"
                )
            else:
                summary_dict[table_name] = "not set"

        return summary_dict

    def __repr__(self) -> str:
        """Return a string representation of the GtfsModel with table summaries."""
        lines = ["GtfsModel:"]

        for k, v in self.summary.items():
            lines.append(f"  {k}: {v}")

        return "\n".join(lines)

network_wrangler.models.gtfs.gtfs.GtfsModel.summary property

summary

A high level summary of the GTFS model object and public attributes.

network_wrangler.models.gtfs.gtfs.GtfsModel.__init__

__init__(**kwargs)

Initialize GTFS model.

Source code in network_wrangler/models/gtfs/gtfs.py
def __init__(self, **kwargs):
    """Initialize GTFS model."""
    self.initialize_tables(**kwargs)

    # Set extra provided attributes.
    extra_attr = {k: v for k, v in kwargs.items() if k not in self.table_names}
    for k, v in extra_attr.items():
        self.__setattr__(k, v)

network_wrangler.models.gtfs.gtfs.GtfsModel.__repr__

__repr__()

Return a string representation of the GtfsModel with table summaries.

Source code in network_wrangler/models/gtfs/gtfs.py
def __repr__(self) -> str:
    """Return a string representation of the GtfsModel with table summaries."""
    lines = ["GtfsModel:"]

    for k, v in self.summary.items():
        lines.append(f"  {k}: {v}")

    return "\n".join(lines)

network_wrangler.models.gtfs.gtfs.GtfsValidationError

Bases: Exception


              flowchart TD
              network_wrangler.models.gtfs.gtfs.GtfsValidationError[GtfsValidationError]

              

              click network_wrangler.models.gtfs.gtfs.GtfsValidationError href "" "network_wrangler.models.gtfs.gtfs.GtfsValidationError"
            

Exception raised for errors in the GTFS feed.

Source code in network_wrangler/models/gtfs/gtfs.py
class GtfsValidationError(Exception):
    """Exception raised for errors in the GTFS feed."""

Feed

Main functionality for GTFS tables including Feed object.

network_wrangler.transit.feed.feed.Feed

Bases: DBModelMixin


              flowchart TD
              network_wrangler.transit.feed.feed.Feed[Feed]
              network_wrangler.models._base.db.DBModelMixin[DBModelMixin]

                              network_wrangler.models._base.db.DBModelMixin --> network_wrangler.transit.feed.feed.Feed
                


              click network_wrangler.transit.feed.feed.Feed href "" "network_wrangler.transit.feed.feed.Feed"
              click network_wrangler.models._base.db.DBModelMixin href "" "network_wrangler.models._base.db.DBModelMixin"
            

Wrapper class around Wrangler flavored GTFS feed.

Most functionality derives from mixin class DBModelMixin which provides:

  • validation of tables to schemas when setting a table attribute (e.g. self.trips = trips_df)
  • validation of fks when setting a table attribute (e.g. self.trips = trips_df)
  • hashing and deep copy functionality
  • overload of eq to apply only to tables in table_names.
  • convenience methods for accessing tables

What is Wrangler-flavored GTFS?

A Wrangler-flavored GTFS feed differs from a GTFS feed in the following ways:

  • frequencies.txt is required
  • shapes.txt requires additional field, shape_model_node_id, corresponding to model_node_id in the RoadwayNetwork
  • stops.txt - stop_id is required to be an int

Attributes:

Name Type Description
table_names list[str]

list of table names in GTFS feed.

tables list[DataFrame]

list tables as dataframes.

stop_times DataFrame[WranglerStopTimesTable]

stop_times dataframe with roadway node_ids

stops DataFrame[WranglerStopsTable]

stops dataframe

shapes DataFrame[WranglerShapesTable]

shapes dataframe

trips DataFrame[WranglerTripsTable]

trips dataframe

frequencies DataFrame[WranglerFrequenciesTable]

frequencies dataframe

routes DataFrame[RoutesTable]

route dataframe

agencies Optional[DataFrame[AgenciesTable]]

agencies dataframe

net Optional[TransitNetwork]

TransitNetwork object

Source code in network_wrangler/transit/feed/feed.py
class Feed(DBModelMixin):
    """Wrapper class around Wrangler flavored GTFS feed.

    Most functionality derives from mixin class
    [`DBModelMixin`][network_wrangler.models._base.db.DBModelMixin] which provides:

    - validation of tables to schemas when setting a table attribute (e.g. self.trips = trips_df)
    - validation of fks when setting a table attribute (e.g. self.trips = trips_df)
    - hashing and deep copy functionality
    - overload of __eq__ to apply only to tables in table_names.
    - convenience methods for accessing tables

    !!! note "What is Wrangler-flavored GTFS?"

        A Wrangler-flavored GTFS feed differs from a GTFS feed in the following ways:

        * `frequencies.txt` is required
        * `shapes.txt` requires additional field, `shape_model_node_id`, corresponding to `model_node_id` in the `RoadwayNetwork`
        * `stops.txt` - `stop_id` is required to be an int

    Attributes:
        table_names (list[str]): list of table names in GTFS feed.
        tables (list[DataFrame]): list tables as dataframes.
        stop_times (DataFrame[WranglerStopTimesTable]): stop_times dataframe with roadway node_ids
        stops (DataFrame[WranglerStopsTable]): stops dataframe
        shapes (DataFrame[WranglerShapesTable]): shapes dataframe
        trips (DataFrame[WranglerTripsTable]): trips dataframe
        frequencies (DataFrame[WranglerFrequenciesTable]): frequencies dataframe
        routes (DataFrame[RoutesTable]): route dataframe
        agencies (Optional[DataFrame[AgenciesTable]]): agencies dataframe
        net (Optional[TransitNetwork]): TransitNetwork object
    """

    # the ordering here matters because the stops need to be added before stop_times if
    # stop times needs to be converted
    _table_models: ClassVar[dict] = {
        "agencies": AgenciesTable,
        "frequencies": WranglerFrequenciesTable,
        "routes": RoutesTable,
        "shapes": WranglerShapesTable,
        "stops": WranglerStopsTable,
        "trips": WranglerTripsTable,
        "stop_times": WranglerStopTimesTable,
    }

    # Define the converters if the table needs to be converted to a Wrangler table.
    # Format: "table_name": converter_function
    _converters: ClassVar[dict[str, Callable]] = {}

    table_names: ClassVar[list[str]] = [
        "frequencies",
        "routes",
        "shapes",
        "stops",
        "trips",
        "stop_times",
    ]

    optional_table_names: ClassVar[list[str]] = ["agencies"]

    def __init__(self, **kwargs):
        """Create a Feed object from a dictionary of DataFrames representing a GTFS feed.

        Args:
            kwargs: A dictionary containing DataFrames representing the tables of a GTFS feed.
        """
        self._net = None
        self.feed_path: Path = None
        self.initialize_tables(**kwargs)

        # Set extra provided attributes but just FYI in logger.
        extra_attr = {k: v for k, v in kwargs.items() if k not in self.table_names}
        if extra_attr:
            WranglerLogger.info(f"Adding additional attributes to Feed: {extra_attr.keys()}")
        for k, v in extra_attr.items():
            self.__setattr__(k, v)

    @property
    def summary(self) -> dict:
        """A high level summary of the GTFS model object and public attributes."""
        summary_dict = {}
        for table_name in self._table_models:
            if hasattr(self, table_name):
                table = getattr(self, table_name)
                table_type = type(table)
                summary_dict[table_name] = (
                    f"{len(getattr(self, table_name)):,} {table_name} (type={table_type})"
                )
            else:
                summary_dict[table_name] = "not set"

        return summary_dict

    def __repr__(self) -> str:
        """Return a string representation of the Feed with table summaries."""
        lines = ["Feed (Wrangler GTFS):"]

        for k, v in self.summary.items():
            lines.append(f"  {k}: {v}")

        # Add note about model_node_ids if stops have them
        if hasattr(self, "stops") and self.stops is not None and "stop_id" in self.stops.columns:
            # In Feed, stop_id contains the model_node_id
            unique_nodes = len(self.stops.stop_id.unique())
            lines.append(f"  Model nodes: {unique_nodes} unique nodes referenced")

        return "\n".join(lines)

    def set_by_id(
        self,
        table_name: str,
        set_df: pd.DataFrame,
        id_property: str = "index",
        properties: list[str] | None = None,
    ):
        """Set one or more property values based on an ID property for a given table.

        Args:
            table_name (str): Name of the table to modify.
            set_df (pd.DataFrame): DataFrame with columns `<id_property>` and `value` containing
                values to set for the specified property where `<id_property>` is unique.
            id_property: Property to use as ID to set by. Defaults to "index".
            properties: List of properties to set which are in set_df. If not specified, will set
                all properties.
        """
        if not set_df[id_property].is_unique:
            msg = f"{id_property} must be unique in set_df."
            _dupes = set_df[id_property][set_df[id_property].duplicated()]
            WranglerLogger.error(msg + f"Found duplicates: {_dupes.sum()}")

            raise ValueError(msg)
        table_df = self.get_table(table_name)
        updated_df = update_df_by_col_value(table_df, set_df, id_property, properties=properties)
        self.__dict__[table_name] = updated_df

network_wrangler.transit.feed.feed.Feed.summary property

summary

A high level summary of the GTFS model object and public attributes.

network_wrangler.transit.feed.feed.Feed.__init__

__init__(**kwargs)

Create a Feed object from a dictionary of DataFrames representing a GTFS feed.

Parameters:

Name Type Description Default
kwargs

A dictionary containing DataFrames representing the tables of a GTFS feed.

{}
Source code in network_wrangler/transit/feed/feed.py
def __init__(self, **kwargs):
    """Create a Feed object from a dictionary of DataFrames representing a GTFS feed.

    Args:
        kwargs: A dictionary containing DataFrames representing the tables of a GTFS feed.
    """
    self._net = None
    self.feed_path: Path = None
    self.initialize_tables(**kwargs)

    # Set extra provided attributes but just FYI in logger.
    extra_attr = {k: v for k, v in kwargs.items() if k not in self.table_names}
    if extra_attr:
        WranglerLogger.info(f"Adding additional attributes to Feed: {extra_attr.keys()}")
    for k, v in extra_attr.items():
        self.__setattr__(k, v)

network_wrangler.transit.feed.feed.Feed.__repr__

__repr__()

Return a string representation of the Feed with table summaries.

Source code in network_wrangler/transit/feed/feed.py
def __repr__(self) -> str:
    """Return a string representation of the Feed with table summaries."""
    lines = ["Feed (Wrangler GTFS):"]

    for k, v in self.summary.items():
        lines.append(f"  {k}: {v}")

    # Add note about model_node_ids if stops have them
    if hasattr(self, "stops") and self.stops is not None and "stop_id" in self.stops.columns:
        # In Feed, stop_id contains the model_node_id
        unique_nodes = len(self.stops.stop_id.unique())
        lines.append(f"  Model nodes: {unique_nodes} unique nodes referenced")

    return "\n".join(lines)

network_wrangler.transit.feed.feed.Feed.set_by_id

set_by_id(
    table_name, set_df, id_property="index", properties=None
)

Set one or more property values based on an ID property for a given table.

Parameters:

Name Type Description Default
table_name str

Name of the table to modify.

required
set_df DataFrame

DataFrame with columns <id_property> and value containing values to set for the specified property where <id_property> is unique.

required
id_property str

Property to use as ID to set by. Defaults to “index”.

'index'
properties list[str] | None

List of properties to set which are in set_df. If not specified, will set all properties.

None
Source code in network_wrangler/transit/feed/feed.py
def set_by_id(
    self,
    table_name: str,
    set_df: pd.DataFrame,
    id_property: str = "index",
    properties: list[str] | None = None,
):
    """Set one or more property values based on an ID property for a given table.

    Args:
        table_name (str): Name of the table to modify.
        set_df (pd.DataFrame): DataFrame with columns `<id_property>` and `value` containing
            values to set for the specified property where `<id_property>` is unique.
        id_property: Property to use as ID to set by. Defaults to "index".
        properties: List of properties to set which are in set_df. If not specified, will set
            all properties.
    """
    if not set_df[id_property].is_unique:
        msg = f"{id_property} must be unique in set_df."
        _dupes = set_df[id_property][set_df[id_property].duplicated()]
        WranglerLogger.error(msg + f"Found duplicates: {_dupes.sum()}")

        raise ValueError(msg)
    table_df = self.get_table(table_name)
    updated_df = update_df_by_col_value(table_df, set_df, id_property, properties=properties)
    self.__dict__[table_name] = updated_df

network_wrangler.transit.feed.feed.merge_shapes_to_stop_times

merge_shapes_to_stop_times(stop_times, shapes, trips)

Add shape_id and shape_pt_sequence to stop_times dataframe.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

stop_times dataframe to add shape_id and shape_pt_sequence to.

required
shapes DataFrame[WranglerShapesTable]

shapes dataframe to add to stop_times.

required
trips DataFrame[WranglerTripsTable]

trips dataframe to link stop_times to shapes

required

Returns:

Type Description
DataFrame[WranglerStopTimesTable]

stop_times dataframe with shape_id and shape_pt_sequence added.

Source code in network_wrangler/transit/feed/feed.py
def merge_shapes_to_stop_times(
    stop_times: DataFrame[WranglerStopTimesTable],
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
) -> DataFrame[WranglerStopTimesTable]:
    """Add shape_id and shape_pt_sequence to stop_times dataframe.

    Args:
        stop_times: stop_times dataframe to add shape_id and shape_pt_sequence to.
        shapes: shapes dataframe to add to stop_times.
        trips: trips dataframe to link stop_times to shapes

    Returns:
        stop_times dataframe with shape_id and shape_pt_sequence added.
    """
    stop_times_w_shape_id = stop_times.merge(
        trips[["trip_id", "shape_id"]], on="trip_id", how="left"
    )

    stop_times_w_shapes = stop_times_w_shape_id.merge(
        shapes,
        how="left",
        left_on=["shape_id", "stop_id"],
        right_on=["shape_id", "shape_model_node_id"],
    )
    stop_times_w_shapes = stop_times_w_shapes.drop(columns=["shape_model_node_id"])
    return stop_times_w_shapes

network_wrangler.transit.feed.feed.stop_count_by_trip

stop_count_by_trip(stop_times)

Returns dataframe with trip_id and stop_count from stop_times.

Source code in network_wrangler/transit/feed/feed.py
def stop_count_by_trip(
    stop_times: DataFrame[WranglerStopTimesTable],
) -> pd.DataFrame:
    """Returns dataframe with trip_id and stop_count from stop_times."""
    stops_count = stop_times.groupby("trip_id").size()
    return stops_count.reset_index(name="stop_count")

Filters and queries of a gtfs frequencies table.

network_wrangler.transit.feed.frequencies.frequencies_for_trips

frequencies_for_trips(frequencies, trips)

Filter frequenceis dataframe to records associated with trips table.

Source code in network_wrangler/transit/feed/frequencies.py
def frequencies_for_trips(
    frequencies: DataFrame[WranglerFrequenciesTable], trips: DataFrame[WranglerTripsTable]
) -> DataFrame[WranglerFrequenciesTable]:
    """Filter frequenceis dataframe to records associated with trips table."""
    _sel_trips = trips.trip_id.unique().tolist()
    filtered_frequencies = frequencies[frequencies.trip_id.isin(_sel_trips)]
    WranglerLogger.debug(
        f"Filtered frequencies to {len(filtered_frequencies)}/{len(frequencies)} \
                         records that referenced one of {len(trips)} trips."
    )
    return filtered_frequencies

Filters and queries of a gtfs routes table and route_ids.

network_wrangler.transit.feed.routes.route_ids_for_trip_ids

route_ids_for_trip_ids(trips, trip_ids)

Returns route ids for given list of trip_ids.

Source code in network_wrangler/transit/feed/routes.py
def route_ids_for_trip_ids(trips: DataFrame[WranglerTripsTable], trip_ids: list[str]) -> list[str]:
    """Returns route ids for given list of trip_ids."""
    return trips[trips["trip_id"].isin(trip_ids)].route_id.unique().tolist()

network_wrangler.transit.feed.routes.routes_for_trip_ids

routes_for_trip_ids(routes, trips, trip_ids)

Returns route records for given list of trip_ids.

Source code in network_wrangler/transit/feed/routes.py
def routes_for_trip_ids(
    routes: DataFrame[RoutesTable], trips: DataFrame[WranglerTripsTable], trip_ids: list[str]
) -> DataFrame[RoutesTable]:
    """Returns route records for given list of trip_ids."""
    route_ids = route_ids_for_trip_ids(trips, trip_ids)
    return routes.loc[routes.route_id.isin(route_ids)]

network_wrangler.transit.feed.routes.routes_for_trips

routes_for_trips(routes, trips)

Filter routes dataframe to records associated with trip records.

Source code in network_wrangler/transit/feed/routes.py
def routes_for_trips(
    routes: DataFrame[RoutesTable], trips: DataFrame[WranglerTripsTable]
) -> DataFrame[RoutesTable]:
    """Filter routes dataframe to records associated with trip records."""
    _sel_routes = trips.route_id.unique().tolist()
    filtered_routes = routes[routes.route_id.isin(_sel_routes)]
    WranglerLogger.debug(
        f"Filtered routes to {len(filtered_routes)}/{len(routes)} \
                         records that referenced one of {len(trips)} trips."
    )
    return filtered_routes

Filters, queries of a gtfs shapes table and node patterns.

network_wrangler.transit.feed.shapes.find_nearest_stops

find_nearest_stops(
    shapes,
    trips,
    stop_times,
    trip_id,
    node_id,
    pickup_dropoff="either",
)

Returns node_ids (before and after) of nearest node_ids that are stops for a given trip_id.

Parameters:

Name Type Description Default
shapes WranglerShapesTable

WranglerShapesTable

required
trips WranglerTripsTable

WranglerTripsTable

required
stop_times WranglerStopTimesTable

WranglerStopTimesTable

required
trip_id str

trip id to find nearest stops for

required
node_id int

node_id to find nearest stops for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'

Returns:

Name Type Description
tuple tuple[int, int]

node_ids for stop before and stop after

Source code in network_wrangler/transit/feed/shapes.py
def find_nearest_stops(
    shapes: WranglerShapesTable,
    trips: WranglerTripsTable,
    stop_times: WranglerStopTimesTable,
    trip_id: str,
    node_id: int,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> tuple[int, int]:
    """Returns node_ids (before and after) of nearest node_ids that are stops for a given trip_id.

    Args:
        shapes: WranglerShapesTable
        trips: WranglerTripsTable
        stop_times: WranglerStopTimesTable
        trip_id: trip id to find nearest stops for
        node_id: node_id to find nearest stops for
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0

    Returns:
        tuple: node_ids for stop before and stop after
    """
    shapes = shapes_with_stop_id_for_trip_id(
        shapes, trips, stop_times, trip_id, pickup_dropoff=pickup_dropoff
    )
    WranglerLogger.debug(f"Looking for stops near node_id: {node_id}")
    if node_id not in shapes["shape_model_node_id"].values:
        msg = f"Node ID {node_id} not in shapes for trip {trip_id}"
        raise ValueError(msg)
    # Find index of node_id in shapes
    node_idx = shapes[shapes["shape_model_node_id"] == node_id].index[0]

    # Find stops before and after new stop in shapes sequence
    nodes_before = shapes.loc[: node_idx - 1]
    stops_before = nodes_before.loc[nodes_before["stop_id"].notna()]
    stop_node_before = 0 if stops_before.empty else stops_before.iloc[-1]["shape_model_node_id"]

    nodes_after = shapes.loc[node_idx + 1 :]
    stops_after = nodes_after.loc[nodes_after["stop_id"].notna()]
    stop_node_after = 0 if stops_after.empty else stops_after.iloc[0]["shape_model_node_id"]

    return stop_node_before, stop_node_after

network_wrangler.transit.feed.shapes.node_pattern_for_shape_id

node_pattern_for_shape_id(shapes, shape_id)

Returns node pattern of a shape.

Source code in network_wrangler/transit/feed/shapes.py
def node_pattern_for_shape_id(shapes: DataFrame[WranglerShapesTable], shape_id: str) -> list[int]:
    """Returns node pattern of a shape."""
    shape_df = shapes.loc[shapes["shape_id"] == shape_id]
    shape_df = shape_df.sort_values(by=["shape_pt_sequence"])
    return shape_df["shape_model_node_id"].to_list()

network_wrangler.transit.feed.shapes.shape_id_for_trip_id

shape_id_for_trip_id(trips, trip_id)

Returns a shape_id for a given trip_id.

Source code in network_wrangler/transit/feed/shapes.py
def shape_id_for_trip_id(trips: WranglerTripsTable, trip_id: str) -> str:
    """Returns a shape_id for a given trip_id."""
    return trips.loc[trips.trip_id == trip_id, "shape_id"].values[0]

network_wrangler.transit.feed.shapes.shape_ids_for_trip_ids

shape_ids_for_trip_ids(trips, trip_ids)

Returns a list of shape_ids for a given list of trip_ids.

Source code in network_wrangler/transit/feed/shapes.py
def shape_ids_for_trip_ids(trips: DataFrame[WranglerTripsTable], trip_ids: list[str]) -> list[str]:
    """Returns a list of shape_ids for a given list of trip_ids."""
    return trips[trips["trip_id"].isin(trip_ids)].shape_id.unique().tolist()
shapes_for_road_links(shapes, links_df)

Filter shapes dataframe to records associated with links dataframe.

EX:

shapes = pd.DataFrame({ “shape_id”: [“1”, “1”, “1”, “1”, “2”, “2”, “2”, “2”, “2”], “shape_pt_sequence”: [1, 2, 3, 4, 1, 2, 3, 4, 5], “shape_model_node_id”: [1, 2, 3, 4, 2, 3, 1, 5, 4] })

links_df = pd.DataFrame({ “A”: [1, 2, 3], “B”: [2, 3, 4] })

shapes

shape_id shape_pt_sequence shape_model_node_id should retain 1 1 1 TRUE 1 2 2 TRUE 1 3 3 TRUE 1 4 4 TRUE 1 5 5 FALSE 2 1 1 TRUE 2 2 2 TRUE 2 3 3 TRUE 2 4 1 FALSE 2 5 5 FALSE 2 6 4 FALSE 2 7 1 FALSE - not largest segment 2 8 2 FALSE - not largest segment

links_df

A B 1 2 2 3 3 4

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_road_links(
    shapes: DataFrame[WranglerShapesTable], links_df: pd.DataFrame
) -> DataFrame[WranglerShapesTable]:
    """Filter shapes dataframe to records associated with links dataframe.

    EX:

    > shapes = pd.DataFrame({
        "shape_id": ["1", "1", "1", "1", "2", "2", "2", "2", "2"],
        "shape_pt_sequence": [1, 2, 3, 4, 1, 2, 3, 4, 5],
        "shape_model_node_id": [1, 2, 3, 4, 2, 3, 1, 5, 4]
    })

    > links_df = pd.DataFrame({
        "A": [1, 2, 3],
        "B": [2, 3, 4]
    })

    > shapes

    shape_id   shape_pt_sequence   shape_model_node_id *should retain*
    1          1                  1                        TRUE
    1          2                  2                        TRUE
    1          3                  3                        TRUE
    1          4                  4                        TRUE
    1          5                  5                       FALSE
    2          1                  1                        TRUE
    2          2                  2                        TRUE
    2          3                  3                        TRUE
    2          4                  1                       FALSE
    2          5                  5                       FALSE
    2          6                  4                       FALSE
    2          7                  1                       FALSE - not largest segment
    2          8                  2                       FALSE - not largest segment

    > links_df

    A   B
    1   2
    2   3
    3   4
    """
    """
    > shape_links

    shape_id  shape_pt_sequence_A  shape_model_node_id_A shape_pt_sequence_B shape_model_node_id_B
    1          1                        1                       2                        2
    1          2                        2                       3                        3
    1          3                        3                       4                        4
    1          4                        4                       5                        5
    2          1                        1                       2                        2
    2          2                        2                       3                        3
    2          3                        3                       4                        1
    2          4                        1                       5                        5
    2          5                        5                       6                        4
    2          6                        4                       7                        1
    2          7                        1                       8                        2
    """
    shape_links = shapes_to_shape_links(shapes)

    """
    > shape_links_w_links

    shape_id  shape_pt_sequence_A shape_pt_sequence_B  A  B
    1          1                         2             1  2
    1          2                         3             2  3
    1          3                         4             3  4
    2          1                         2             1  2
    2          2                         3             2  3
    2          7                         8             1  2
    """

    shape_links_w_links = shape_links.merge(
        links_df[["A", "B"]],
        how="inner",
        on=["A", "B"],
    )

    """
    Find largest segment of each shape_id that is in the links

    > longest_shape_segments
    shape_id, segment_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq
    1          1                        1                       4
    2          1                        1                       3
    """
    longest_shape_segments = shape_links_to_longest_shape_segments(shape_links_w_links)

    """
    > shapes

    shape_id   shape_pt_sequence   shape_model_node_id
    1          1                  1
    1          2                  2
    1          3                  3
    1          4                  4
    2          1                  1
    2          2                  2
    2          3                  3
    """
    filtered_shapes = filter_shapes_to_segments(shapes, longest_shape_segments)
    filtered_shapes = filtered_shapes.reset_index(drop=True)
    return filtered_shapes

network_wrangler.transit.feed.shapes.shapes_for_shape_id

shapes_for_shape_id(shapes, shape_id)

Returns shape records for a given shape_id.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_shape_id(
    shapes: DataFrame[WranglerShapesTable], shape_id: str
) -> DataFrame[WranglerShapesTable]:
    """Returns shape records for a given shape_id."""
    shapes = shapes.loc[shapes.shape_id == shape_id]
    return shapes.sort_values(by=["shape_pt_sequence"])

network_wrangler.transit.feed.shapes.shapes_for_trip_id

shapes_for_trip_id(shapes, trips, trip_id)

Returns shape records for a single given trip_id.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_trip_id(
    shapes: DataFrame[WranglerShapesTable], trips: DataFrame[WranglerTripsTable], trip_id: str
) -> DataFrame[WranglerShapesTable]:
    """Returns shape records for a single given trip_id."""
    shape_id = shape_id_for_trip_id(trips, trip_id)
    return shapes.loc[shapes.shape_id == shape_id]

network_wrangler.transit.feed.shapes.shapes_for_trip_ids

shapes_for_trip_ids(shapes, trips, trip_ids)

Returns shape records for list of trip_ids.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_trip_ids(
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
    trip_ids: list[str],
) -> DataFrame[WranglerShapesTable]:
    """Returns shape records for list of trip_ids."""
    shape_ids = shape_ids_for_trip_ids(trips, trip_ids)
    return shapes.loc[shapes.shape_id.isin(shape_ids)]

network_wrangler.transit.feed.shapes.shapes_for_trips

shapes_for_trips(shapes, trips)

Filter shapes dataframe to records associated with trips table.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_trips(
    shapes: DataFrame[WranglerShapesTable], trips: DataFrame[WranglerTripsTable]
) -> DataFrame[WranglerShapesTable]:
    """Filter shapes dataframe to records associated with trips table."""
    _sel_shapes = trips.shape_id.unique().tolist()
    filtered_shapes = shapes[shapes.shape_id.isin(_sel_shapes)]
    WranglerLogger.debug(
        f"Filtered shapes to {len(filtered_shapes)}/{len(shapes)} \
                         records that referenced one of {len(trips)} trips."
    )
    return filtered_shapes

network_wrangler.transit.feed.shapes.shapes_with_stop_id_for_trip_id

shapes_with_stop_id_for_trip_id(
    shapes,
    trips,
    stop_times,
    trip_id,
    pickup_dropoff="either",
)

Returns shapes.txt for a given trip_id with the stop_id added based on pickup_type.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

WranglerShapesTable

required
trips DataFrame[WranglerTripsTable]

WranglerTripsTable

required
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
trip_id str

trip id to select

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'
Source code in network_wrangler/transit/feed/shapes.py
def shapes_with_stop_id_for_trip_id(
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> DataFrame[WranglerShapesTable]:
    """Returns shapes.txt for a given trip_id with the stop_id added based on pickup_type.

    Args:
        shapes: WranglerShapesTable
        trips: WranglerTripsTable
        stop_times: WranglerStopTimesTable
        trip_id: trip id to select
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0
    """
    from .stop_times import stop_times_for_pickup_dropoff_trip_id

    shapes = shapes_for_trip_id(shapes, trips, trip_id)
    trip_stop_times = stop_times_for_pickup_dropoff_trip_id(
        stop_times, trip_id, pickup_dropoff=pickup_dropoff
    )

    stop_times_cols = [
        "stop_id",
        "trip_id",
        "pickup_type",
        "drop_off_type",
    ]

    shape_with_trip_stops = shapes.merge(
        trip_stop_times[stop_times_cols],
        how="left",
        right_on="stop_id",
        left_on="shape_model_node_id",
    )
    shape_with_trip_stops = shape_with_trip_stops.sort_values(by=["shape_pt_sequence"])
    return shape_with_trip_stops

network_wrangler.transit.feed.shapes.shapes_with_stops_for_shape_id

shapes_with_stops_for_shape_id(
    shapes, trips, stop_times, shape_id
)

Returns a DataFrame containing shapes with associated stops for a given shape_id.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

DataFrame containing shape data.

required
trips DataFrame[WranglerTripsTable]

DataFrame containing trip data.

required
stop_times DataFrame[WranglerStopTimesTable]

DataFrame containing stop times data.

required
shape_id str

The shape_id for which to retrieve shapes with stops.

required

Returns:

Type Description
DataFrame[WranglerShapesTable]

DataFrame[WranglerShapesTable]: DataFrame containing shapes with associated stops.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_with_stops_for_shape_id(
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    shape_id: str,
) -> DataFrame[WranglerShapesTable]:
    """Returns a DataFrame containing shapes with associated stops for a given shape_id.

    Parameters:
        shapes (DataFrame[WranglerShapesTable]): DataFrame containing shape data.
        trips (DataFrame[WranglerTripsTable]): DataFrame containing trip data.
        stop_times (DataFrame[WranglerStopTimesTable]): DataFrame containing stop times data.
        shape_id (str): The shape_id for which to retrieve shapes with stops.

    Returns:
        DataFrame[WranglerShapesTable]: DataFrame containing shapes with associated stops.
    """
    from .trips import trip_ids_for_shape_id

    trip_ids = trip_ids_for_shape_id(trips, shape_id)
    all_shape_stop_times = concat_with_attr(
        [shapes_with_stop_id_for_trip_id(shapes, trips, stop_times, t) for t in trip_ids]
    )
    shapes_with_stops = all_shape_stop_times[all_shape_stop_times["stop_id"].notna()]
    shapes_with_stops = shapes_with_stops.sort_values(by=["shape_pt_sequence"])
    return shapes_with_stops

Filters and queries of a gtfs stop_times table.

network_wrangler.transit.feed.stop_times.stop_times_for_longest_segments

stop_times_for_longest_segments(stop_times)

Find the longest segment of each trip_id that is in the stop_times.

Segment ends defined based on interruptions in stop_sequence.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_longest_segments(
    stop_times: DataFrame[WranglerStopTimesTable],
) -> pd.DataFrame:
    """Find the longest segment of each trip_id that is in the stop_times.

    Segment ends defined based on interruptions in `stop_sequence`.
    """
    stop_times = stop_times.sort_values(by=["trip_id", "stop_sequence"])

    stop_times["prev_stop_sequence"] = stop_times.groupby("trip_id")["stop_sequence"].shift(1)
    stop_times["gap"] = (stop_times["stop_sequence"] - stop_times["prev_stop_sequence"]).ne(
        1
    ) | stop_times["prev_stop_sequence"].isna()

    stop_times["segment_id"] = stop_times["gap"].cumsum()
    # WranglerLogger.debug(f"stop_times with segment_id:\n{stop_times}")

    # Calculate the length of each segment
    segment_lengths = (
        stop_times.groupby(["trip_id", "segment_id"]).size().reset_index(name="segment_length")
    )

    # Identify the longest segment for each trip
    idx = segment_lengths.groupby("trip_id")["segment_length"].idxmax()
    longest_segments = segment_lengths.loc[idx]

    # Merge longest segment info back to stop_times
    stop_times = stop_times.merge(
        longest_segments[["trip_id", "segment_id"]],
        on=["trip_id", "segment_id"],
        how="inner",
    )

    # Drop temporary columns used for calculations
    stop_times.drop(columns=["prev_stop_sequence", "gap", "segment_id"], inplace=True)
    # WranglerLogger.debug(f"stop_timesw/longest segments:\n{stop_times}")
    return stop_times

network_wrangler.transit.feed.stop_times.stop_times_for_min_stops

stop_times_for_min_stops(stop_times, min_stops)

Filter stop_times dataframe to only the records which have >= min_stops for the trip.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

stoptimestable to filter

required
min_stops int

minimum stops to require to keep trip in stoptimes

required
Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_min_stops(
    stop_times: DataFrame[WranglerStopTimesTable], min_stops: int
) -> DataFrame[WranglerStopTimesTable]:
    """Filter stop_times dataframe to only the records which have >= min_stops for the trip.

    Args:
        stop_times: stoptimestable to filter
        min_stops: minimum stops to require to keep trip in stoptimes
    """
    stop_ct_by_trip_df = stop_count_by_trip(stop_times)

    # Filter to obtain DataFrame of trips with stop counts >= min_stops
    min_stop_ct_trip_df = stop_ct_by_trip_df[stop_ct_by_trip_df.stop_count >= min_stops]
    if len(min_stop_ct_trip_df) == 0:
        msg = f"No trips meet threshold of minimum stops: {min_stops}"
        raise ValueError(msg)
    WranglerLogger.debug(
        f"Found {len(min_stop_ct_trip_df)} trips with a minimum of {min_stops} stops."
    )

    # Filter the original stop_times DataFrame to only include trips with >= min_stops
    filtered_stop_times = stop_times.merge(
        min_stop_ct_trip_df["trip_id"], on="trip_id", how="inner"
    )
    WranglerLogger.debug(
        f"Filter stop times to {len(filtered_stop_times)}/{len(stop_times)}\
            w/a minimum of {min_stops} stops."
    )

    return filtered_stop_times

network_wrangler.transit.feed.stop_times.stop_times_for_pickup_dropoff_trip_id

stop_times_for_pickup_dropoff_trip_id(
    stop_times, trip_id, pickup_dropoff="either"
)

Filters stop_times for a given trip_id based on pickup type.

GTFS values for pickup_type and drop_off_type” 0 or empty - Regularly scheduled pickup/dropoff. 1 - No pickup/dropoff available. 2 - Must phone agency to arrange pickup/dropoff. 3 - Must coordinate with driver to arrange pickup/dropoff.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

A WranglerStopTimesTable to query.

required
trip_id str

trip_id to get stop pattern for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on pickup and dropoff availability at stop. Defaults to “either”. “any”: all stoptime records “either”: either pickup_type or dropoff_type != 1 “both”: both pickup_type and dropoff_type != 1 “pickup_only”: dropoff = 1; pickup != 1 “dropoff_only”: pickup = 1; dropoff != 1

'either'
Source code in network_wrangler/transit/feed/stop_times.py
@validate_call_pyd
def stop_times_for_pickup_dropoff_trip_id(
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> DataFrame[WranglerStopTimesTable]:
    """Filters stop_times for a given trip_id based on pickup type.

    GTFS values for pickup_type and drop_off_type"
        0 or empty - Regularly scheduled pickup/dropoff.
        1 - No pickup/dropoff available.
        2 - Must phone agency to arrange pickup/dropoff.
        3 - Must coordinate with driver to arrange pickup/dropoff.

    Args:
        stop_times: A WranglerStopTimesTable to query.
        trip_id: trip_id to get stop pattern for
        pickup_dropoff: str indicating logic for selecting stops based on pickup and dropoff
            availability at stop. Defaults to "either".
            "any": all stoptime records
            "either": either pickup_type or dropoff_type != 1
            "both": both pickup_type and dropoff_type != 1
            "pickup_only": dropoff = 1; pickup != 1
            "dropoff_only":  pickup = 1; dropoff != 1
    """
    trip_stop_pattern = stop_times_for_trip_id(stop_times, trip_id)

    if pickup_dropoff == "any":
        return trip_stop_pattern

    pickup_type_selection = {
        "either": (trip_stop_pattern.pickup_type != 1) | (trip_stop_pattern.drop_off_type != 1),
        "both": (trip_stop_pattern.pickup_type != 1) & (trip_stop_pattern.drop_off_type != 1),
        "pickup_only": (trip_stop_pattern.pickup_type != 1)
        & (trip_stop_pattern.drop_off_type == 1),
        "dropoff_only": (trip_stop_pattern.drop_off_type != 1)
        & (trip_stop_pattern.pickup_type == 1),
    }

    selection = pickup_type_selection[pickup_dropoff]
    trip_stops = trip_stop_pattern[selection]

    return trip_stops

network_wrangler.transit.feed.stop_times.stop_times_for_route_ids

stop_times_for_route_ids(stop_times, trips, route_ids)

Returns a stop_time records for a list of route_ids.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_route_ids(
    stop_times: DataFrame[WranglerStopTimesTable],
    trips: DataFrame[WranglerTripsTable],
    route_ids: list[str],
) -> DataFrame[WranglerStopTimesTable]:
    """Returns a stop_time records for a list of route_ids."""
    trip_ids = trips.loc[trips.route_id.isin(route_ids)].trip_id.unique()
    return stop_times_for_trip_ids(stop_times, trip_ids)

network_wrangler.transit.feed.stop_times.stop_times_for_shapes

stop_times_for_shapes(stop_times, shapes, trips)

Filter stop_times dataframe to records associated with shapes dataframe.

Where multiple segments of stop_times are found to match shapes, retain only the longest.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

stop_times dataframe to filter

required
shapes DataFrame[WranglerShapesTable]

shapes dataframe to stop_times to.

required
trips DataFrame[WranglerTripsTable]

trips to link stop_times to shapess

required

Returns:

Type Description
DataFrame[WranglerStopTimesTable]

filtered stop_times dataframe

  • should be retained

    stop_times

trip_id stop_sequence stop_id t1 1 1 t1 2 2 t1 3 3 t1 4 5 t2 1 1 *t2 2 3 t2 3 7

shapes

shape_id shape_pt_sequence shape_model_node_id s1 1 1 s1 2 2 s1 3 3 s1 4 4 s2 1 1 s2 2 2 s2 3 3

trips

trip_id shape_id t1 s1 t2 s2

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_shapes(
    stop_times: DataFrame[WranglerStopTimesTable],
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
) -> DataFrame[WranglerStopTimesTable]:
    """Filter stop_times dataframe to records associated with shapes dataframe.

    Where multiple segments of stop_times are found to match shapes, retain only the longest.

    Args:
        stop_times: stop_times dataframe to filter
        shapes: shapes dataframe to stop_times to.
        trips: trips to link stop_times to shapess

    Returns:
        filtered stop_times dataframe

    EX:
    * should be retained
    > stop_times

    trip_id   stop_sequence   stop_id
    *t1          1                  1
    *t1          2                  2
    *t1          3                  3
    t1           4                  5
    *t2          1                  1
    *t2          2                  3
    t2           3                  7

    > shapes

    shape_id   shape_pt_sequence   shape_model_node_id
    s1          1                  1
    s1          2                  2
    s1          3                  3
    s1          4                  4
    s2          1                  1
    s2          2                  2
    s2          3                  3

    > trips

    trip_id   shape_id
    t1          s1
    t2          s2
    """
    """
    > stop_times_w_shapes

    trip_id   stop_sequence   stop_id    shape_id   shape_pt_sequence
    *t1          1                  1        s1          1
    *t1          2                  2        s1          2
    *t1          3                  3        s1          3
    t1           4                  5        NA          NA
    *t2          1                  1        s2          1
    *t2          2                  3        s2          2
    t2           3                  7        NA          NA

    """
    stop_times_w_shapes = merge_shapes_to_stop_times(stop_times, shapes, trips)
    # WranglerLogger.debug(f"stop_times_w_shapes :\n{stop_times_w_shapes}")
    """
    > stop_times_w_shapes

    trip_id   stop_sequence   stop_id   shape_id   shape_pt_sequence
    *t1          1               1        s1          1
    *t1          2               2        s1          2
    *t1          3               3        s1          3
    *t2          1               1        s2          1
    *t2          2               3        s2          2

    """
    filtered_stop_times = stop_times_w_shapes[stop_times_w_shapes["shape_pt_sequence"].notna()]
    # WranglerLogger.debug(f"filtered_stop_times:\n{filtered_stop_times}")

    # Filter out any stop_times the shape_pt_sequence is not ascending
    valid_stop_times = filtered_stop_times.groupby("trip_id").filter(
        lambda x: x["shape_pt_sequence"].is_monotonic_increasing
    )
    # WranglerLogger.debug(f"valid_stop_times:\n{valid_stop_times}")

    valid_stop_times = valid_stop_times.drop(columns=["shape_id", "shape_pt_sequence"])

    longest_valid_stop_times = stop_times_for_longest_segments(valid_stop_times)
    longest_valid_stop_times = longest_valid_stop_times.reset_index(drop=True)

    return longest_valid_stop_times

network_wrangler.transit.feed.stop_times.stop_times_for_stops

stop_times_for_stops(stop_times, stops)

Filter stop_times dataframe to only have stop_times associated with stops records.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_stops(
    stop_times: DataFrame[WranglerStopTimesTable], stops: DataFrame[WranglerStopsTable]
) -> DataFrame[WranglerStopTimesTable]:
    """Filter stop_times dataframe to only have stop_times associated with stops records."""
    _sel_stops = stops.stop_id.unique().tolist()
    filtered_stop_times = stop_times[stop_times.stop_id.isin(_sel_stops)]
    WranglerLogger.debug(
        f"Filtered stop_times to {len(filtered_stop_times)}/{len(stop_times)} \
                         records that referenced one of {len(stops)} stops."
    )
    return filtered_stop_times

network_wrangler.transit.feed.stop_times.stop_times_for_trip_id

stop_times_for_trip_id(stop_times, trip_id)

Returns a stop_time records for a given trip_id.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_trip_id(
    stop_times: DataFrame[WranglerStopTimesTable], trip_id: str
) -> DataFrame[WranglerStopTimesTable]:
    """Returns a stop_time records for a given trip_id."""
    stop_times = stop_times.loc[stop_times.trip_id == trip_id]
    return stop_times.sort_values(by=["stop_sequence"])

network_wrangler.transit.feed.stop_times.stop_times_for_trip_ids

stop_times_for_trip_ids(stop_times, trip_ids)

Returns a stop_time records for a given list of trip_ids.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_trip_ids(
    stop_times: DataFrame[WranglerStopTimesTable], trip_ids: list[str]
) -> DataFrame[WranglerStopTimesTable]:
    """Returns a stop_time records for a given list of trip_ids."""
    stop_times = stop_times.loc[stop_times.trip_id.isin(trip_ids)]
    return stop_times.sort_values(by=["stop_sequence"])

network_wrangler.transit.feed.stop_times.stop_times_for_trip_node_segment

stop_times_for_trip_node_segment(
    stop_times,
    trip_id,
    node_id_start,
    node_id_end,
    include_start=True,
    include_end=True,
)

Returns stop_times for a given trip_id between two nodes or with those nodes included.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
trip_id str

trip id to select

required
node_id_start int

int of the starting node

required
node_id_end int

int of the ending node

required
include_start bool

bool indicating if the start node should be included in the segment. Defaults to True.

True
include_end bool

bool indicating if the end node should be included in the segment. Defaults to True.

True
Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_trip_node_segment(
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    node_id_start: int,
    node_id_end: int,
    include_start: bool = True,
    include_end: bool = True,
) -> DataFrame[WranglerStopTimesTable]:
    """Returns stop_times for a given trip_id between two nodes or with those nodes included.

    Args:
        stop_times: WranglerStopTimesTable
        trip_id: trip id to select
        node_id_start: int of the starting node
        node_id_end: int of the ending node
        include_start: bool indicating if the start node should be included in the segment.
            Defaults to True.
        include_end: bool indicating if the end node should be included in the segment.
            Defaults to True.
    """
    stop_times = stop_times_for_trip_id(stop_times, trip_id)
    start_idx = stop_times[stop_times["stop_id"] == node_id_start].index[0]
    end_idx = stop_times[stop_times["stop_id"] == node_id_end].index[0]
    if not include_start:
        start_idx += 1
    if include_end:
        end_idx += 1
    return stop_times.loc[start_idx:end_idx]

Filters and queries of a gtfs stops table and stop_ids.

network_wrangler.transit.feed.stops.node_is_stop

node_is_stop(
    stops,
    stop_times,
    node_id,
    trip_id,
    pickup_dropoff="either",
)

Returns boolean indicating if a (or list of) node(s)) is (are) stops for a given trip_id.

Parameters:

Name Type Description Default
stops DataFrame[WranglerStopsTable]

WranglerStopsTable

required
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
node_id int | list[int]

node ID for roadway

required
trip_id str

trip_id to get stop pattern for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'
Source code in network_wrangler/transit/feed/stops.py
def node_is_stop(
    stops: DataFrame[WranglerStopsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    node_id: int | list[int],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> bool | list[bool]:
    """Returns boolean indicating if a (or list of) node(s)) is (are) stops for a given trip_id.

    Args:
        stops: WranglerStopsTable
        stop_times: WranglerStopTimesTable
        node_id: node ID for roadway
        trip_id: trip_id to get stop pattern for
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0
    """
    trip_stop_nodes = stops_for_trip_id(stops, stop_times, trip_id, pickup_dropoff=pickup_dropoff)[
        "stop_id"
    ]
    if isinstance(node_id, list):
        return [n in trip_stop_nodes.values for n in node_id]
    return node_id in trip_stop_nodes.values

network_wrangler.transit.feed.stops.stop_id_pattern_for_trip

stop_id_pattern_for_trip(
    stop_times, trip_id, pickup_dropoff="either"
)

Returns a stop pattern for a given trip_id given by a list of stop_ids.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
trip_id str

trip_id to get stop pattern for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'
Source code in network_wrangler/transit/feed/stops.py
@validate_call_pyd
def stop_id_pattern_for_trip(
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> list[str]:
    """Returns a stop pattern for a given trip_id given by a list of stop_ids.

    Args:
        stop_times: WranglerStopTimesTable
        trip_id: trip_id to get stop pattern for
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0
    """
    from .stop_times import stop_times_for_pickup_dropoff_trip_id

    trip_stops = stop_times_for_pickup_dropoff_trip_id(
        stop_times, trip_id, pickup_dropoff=pickup_dropoff
    )
    return trip_stops.stop_id.to_list()

network_wrangler.transit.feed.stops.stops_for_stop_times

stops_for_stop_times(stops, stop_times)

Filter stops dataframe to only have stops associated with stop_times records.

Source code in network_wrangler/transit/feed/stops.py
def stops_for_stop_times(
    stops: DataFrame[WranglerStopsTable], stop_times: DataFrame[WranglerStopTimesTable]
) -> DataFrame[WranglerStopsTable]:
    """Filter stops dataframe to only have stops associated with stop_times records."""
    _sel_stops_ge_min = stop_times.stop_id.unique().tolist()
    filtered_stops = stops[stops.stop_id.isin(_sel_stops_ge_min)]
    WranglerLogger.debug(
        f"Filtered stops to {len(filtered_stops)}/{len(stops)} \
                         records that referenced one of {len(stop_times)} stop_times."
    )
    return filtered_stops

network_wrangler.transit.feed.stops.stops_for_trip_id

stops_for_trip_id(
    stops, stop_times, trip_id, pickup_dropoff="any"
)

Returns stops.txt which are used for a given trip_id.

Source code in network_wrangler/transit/feed/stops.py
def stops_for_trip_id(
    stops: DataFrame[WranglerStopsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "any",
) -> DataFrame[WranglerStopsTable]:
    """Returns stops.txt which are used for a given trip_id."""
    stop_ids = stop_id_pattern_for_trip(stop_times, trip_id, pickup_dropoff=pickup_dropoff)
    return stops.loc[stops.stop_id.isin(stop_ids)]

Filters and queries of a gtfs trips table and trip_ids.

network_wrangler.transit.feed.trips.trip_ids_for_shape_id

trip_ids_for_shape_id(trips, shape_id)

Returns a list of trip_ids for a given shape_id.

Source code in network_wrangler/transit/feed/trips.py
def trip_ids_for_shape_id(trips: DataFrame[WranglerTripsTable], shape_id: str) -> list[str]:
    """Returns a list of trip_ids for a given shape_id."""
    return trips_for_shape_id(trips, shape_id)["trip_id"].unique().tolist()

network_wrangler.transit.feed.trips.trips_for_shape_id

trips_for_shape_id(trips, shape_id)

Returns a trips records for a given shape_id.

Source code in network_wrangler/transit/feed/trips.py
def trips_for_shape_id(
    trips: DataFrame[WranglerTripsTable], shape_id: str
) -> DataFrame[WranglerTripsTable]:
    """Returns a trips records for a given shape_id."""
    return trips.loc[trips.shape_id == shape_id]

network_wrangler.transit.feed.trips.trips_for_stop_times

trips_for_stop_times(trips, stop_times)

Filter trips dataframe to records associated with stop_time records.

Source code in network_wrangler/transit/feed/trips.py
def trips_for_stop_times(
    trips: DataFrame[WranglerTripsTable], stop_times: DataFrame[WranglerStopTimesTable]
) -> DataFrame[WranglerTripsTable]:
    """Filter trips dataframe to records associated with stop_time records."""
    _sel_trips = stop_times.trip_id.unique().tolist()
    filtered_trips = trips[trips.trip_id.isin(_sel_trips)]
    WranglerLogger.debug(
        f"Filtered trips to {len(filtered_trips)}/{len(trips)} \
                         records that referenced one of {len(stop_times)} stop_times."
    )
    return filtered_trips

Functions for translating transit tables into visualizable links relatable to roadway network.

shapes_to_shape_links(shapes)

Converts shapes DataFrame to shape links DataFrame.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

The input shapes DataFrame.

required

Returns:

Type Description
DataFrame

pd.DataFrame: The resulting shape links DataFrame.

Source code in network_wrangler/transit/feed/transit_links.py
def shapes_to_shape_links(shapes: DataFrame[WranglerShapesTable]) -> pd.DataFrame:
    """Converts shapes DataFrame to shape links DataFrame.

    Args:
        shapes (DataFrame[WranglerShapesTable]): The input shapes DataFrame.

    Returns:
        pd.DataFrame: The resulting shape links DataFrame.
    """
    return point_seq_to_links(
        shapes,
        id_field="shape_id",
        seq_field="shape_pt_sequence",
        node_id_field="shape_model_node_id",
    )
stop_times_to_stop_times_links(
    stop_times, from_field="A", to_field="B"
)

Converts stop times to stop times links.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

The stop times data.

required
from_field str

The name of the field representing the ‘from’ stop. Defaults to “A”.

'A'
to_field str

The name of the field representing the ‘to’ stop. Defaults to “B”.

'B'

Returns:

Type Description
DataFrame

pd.DataFrame: The resulting stop times links.

Source code in network_wrangler/transit/feed/transit_links.py
def stop_times_to_stop_times_links(
    stop_times: DataFrame[WranglerStopTimesTable],
    from_field: str = "A",
    to_field: str = "B",
) -> pd.DataFrame:
    """Converts stop times to stop times links.

    Args:
        stop_times (DataFrame[WranglerStopTimesTable]): The stop times data.
        from_field (str, optional): The name of the field representing the 'from' stop.
            Defaults to "A".
        to_field (str, optional): The name of the field representing the 'to' stop.
            Defaults to "B".

    Returns:
        pd.DataFrame: The resulting stop times links.
    """
    return point_seq_to_links(
        stop_times,
        id_field="trip_id",
        seq_field="stop_sequence",
        node_id_field="stop_id",
        from_field=from_field,
        to_field=to_field,
    )
unique_shape_links(shapes, from_field='A', to_field='B')

Returns a DataFrame containing unique shape links based on the provided shapes DataFrame.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

The input DataFrame containing shape information.

required
from_field str

The name of the column representing the ‘from’ field. Defaults to “A”.

'A'
to_field str

The name of the column representing the ‘to’ field. Defaults to “B”.

'B'

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame containing unique shape links based on the provided shapes df.

Source code in network_wrangler/transit/feed/transit_links.py
def unique_shape_links(
    shapes: DataFrame[WranglerShapesTable], from_field: str = "A", to_field: str = "B"
) -> pd.DataFrame:
    """Returns a DataFrame containing unique shape links based on the provided shapes DataFrame.

    Parameters:
        shapes (DataFrame[WranglerShapesTable]): The input DataFrame containing shape information.
        from_field (str, optional): The name of the column representing the 'from' field.
            Defaults to "A".
        to_field (str, optional): The name of the column representing the 'to' field.
            Defaults to "B".

    Returns:
        pd.DataFrame: DataFrame containing unique shape links based on the provided shapes df.
    """
    shape_links = shapes_to_shape_links(shapes)
    # WranglerLogger.debug(f"Shape links: \n {shape_links[['shape_id', from_field, to_field]]}")

    _agg_dict: dict[str, type | str] = {"shape_id": list}
    _opt_fields = [f"shape_pt_{v}_{t}" for v in ["lat", "lon"] for t in [from_field, to_field]]
    for f in _opt_fields:
        if f in shape_links:
            _agg_dict[f] = "first"

    unique_shape_links = shape_links.groupby([from_field, to_field]).agg(_agg_dict).reset_index()
    return unique_shape_links
unique_stop_time_links(
    stop_times, from_field="A", to_field="B"
)

Returns a DataFrame containing unique stop time links based on the given stop times DataFrame.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

The DataFrame containing stop times data.

required
from_field str

The name of the column representing the ‘from’ field in the stop times DataFrame. Defaults to “A”.

'A'
to_field str

The name of the column representing the ‘to’ field in the stop times DataFrame. Defaults to “B”.

'B'

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing unique stop time links with columns ‘from_field’, ‘to_field’, and ‘trip_id’.

Source code in network_wrangler/transit/feed/transit_links.py
def unique_stop_time_links(
    stop_times: DataFrame[WranglerStopTimesTable],
    from_field: str = "A",
    to_field: str = "B",
) -> pd.DataFrame:
    """Returns a DataFrame containing unique stop time links based on the given stop times DataFrame.

    Parameters:
        stop_times (DataFrame[WranglerStopTimesTable]): The DataFrame containing stop times data.
        from_field (str, optional): The name of the column representing the 'from' field in the
            stop times DataFrame. Defaults to "A".
        to_field (str, optional): The name of the column representing the 'to' field in the stop
            times DataFrame. Defaults to "B".

    Returns:
        pd.DataFrame: A DataFrame containing unique stop time links with columns 'from_field',
            'to_field', and 'trip_id'.
    """
    links = stop_times_to_stop_times_links(stop_times, from_field=from_field, to_field=to_field)
    unique_links = links.groupby([from_field, to_field])["trip_id"].apply(list).reset_index()
    return unique_links

Functions to create segments from shapes and shape_links.

network_wrangler.transit.feed.transit_segments.filter_shapes_to_segments

filter_shapes_to_segments(shapes, segments)

Filter shapes dataframe to records associated with segments dataframe.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

shapes dataframe to filter

required
segments DataFrame

segments dataframe to filter by with shape_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq . Should have one record per shape_id.

required

Returns:

Type Description
DataFrame[WranglerShapesTable]

filtered shapes dataframe

Source code in network_wrangler/transit/feed/transit_segments.py
def filter_shapes_to_segments(
    shapes: DataFrame[WranglerShapesTable], segments: pd.DataFrame
) -> DataFrame[WranglerShapesTable]:
    """Filter shapes dataframe to records associated with segments dataframe.

    Args:
        shapes: shapes dataframe to filter
        segments: segments dataframe to filter by with shape_id, segment_start_shape_pt_seq,
            segment_end_shape_pt_seq . Should have one record per shape_id.

    Returns:
        filtered shapes dataframe
    """
    shapes_w_segs = shapes.merge(segments, on="shape_id", how="left")

    # Retain only those points within the segment sequences
    filtered_shapes = shapes_w_segs[
        (shapes_w_segs["shape_pt_sequence"] >= shapes_w_segs["segment_start_shape_pt_seq"])
        & (shapes_w_segs["shape_pt_sequence"] <= shapes_w_segs["segment_end_shape_pt_seq"])
    ]

    drop_cols = [
        "segment_id",
        "segment_start_shape_pt_seq",
        "segment_end_shape_pt_seq",
        "segment_length",
    ]
    filtered_shapes = filtered_shapes.drop(columns=drop_cols)

    return filtered_shapes
shape_links_to_longest_shape_segments(shape_links)

Find the longest segment of each shape_id that is in the links.

Parameters:

Name Type Description Default
shape_links

DataFrame with shape_id, shape_pt_sequence_A, shape_pt_sequence_B

required

Returns:

Type Description
DataFrame

DataFrame with shape_id, segment_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq

Source code in network_wrangler/transit/feed/transit_segments.py
def shape_links_to_longest_shape_segments(shape_links) -> pd.DataFrame:
    """Find the longest segment of each shape_id that is in the links.

    Args:
        shape_links: DataFrame with shape_id, shape_pt_sequence_A, shape_pt_sequence_B

    Returns:
        DataFrame with shape_id, segment_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq
    """
    segments = shape_links_to_segments(shape_links)
    idx = segments.groupby("shape_id")["segment_length"].idxmax()
    longest_segments = segments.loc[idx]
    return longest_segments
shape_links_to_segments(shape_links)

Convert shape_links to segments by shape_id with segments of continuous shape_pt_sequence.

DataFrame with shape_id, segment_id, segment_start_shape_pt_seq,

Type Description
DataFrame

segment_end_shape_pt_seq

Source code in network_wrangler/transit/feed/transit_segments.py
def shape_links_to_segments(shape_links) -> pd.DataFrame:
    """Convert shape_links to segments by shape_id with segments of continuous shape_pt_sequence.

    Returns: DataFrame with shape_id, segment_id, segment_start_shape_pt_seq,
        segment_end_shape_pt_seq
    """
    shape_links["gap"] = shape_links.groupby("shape_id")["shape_pt_sequence_A"].diff().gt(1)
    shape_links["segment_id"] = shape_links.groupby("shape_id")["gap"].cumsum()

    # Define segment starts and ends
    segment_definitions = (
        shape_links.groupby(["shape_id", "segment_id"])
        .agg(
            segment_start_shape_pt_seq=("shape_pt_sequence_A", "min"),
            segment_end_shape_pt_seq=("shape_pt_sequence_B", "max"),
        )
        .reset_index()
    )

    # Optionally calculate segment lengths for further uses
    segment_definitions["segment_length"] = (
        segment_definitions["segment_end_shape_pt_seq"]
        - segment_definitions["segment_start_shape_pt_seq"]
        + 1
    )

    return segment_definitions

Transit Projects

Functions for adding a transit route to a TransitNetwork.

network_wrangler.transit.projects.add_route.apply_transit_route_addition

apply_transit_route_addition(
    net, transit_route_addition, reference_road_net=None
)

Add transit route to TransitNetwork.

Parameters:

Name Type Description Default
net TransitNetwork

Network to modify.

required
transit_route_addition dict

route dictionary to add to the feed.

required
reference_road_net RoadwayNetwork | None

(RoadwayNetwork, optional): Reference roadway network to use for adding shapes and stops. Defaults to None.

None

Returns:

Name Type Description
TransitNetwork TransitNetwork

Modified network.

Source code in network_wrangler/transit/projects/add_route.py
def apply_transit_route_addition(
    net: TransitNetwork,
    transit_route_addition: dict,
    reference_road_net: RoadwayNetwork | None = None,
) -> TransitNetwork:
    """Add transit route to TransitNetwork.

    Args:
        net (TransitNetwork): Network to modify.
        transit_route_addition: route dictionary to add to the feed.
        reference_road_net: (RoadwayNetwork, optional): Reference roadway network to use for adding shapes and stops. Defaults to None.

    Returns:
        TransitNetwork: Modified network.
    """
    WranglerLogger.debug("Applying add transit route project.")

    add_routes = transit_route_addition["routes"]

    road_net = net.road_net if reference_road_net is None else reference_road_net
    if road_net is None:
        WranglerLogger.error(
            "! Must have a reference road network set in order to update transit \
                         routin.  Either provide as an input to this function or set it for the \
                         transit network: >> transit_net.road_net = ..."
        )
        msg = "Must have a reference road network set in order to update transit routing."
        raise TransitRouteAddError(msg)

    net.feed = _add_route_to_feed(net.feed, add_routes, road_net)

    return net

Module for applying calculated transit projects to a transit network object.

These projects are stored in project card pycode property as python code strings which are executed to change the transit network object.

network_wrangler.transit.projects.calculate.apply_calculated_transit

apply_calculated_transit(net, pycode)

Changes transit network object by executing pycode.

Parameters:

Name Type Description Default
net TransitNetwork

transit network to manipulate

required
pycode str

python code which changes values in the transit network object

required
Source code in network_wrangler/transit/projects/calculate.py
def apply_calculated_transit(
    net: TransitNetwork,
    pycode: str,
) -> TransitNetwork:
    """Changes transit network object by executing pycode.

    Args:
        net: transit network to manipulate
        pycode: python code which changes values in the transit network object
    """
    WranglerLogger.debug("Applying calculated transit project.")
    exec(pycode)

    return net

Functions for adding a transit route to a TransitNetwork.

network_wrangler.transit.projects.delete_service.apply_transit_service_deletion

apply_transit_service_deletion(
    net, selection, clean_shapes=False, clean_routes=False
)

Delete transit service to TransitNetwork.

Parameters:

Name Type Description Default
net TransitNetwork

Network to modify.

required
selection TransitSelection

TransitSelection object, created from a selection dictionary.

required
clean_shapes bool

If True, remove shapes not used by any trips. Defaults to False.

False
clean_routes bool

If True, remove routes not used by any trips. Defaults to False.

False

Returns:

Name Type Description
TransitNetwork TransitNetwork

Modified network.

Source code in network_wrangler/transit/projects/delete_service.py
def apply_transit_service_deletion(
    net: TransitNetwork,
    selection: TransitSelection,
    clean_shapes: bool | None = False,
    clean_routes: bool | None = False,
) -> TransitNetwork:
    """Delete transit service to TransitNetwork.

    Args:
        net (TransitNetwork): Network to modify.
        selection: TransitSelection object, created from a selection dictionary.
        clean_shapes (bool, optional): If True, remove shapes not used by any trips.
            Defaults to False.
        clean_routes (bool, optional): If True, remove routes not used by any trips.
            Defaults to False.

    Returns:
        TransitNetwork: Modified network.
    """
    WranglerLogger.debug("Applying delete transit service project.")

    trip_ids = selection.selected_trips
    net.feed = _delete_trips_from_feed(
        net.feed, trip_ids, clean_shapes=clean_shapes, clean_routes=clean_routes
    )

    return net

Functions for editing transit properties in a TransitNetwork.

network_wrangler.transit.projects.edit_property.apply_transit_property_change

apply_transit_property_change(
    net, selection, property_changes, project_name=None
)

Apply changes to transit properties.

Parameters:

Name Type Description Default
net TransitNetwork

Network to modify.

required
selection TransitSelection

Selection of trips to modify.

required
property_changes dict

Dictionary of properties to change.

required
project_name str

Name of the project. Defaults to None.

None

Returns:

Name Type Description
TransitNetwork TransitNetwork

Modified network.

Source code in network_wrangler/transit/projects/edit_property.py
def apply_transit_property_change(
    net: TransitNetwork,
    selection: TransitSelection,
    property_changes: dict,
    project_name: str | None = None,
) -> TransitNetwork:
    """Apply changes to transit properties.

    Args:
        net (TransitNetwork): Network to modify.
        selection (TransitSelection): Selection of trips to modify.
        property_changes (dict): Dictionary of properties to change.
        project_name (str, optional): Name of the project. Defaults to None.

    Returns:
        TransitNetwork: Modified network.
    """
    WranglerLogger.debug("Applying transit property change project.")
    for property, property_change in property_changes.items():
        net = _apply_transit_property_change_to_table(
            net,
            selection,
            property,
            property_change,
            project_name=project_name,
        )
    return net

Functions for editing the transit route shapes and stop patterns.

network_wrangler.transit.projects.edit_routing.apply_transit_routing_change

apply_transit_routing_change(
    net,
    selection,
    routing_change,
    reference_road_net=None,
    project_name=None,
)

Apply a routing change to the transit network, including stop updates.

Parameters:

Name Type Description Default
net TransitNetwork

TransitNetwork object to apply routing change to.

required
selection Selection

TransitSelection object, created from a selection dictionary.

required
routing_change dict

Routing Change dictionary, e.g.

{
    "existing": [46665, 150855],
    "set": [-46665, 150855, 46665, 150855],
}

required
shape_id_scalar int

Initial scalar value to add to duplicated shape_ids to create a new shape_id. Defaults to SHAPE_ID_SCALAR.

required
reference_road_net RoadwayNetwork

Reference roadway network to use for updating shapes and stops. Defaults to None.

None
project_name str

Name of the project. Defaults to None.

None
Source code in network_wrangler/transit/projects/edit_routing.py
def apply_transit_routing_change(
    net: TransitNetwork,
    selection: TransitSelection,
    routing_change: dict,
    reference_road_net: RoadwayNetwork | None = None,
    project_name: str | None = None,
) -> TransitNetwork:
    """Apply a routing change to the transit network, including stop updates.

    Args:
        net (TransitNetwork): TransitNetwork object to apply routing change to.
        selection (Selection): TransitSelection object, created from a selection dictionary.
        routing_change (dict): Routing Change dictionary, e.g.
            ```python
            {
                "existing": [46665, 150855],
                "set": [-46665, 150855, 46665, 150855],
            }
            ```
        shape_id_scalar (int, optional): Initial scalar value to add to duplicated shape_ids to
            create a new shape_id. Defaults to SHAPE_ID_SCALAR.
        reference_road_net (RoadwayNetwork, optional): Reference roadway network to use for
            updating shapes and stops. Defaults to None.
        project_name (str, optional): Name of the project. Defaults to None.
    """
    WranglerLogger.debug("Applying transit routing change project.")
    WranglerLogger.debug(f"...selection: {selection.selection_dict}")
    WranglerLogger.debug(f"...routing: {routing_change}")

    # ---- Secure all inputs needed --------------
    updated_feed = copy.deepcopy(net.feed)
    trip_ids = selection.selected_trips
    if project_name:
        updated_feed.trips.loc[updated_feed.trips.trip_id.isin(trip_ids), "projects"] += (
            f"{project_name},"
        )

    road_net = net.road_net if reference_road_net is None else reference_road_net
    if road_net is None:
        WranglerLogger.error(
            "! Must have a reference road network set in order to update transit \
                         routin.  Either provide as an input to this function or set it for the \
                         transit network: >> transit_net.road_net = ..."
        )
        msg = "Must have a reference road network set in order to update transit routing."
        raise TransitRoutingChangeError(msg)

    # ---- update each shape that is used by selected trips to use new routing -------
    shape_ids = shape_ids_for_trip_ids(updated_feed.trips, trip_ids)
    # WranglerLogger.debug(f"shape_ids: {shape_ids}")
    for shape_id in shape_ids:
        updated_feed.shapes, updated_feed.trips = _update_shapes_and_trips(
            updated_feed,
            shape_id,
            trip_ids,
            routing_change["set"],
            net.config.IDS.TRANSIT_SHAPE_ID_SCALAR,
            road_net,
            routing_existing=routing_change.get("existing", []),
            project_name=project_name,
        )
    # WranglerLogger.debug(f"updated_feed.shapes: \n{updated_feed.shapes}")
    # WranglerLogger.debug(f"updated_feed.trips: \n{updated_feed.trips}")
    # ---- Check if any stops need adding to stops.txt and add if they do ----------
    updated_feed.stops = _update_stops(
        updated_feed, routing_change["set"], road_net, project_name=project_name
    )
    # WranglerLogger.debug(f"updated_feed.stops: \n{updated_feed.stops}")
    # ---- Update stop_times --------------------------------------------------------
    for trip_id in trip_ids:
        updated_feed.stop_times = _update_stop_times_for_trip(
            updated_feed,
            trip_id,
            routing_change["set"],
            routing_change.get("existing", []),
        )

    # ---- Check result -------------------------------------------------------------
    _show_col = [
        "trip_id",
        "stop_id",
        "stop_sequence",
        "departure_time",
        "arrival_time",
    ]
    _ex_stoptimes = updated_feed.stop_times.loc[
        updated_feed.stop_times.trip_id == trip_ids[0], _show_col
    ]
    # WranglerLogger.debug(f"stop_times for first updated trip: \n {_ex_stoptimes}")

    # ---- Update transit network with updated feed.
    net.feed = updated_feed
    # WranglerLogger.debug(f"net.feed.stops: \n {net.feed.stops}")
    return net

Transit Helper Modules

Utilities for getting GTFS into wrangler.

network_wrangler.utils.transit.add_additional_data_to_shapes

add_additional_data_to_shapes(
    feed_tables, local_crs, crs_units, trace_shape_ids=None
)

Updates feed_tables[‘shapes’] with route/trip metadata and snaps shape points to stops.

Enriches shape points with information from trips, routes, and agencies tables, then matches shape points to nearby stops and updates their locations. Processes stops in sequence order, always searching forward to handle routes that double back.

Process Steps:

  1. Converts shapes to GeoDataFrame if needed (using shape_pt_lon/lat)
  2. Joins with trips, routes, and agencies to add metadata
  3. Projects to local CRS for distance calculations
  4. For each shape, calls _align_shape_with_stops() to:
  5. Match stops to existing shape points (via _match_stop_to_shape_points())
  6. Insert unmatched stops as new shape points (via _insert_stop_into_shape())
  7. Verify stop_sequence is monotonically increasing
  8. Writes debug GeoJSON output (via _write_debug_shapes())

Assumes create_feed_frequencies() has already run, so each shape corresponds to one consolidated trip_id.

Modifies feed_tables in place:

feed_tables[‘shapes’] - Adds/modifies columns: Route/Trip Metadata: - trip_id (str): Associated trip ID - direction_id (int): Direction of travel (0 or 1) - route_id (str): Route identifier - agency_id (str): Agency identifier - agency_name (str): Agency name - route_short_name (str): Route short name - route_type (int): GTFS route type

1
2
3
4
5
6
7
Stop Matching (for matched points only):
- stop_id (str): Matched stop ID
- stop_name (str): Matched stop name
- stop_sequence (int): Order of stop in trip
- match_distance_{crs_units} (float): Distance from original to stop location
- shape_pt_lon, shape_pt_lat: Updated to stop coordinates
- geometry: Updated to stop location

feed_tables[‘stop_times’] - Converted to GeoDataFrame with: - geometry: Stop location added from stops table

Parameters:

Name Type Description Default
feed_tables dict[str, DataFrame]

dictionary with required tables: - ‘shapes’: Shape points to update - ‘trips’: Trip information - ‘routes’: Route information - ‘agencies’: Agency information - ‘stops’: Stop locations - ‘stop_times’: Stop sequences

required
local_crs str

Coordinate reference system for projections

required
crs_units str

Distance units (‘feet’ or ‘meters’)

required
trace_shape_ids list[str] | None

Optional shape IDs for debug logging

None
Helper Functions

_align_shape_with_stops(): Process all stops for one shape _match_stop_to_shape_points(): Find nearest shape point for a stop _insert_stop_into_shape(): Insert unmatched stop as new shape point _write_debug_shapes(): Write debug GeoJSON output

Notes
  • Searches forward from previous matched position to handle routes that double back
  • Inserts stops immediately when unmatched (not batched) for accurate positioning
  • Handles circular routes (first/last stop same) with constrained search ranges
  • Writes debug_shapes.geojson with all shapes, stops, and shape points
Source code in network_wrangler/utils/transit.py
def add_additional_data_to_shapes(
    feed_tables: dict[str, pd.DataFrame],
    local_crs: str,
    crs_units: str,
    trace_shape_ids: list[str] | None = None,
):
    """Updates feed_tables['shapes'] with route/trip metadata and snaps shape points to stops.

    Enriches shape points with information from trips, routes, and agencies tables,
    then matches shape points to nearby stops and updates their locations. Processes
    stops in sequence order, always searching forward to handle routes that double back.

    Process Steps:

    1. Converts shapes to GeoDataFrame if needed (using shape_pt_lon/lat)
    2. Joins with trips, routes, and agencies to add metadata
    3. Projects to local CRS for distance calculations
    4. For each shape, calls _align_shape_with_stops() to:
       - Match stops to existing shape points (via _match_stop_to_shape_points())
       - Insert unmatched stops as new shape points (via _insert_stop_into_shape())
       - Verify stop_sequence is monotonically increasing
    5. Writes debug GeoJSON output (via _write_debug_shapes())

    Assumes create_feed_frequencies() has already run, so each shape corresponds
    to one consolidated trip_id.

    Modifies feed_tables in place:

    feed_tables['shapes'] - Adds/modifies columns:
        Route/Trip Metadata:
        - trip_id (str): Associated trip ID
        - direction_id (int): Direction of travel (0 or 1)
        - route_id (str): Route identifier
        - agency_id (str): Agency identifier
        - agency_name (str): Agency name
        - route_short_name (str): Route short name
        - route_type (int): GTFS route type

        Stop Matching (for matched points only):
        - stop_id (str): Matched stop ID
        - stop_name (str): Matched stop name
        - stop_sequence (int): Order of stop in trip
        - match_distance_{crs_units} (float): Distance from original to stop location
        - shape_pt_lon, shape_pt_lat: Updated to stop coordinates
        - geometry: Updated to stop location

    feed_tables['stop_times'] - Converted to GeoDataFrame with:
        - geometry: Stop location added from stops table

    Args:
        feed_tables: dictionary with required tables:
            - 'shapes': Shape points to update
            - 'trips': Trip information
            - 'routes': Route information
            - 'agencies': Agency information
            - 'stops': Stop locations
            - 'stop_times': Stop sequences
        local_crs: Coordinate reference system for projections
        crs_units: Distance units ('feet' or 'meters')
        trace_shape_ids: Optional shape IDs for debug logging

    Helper Functions:
        _align_shape_with_stops(): Process all stops for one shape
        _match_stop_to_shape_points(): Find nearest shape point for a stop
        _insert_stop_into_shape(): Insert unmatched stop as new shape point
        _write_debug_shapes(): Write debug GeoJSON output

    Notes:
        - Searches forward from previous matched position to handle routes that double back
        - Inserts stops immediately when unmatched (not batched) for accurate positioning
        - Handles circular routes (first/last stop same) with constrained search ranges
        - Writes debug_shapes.geojson with all shapes, stops, and shape points
    """
    # Step 1: Convert shapes to GeoDataFrame if needed and add geometry from lat/lon coordinates
    # Create GeoDataFrame from shape points if not already one
    if not isinstance(feed_tables["shapes"], gpd.GeoDataFrame):
        shape_geometry = [
            shapely.geometry.Point(lon, lat)
            for lon, lat in zip(
                feed_tables["shapes"]["shape_pt_lon"],
                feed_tables["shapes"]["shape_pt_lat"],
                strict=False,
            )
        ]
        feed_tables["shapes"] = gpd.GeoDataFrame(
            feed_tables["shapes"], geometry=shape_geometry, crs=LAT_LON_CRS
        )
        WranglerLogger.debug(f"Converted feed_tables['shapes'] to GeoDataFrame")
    else:
        WranglerLogger.debug(f"feed_tables['shapes'].crs={feed_tables['shapes'].crs}")

    # Step 2: Add agency, route, and trip information to shapes by joining with trips and routes tables
    # Get unique shape_ids to trips mapping
    trip_shapes_df = feed_tables["trips"][
        ["shape_id", "trip_id", "direction_id", "route_id"]
    ].drop_duplicates()
    WranglerLogger.debug(f"trip_shapes_df\n{trip_shapes_df}")
    # assumes trip_id and shape_id are equivalent due to add_additional_data_to_shapes
    assert feed_tables["trips"]["trip_id"].nunique() == feed_tables["trips"]["shape_id"].nunique()

    # Get route information: Join routes with agencies to get agency names
    routes_with_agency_df = pd.merge(
        feed_tables["routes"][["route_id", "agency_id", "route_short_name", "route_type"]],
        feed_tables["agencies"][["agency_id", "agency_name"]],
        on="agency_id",
        how="left",
    )
    # Add agency information to trips_shapes_df
    trip_shapes_df = pd.merge(
        left=trip_shapes_df, right=routes_with_agency_df, how="left", on="route_id"
    )
    # Add this data to shapes table
    feed_tables["shapes"] = pd.merge(
        feed_tables["shapes"], trip_shapes_df, on="shape_id", how="left", validate="many_to_one"
    )

    WranglerLogger.debug(f"Added agency and route information to shapes table")
    WranglerLogger.debug(f"feed_tables['shapes'].head():\n{feed_tables['shapes'].head()}")
    # shapes columns: shape_id, shape_pt_lat, shape_pt_lon, shape_pt_sequence, shape_dist_traveled, geometry
    #                 trip_id, direction_id, route_id, agency_id, route_short_name, route_type, agency_name

    # Match stops to shape points using segment iteration
    WranglerLogger.info(f"Matching stops to shape points using segment iteration")

    # Project both GeoDataFrames to specified CRS for distance calculations
    feed_tables["shapes"].to_crs(local_crs, inplace=True)
    feed_tables["stops"].to_crs(local_crs, inplace=True)

    # Initialize shape columns for stop information
    feed_tables["shapes"][f"match_distance_{crs_units}"] = np.inf
    feed_tables["shapes"]["stop_id"] = None
    feed_tables["shapes"]["stop_name"] = ""
    feed_tables["shapes"]["stop_sequence"] = None
    # these are the most useful columns for debugging
    stoptime_debug_cols = ["stop_sequence", "stop_id", "stop_name"]
    shape_debug_cols = [
        "shape_id",
        "shape_pt_sequence",
        "shape_dist_traveled",
        "stop_sequence",
        "stop_id",
        "stop_name",
        f"match_distance_{crs_units}",
    ]

    # Add stop geometry to stop_times and convert it a GeoDataFrame
    WranglerLogger.debug(f"Before merge, {len(feed_tables['stop_times'])=:,}")
    feed_tables["stop_times"] = gpd.GeoDataFrame(
        pd.merge(
            left=feed_tables["stop_times"],
            right=feed_tables["stops"][["stop_id", "stop_name", "geometry"]],
            how="left",
            on="stop_id",
            validate="many_to_one",
        ),
        geometry="geometry",
        crs=feed_tables["stops"].crs,
    )

    WranglerLogger.debug(
        f"feed_tables['trips'] type={type(feed_tables['trips'])}:\n{feed_tables['trips']}"
    )
    WranglerLogger.debug(
        f"feed_tables['stops'] type={type(feed_tables['stops'])}:\n{feed_tables['stops']}"
    )
    WranglerLogger.debug(
        f"feed_tables['stop_times'] type={type(feed_tables['stop_times'])}:\n{feed_tables['stop_times']}"
    )

    # Sort tables for processing
    feed_tables["stop_times"].sort_values(
        by=["trip_id", "stop_sequence"], inplace=True, ignore_index=True
    )
    feed_tables["shapes"].sort_values(
        by=["shape_id", "shape_pt_sequence"], inplace=True, ignore_index=True
    )

    # Process each shape_id
    unique_shape_ids = sorted(feed_tables["shapes"]["shape_id"].unique())
    WranglerLogger.info(f"Finding stops for {len(unique_shape_ids):,} unique shape_ids")

    matched_count = 0
    inserted_count = 0
    max_distance = (
        DefaultConfig.TRANSIT.MAX_DISTANCE_STOP_FEET
        if crs_units == "feet"
        else DefaultConfig.TRANSIT.MAX_DISTANCE_STOP_METERS
    )

    # Collect debug features for all shapes
    debug_features = []

    for shape_id in unique_shape_ids:
        # Process this shape: match and/or insert all stops
        shape_matched, shape_inserted, shape_debug_features = _align_shape_with_stops(
            shape_id=shape_id,
            feed_tables=feed_tables,
            local_crs=local_crs,
            crs_units=crs_units,
            max_distance=max_distance,
            trace_shape_ids=trace_shape_ids,
            stoptime_debug_cols=stoptime_debug_cols,
            shape_debug_cols=shape_debug_cols,
        )

        matched_count += shape_matched
        inserted_count += shape_inserted
        debug_features.extend(shape_debug_features)

    WranglerLogger.info(
        f"Finished adding stop information to shapes: "
        f"{matched_count} stops matched, {inserted_count} stops inserted"
    )

    # Write consolidated debug GeoJSON
    _write_debug_shapes(debug_features, local_crs)

network_wrangler.utils.transit.add_additional_data_to_stops

add_additional_data_to_stops(feed_tables)

Updates feed_tables[‘stops’] with additional metadata about routes and agencies.

Aggregates information from stop_times, trips, routes, and agencies tables to add comprehensive metadata about which routes and agencies serve each stop.

Process Steps:

  1. Joins stop_times with trips to get route and shape information
  2. Joins with routes and agencies to get route types and agency names
  3. Groups by stop_id to aggregate all serving routes/agencies
  4. Identifies mixed-traffic stops based on route types
  5. Handles parent stations by checking child stop characteristics

Modifies feed_tables[‘stops’] in place, adding columns:

Route/Agency Information: - agency_ids (list of str): All agencies serving this stop - agency_names (list of str): Names of agencies serving this stop - route_ids (list of str): All routes serving this stop - route_names (list of str): Short names of routes serving this stop - route_types (list of int): Types of routes serving this stop - shape_ids (list of str): All shapes associated with this stop

Stop Type Flags: - is_parent (bool): True if other stops reference this as parent_station

Parameters:

Name Type Description Default
feed_tables dict[str, DataFrame]

dictionary with required tables: - ‘stop_times’: Links stops to trips - ‘trips’: Links trips to routes and shapes - ‘routes’: Route information including type - ‘agencies’: Agency names - ‘stops’: Table to be updated

required
Notes
  • Parent stations may not appear in trips but are retained if referenced
  • Empty lists are used for stops with no associated routes
  • Handles missing parent_station column gracefully
Source code in network_wrangler/utils/transit.py
def add_additional_data_to_stops(
    feed_tables: dict[str, pd.DataFrame],
):
    """Updates feed_tables['stops'] with additional metadata about routes and agencies.

    Aggregates information from stop_times, trips, routes, and agencies tables to add
    comprehensive metadata about which routes and agencies serve each stop.

    Process Steps:

    1. Joins stop_times with trips to get route and shape information
    2. Joins with routes and agencies to get route types and agency names
    3. Groups by stop_id to aggregate all serving routes/agencies
    4. Identifies mixed-traffic stops based on route types
    5. Handles parent stations by checking child stop characteristics

    Modifies feed_tables['stops'] in place, adding columns:

    Route/Agency Information:
    - agency_ids (list of str): All agencies serving this stop
    - agency_names (list of str): Names of agencies serving this stop
    - route_ids (list of str): All routes serving this stop
    - route_names (list of str): Short names of routes serving this stop
    - route_types (list of int): Types of routes serving this stop
    - shape_ids (list of str): All shapes associated with this stop

    Stop Type Flags:
    - is_parent (bool): True if other stops reference this as parent_station

    Args:
        feed_tables: dictionary with required tables:
            - 'stop_times': Links stops to trips
            - 'trips': Links trips to routes and shapes
            - 'routes': Route information including type
            - 'agencies': Agency names
            - 'stops': Table to be updated

    Notes:
        - Parent stations may not appear in trips but are retained if referenced
        - Empty lists are used for stops with no associated routes
        - Handles missing parent_station column gracefully
    """
    # Add information about agencies, routes, directions and shapes to stops
    # Join stop_times with trips and routes
    stop_trips = pd.merge(
        feed_tables["stop_times"][["stop_id", "trip_id"]].drop_duplicates(),
        feed_tables["trips"][["trip_id", "direction_id", "route_id", "shape_id"]],
        on="trip_id",
        how="left",
    )
    WranglerLogger.debug(f"After joining stop_times with trips: {len(stop_trips):,} records")

    # Create stop to route, agency mapping with direction information
    stop_agencies = pd.merge(
        stop_trips,
        feed_tables["routes"][["route_id", "agency_id", "route_short_name", "route_type"]],
        on="route_id",
        how="left",
    )[
        [
            "stop_id",
            "agency_id",
            "route_id",
            "direction_id",
            "shape_id",
            "route_short_name",
            "route_type",
        ]
    ].drop_duplicates()
    WranglerLogger.debug(f"stop_agencies.head():\n{stop_agencies.head()}")

    # pick up agency information
    stop_agencies = pd.merge(
        stop_agencies,
        feed_tables["agencies"][["agency_id", "agency_name"]],
        on="agency_id",
        how="left",
    )
    WranglerLogger.debug(f"stop_agencies.head():\n{stop_agencies.head()}")

    # Group by stop to get all agencies and routes serving each stop
    # Now including route_dir_ids as list of (route_id, direction_id) tuples
    stop_agency_info = (
        stop_agencies.groupby("stop_id")
        .agg(
            {
                "agency_id": lambda x: list(x.dropna().unique()),
                "agency_name": lambda x: list(x.dropna().unique()) if x.notna().any() else [],
                "route_id": lambda x: list(x.dropna().unique()),
                "route_short_name": lambda x: list(x.dropna().unique()),
                "route_type": lambda x: list(x.dropna().unique()),
                "shape_id": lambda x: list(x.dropna().unique()),
            }
        )
        .reset_index()
    )

    stop_agency_info.columns = [
        "stop_id",
        "agency_ids",
        "agency_names",
        "route_ids",
        "route_names",
        "route_types",
        "shape_ids",
    ]

    # columns: stop_id (str), agency_ids (list of str), agency_names (list of str),
    #   route_ids (list of str), route_names (list of str), route_types (list of int),
    #   shape_ids (list of str)
    WranglerLogger.debug(f"stop_agency_info.head():\n{stop_agency_info.head()}")

    # Merge this information back to stops
    feed_tables["stops"] = pd.merge(
        feed_tables["stops"], stop_agency_info, on="stop_id", how="left"
    )

    # Handle parent stations that may not be included in trips
    if "parent_station" in feed_tables["stops"].columns:
        # Find which stops are referenced as parent stations
        child_stops = feed_tables["stops"][
            feed_tables["stops"]["parent_station"].notna()
            & (feed_tables["stops"]["parent_station"] != "")
        ]

        if len(child_stops) > 0:
            # Get unique parent station IDs
            parent_station_ids = child_stops["parent_station"].unique()

            # Mark parent stations
            feed_tables["stops"]["is_parent"] = False
            feed_tables["stops"].loc[
                feed_tables["stops"]["stop_id"].isin(parent_station_ids), "is_parent"
            ] = True

            # Log parent stations
            WranglerLogger.debug(
                f"Found {len(parent_station_ids)} parent stations:\n"
                + f"{feed_tables['stops'].loc[feed_tables['stops']['is_parent'] == True, ['stop_id', 'stop_name']]}"
            )
        else:
            feed_tables["stops"]["is_parent"] = False
    else:
        feed_tables["stops"]["is_parent"] = False

    WranglerLogger.debug(
        f"add_additional_data_to_stops() completed. feed_tables['stops']:\n{feed_tables['stops']}"
    )
add_stations_and_links_to_roadway_network(
    feed_tables,
    roadway_net,
    local_crs,
    crs_units,
    trace_shape_ids=None,
    default_node_attribute_dict=None,
    default_link_attribute_dict=None,
)

Add transit station nodes and dedicated transit links to the roadway network.

Creates new roadway nodes for transit stations and adds dedicated transit links between stations for fixed-guideway transit (rail, subway, ferry, etc.). Bus stops use existing roadway nodes from match_bus_stops_to_roadway_nodes().

Process Steps: 1. Creates stop link pairs from consecutive stops in stop_times 2. Aggregates intermediate shape points between stops into multi-point lines 3. Filters links to STATION_ROUTE_TYPES for network addition 4. Creates new roadway nodes for stations not already in network 5. Creates dedicated transit links with appropriate access restrictions 6. Updates feed_tables[‘stops’] with model_node_id for all stops 7. Updates feed_tables[‘shapes’] with shape_model_node_id for all stations 8. Returns bus stop links separately (not added to network)

Modifies in place:

roadway_net - Adds
  • New nodes for transit stations with model_node_id
  • New links between stations with:
    • rail_only=True for rail types
    • ferry_only=True for ferry types
    • drive/bike/walk/truck_access=False
    • Geometry following shape points if available

feed_tables[‘stops’] - Adds/updates: - model_node_id (int): Roadway node ID for the stop - Updates existing bus stop model_node_ids - Adds new station model_node_ids

feed_tables[‘shapes’] - Adds/updates: - shape_model_node_id (int): Roadway node ID for the shape point

Parameters:

Name Type Description Default
feed_tables dict[str, DataFrame]

dictionary with required tables: - ‘stops’: Stop information with geometry - ‘stop_times’: Stop sequences for trips - ‘shapes’: Shape points between stops - ‘routes’: Route types

required
roadway_net RoadwayNetwork

RoadwayNetwork to modify with new nodes/links

required
local_crs str

Coordinate reference system for projections

required
crs_units str

Distance units (‘feet’ or ‘meters’)

required
trace_shape_ids list[str] | None

Optional shape IDs for debug logging

None
default_node_attribute_dict dict[str, any] | None

Optional dict of column-name to default value to set on new nodes.

None
default_link_attribute_dict dict[str, any] | None

Optional dict of column-name to default value to set on new links.

None

Returns:

Type Description
tuple[dict[str, int], GeoDataFrame]

tuple[dict[str,int], gpd.GeoDataFrame]: - dictionary mapping new station stop_ids to model_node_ids - GeoDataFrame of bus stop links (not added to network) with columns: shape_id, stop_sequence, stop_id, stop_name, next_stop_id, next_stop_name, A, B, geometry

Notes
  • Stations are new nodes; bus stops use existing road nodes
  • Self-loops (stop appearing twice consecutively) are filtered out
  • Links follow actual shape geometry when available
  • Parent stations without trips are handled correctly
Source code in network_wrangler/utils/transit.py
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
2722
2723
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735
2736
2737
2738
2739
2740
2741
2742
2743
2744
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
2797
2798
2799
2800
2801
2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
2859
2860
2861
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
2939
2940
2941
2942
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
2962
2963
2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
3019
3020
3021
3022
3023
3024
3025
3026
3027
3028
3029
3030
3031
3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
3043
3044
3045
3046
3047
3048
3049
3050
3051
3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064
3065
3066
3067
3068
3069
3070
3071
3072
3073
3074
3075
3076
3077
3078
3079
3080
3081
3082
3083
3084
3085
3086
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
3098
3099
3100
3101
3102
3103
3104
3105
3106
3107
3108
3109
3110
3111
3112
3113
3114
3115
3116
3117
3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
def add_stations_and_links_to_roadway_network(  # noqa: PLR0912, PLR0915
    feed_tables: dict[str, pd.DataFrame],
    roadway_net: RoadwayNetwork,
    local_crs: str,
    crs_units: str,
    trace_shape_ids: list[str] | None = None,
    default_node_attribute_dict: dict[str, any] | None = None,
    default_link_attribute_dict: dict[str, any] | None = None,
) -> tuple[dict[str, int], gpd.GeoDataFrame]:
    """Add transit station nodes and dedicated transit links to the roadway network.

    Creates new roadway nodes for transit stations and adds dedicated transit links
    between stations for fixed-guideway transit (rail, subway, ferry, etc.). Bus stops
    use existing roadway nodes from match_bus_stops_to_roadway_nodes().

    Process Steps:
    1. Creates stop link pairs from consecutive stops in stop_times
    2. Aggregates intermediate shape points between stops into multi-point lines
    3. Filters links to STATION_ROUTE_TYPES for network addition
    4. Creates new roadway nodes for stations not already in network
    5. Creates dedicated transit links with appropriate access restrictions
    6. Updates feed_tables['stops'] with model_node_id for all stops
    7. Updates feed_tables['shapes'] with shape_model_node_id for all stations
    8. Returns bus stop links separately (not added to network)

    Modifies in place:

    roadway_net - Adds:
        - New nodes for transit stations with model_node_id
        - New links between stations with:
            - rail_only=True for rail types
            - ferry_only=True for ferry types
            - drive/bike/walk/truck_access=False
            - Geometry following shape points if available

    feed_tables['stops'] - Adds/updates:
        - model_node_id (int): Roadway node ID for the stop
        - Updates existing bus stop model_node_ids
        - Adds new station model_node_ids

    feed_tables['shapes'] - Adds/updates:
        - shape_model_node_id (int): Roadway node ID for the shape point

    Args:
        feed_tables: dictionary with required tables:
            - 'stops': Stop information with geometry
            - 'stop_times': Stop sequences for trips
            - 'shapes': Shape points between stops
            - 'routes': Route types
        roadway_net: RoadwayNetwork to modify with new nodes/links
        local_crs: Coordinate reference system for projections
        crs_units: Distance units ('feet' or 'meters')
        trace_shape_ids: Optional shape IDs for debug logging
        default_node_attribute_dict: Optional dict of column-name to default value to set on new nodes.
        default_link_attribute_dict: Optional dict of column-name to default value to set on new links.

    Returns:
        tuple[dict[str,int], gpd.GeoDataFrame]:
            - dictionary mapping new station stop_ids to model_node_ids
            - GeoDataFrame of bus stop links (not added to network) with columns:
                shape_id, stop_sequence, stop_id, stop_name, next_stop_id,
                next_stop_name, A, B, geometry

    Notes:
        - Stations are new nodes; bus stops use existing road nodes
        - Self-loops (stop appearing twice consecutively) are filtered out
        - Links follow actual shape geometry when available
        - Parent stations without trips are handled correctly
    """
    WranglerLogger.info(f"Adding transit stations and station-based links to the roadway network")
    WranglerLogger.debug(
        f"feed_tables['shapes'] type={type(feed_tables['shapes'])}:\n{feed_tables['shapes']}"
    )
    WranglerLogger.debug(
        f"feed_tables['stops'] type={type(feed_tables['stops'])}:\n{feed_tables['stops']}"
    )
    WranglerLogger.debug(
        f"feed_tables['stop_times'] type={type(feed_tables['stop_times'])}:\n{feed_tables['stop_times']}"
    )

    # Add route_type to stop_times
    if "route_type" not in feed_tables["stop_times"].columns:
        feed_tables["stop_times"] = pd.merge(
            feed_tables["stop_times"],
            feed_tables["routes"][["route_id", "route_type"]],
            how="left",
            on="route_id",
            validate="many_to_one",
        )
    # keep trace_stop_id_set
    trace_stop_id_set = None
    if trace_shape_ids:
        trace_stop_id_set = set(
            feed_tables["stop_times"]
            .loc[feed_tables["stop_times"]["shape_id"].isin(trace_shape_ids), "stop_id"]
            .to_list()
        )
        WranglerLogger.debug(f"trace_stop_id_set:{trace_stop_id_set}")

    # Prepare new link list first
    # For all consecutive stop_ids in feed_table['stop_times'], create consecutive node pairs for each shape
    stop_links_df = feed_tables["stop_times"][
        [
            "route_type",
            "route_id",
            "direction_id",
            "shape_id",
            "trip_id",
            "stop_sequence",
            "stop_id",
            "stop_name",
            "geometry",
        ]
    ].copy()
    stop_links_df.rename(columns={"geometry": "stop_geometry"}, inplace=True)
    stop_links_df.sort_values(by=["trip_id", "stop_sequence"], inplace=True)

    stop_links_df["next_stop_id"] = stop_links_df.groupby("trip_id")["stop_id"].shift(-1)
    stop_links_df["next_stop_name"] = stop_links_df.groupby("trip_id")["stop_name"].shift(-1)
    stop_links_df["next_stop_geometry"] = stop_links_df.groupby("trip_id")["stop_geometry"].shift(
        -1
    )

    # Filter to only rows that have a next node (excludes last point of each shape)
    # and filter out self-loops where the stop occurs twice in a row
    has_next = stop_links_df["next_stop_id"].notna()
    is_self_loop = stop_links_df["next_stop_id"] == stop_links_df["stop_id"]

    num_self_loops = (has_next & is_self_loop).sum()
    if num_self_loops > 0:
        WranglerLogger.debug(
            f"Filtering out {num_self_loops:,} self-loop segments where stop_id == next_stop_id"
        )

    stop_links_df = stop_links_df[has_next & ~is_self_loop]
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace stop_links_df for {trace_shape_id}:\n"
                f"{stop_links_df.loc[stop_links_df.shape_id == trace_shape_id]}"
            )

    WranglerLogger.debug(f"stop_links_df.dtypes:\n{stop_links_df.dtypes}")
    # route_type            category
    # route_id                object
    # direction_id          category
    # shape_id                object
    # trip_id                 object
    # stop_sequence            int64
    # stop_id                 object
    # stop_name               object
    # stop_geometry         geometry
    # next_stop_id            object
    # next_stop_name          object
    # next_stop_geometry    geometry
    # dtype: object

    # feed_tables['shapes'] is a GeoDataFrame of points, 3 the columns 'shape_id', 'stop_id' and 'stop_sequence'
    # Match these sequences with the stop_id/next_stop_id in stop_links_gdf based on shape_id and add intermediate points to the shape
    shape_links_df = feed_tables["shapes"][
        ["shape_id", "geometry", "stop_sequence", "stop_id"]
    ].copy()
    # set the stop_sequence to 1 for the first row of each shape_id if it's not set
    shape_links_df.loc[
        (~shape_links_df["shape_id"].duplicated()) & (shape_links_df["stop_sequence"].isna()),
        "stop_sequence",
    ] = -1
    shape_links_df.loc[
        (~shape_links_df["shape_id"].duplicated()) & (shape_links_df["stop_id"].isna()), "stop_id"
    ] = -1
    # fill forward
    # Suppress downcasting warning for ffill
    with pd.option_context("future.no_silent_downcasting", True):
        shape_links_df["shape_stop_sequence"] = shape_links_df["stop_sequence"].ffill()
        shape_links_df["shape_stop_id"] = shape_links_df["stop_id"].ffill()
    # drop the first one - that's already covered by the stop point
    shape_links_df.loc[shape_links_df["stop_id"].notna(), "shape_stop_sequence"] = np.nan
    shape_links_df.loc[shape_links_df["stop_id"].notna(), "shape_stop_id"] = np.nan

    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace shape_links_df for {trace_shape_id} after ffill():\n"
                f"{shape_links_df.loc[shape_links_df.shape_id == trace_shape_id]}"
            )
    # Example: this shape's first stop is stop_sequence=2, the shape points at the beginning are pre-trip points
    #                 shape_id                         geometry stop_sequence stop_id  shape_stop_sequence shape_stop_id
    # 909222  SF:9717:20230930  POINT (6013196.693 2109556.951)            -1      -1                  NaN           NaN
    # 909223  SF:9717:20230930  POINT (6013227.638 2109614.595)          None    None                -1.00            -1
    # 909224  SF:9717:20230930  POINT (6013265.791 2109685.568)          None    None                -1.00            -1
    # 909225  SF:9717:20230930  POINT (6013316.874 2109752.637)          None    None                -1.00            -1
    # 909226  SF:9717:20230930  POINT (6013604.053 2110029.431)          None    None                -1.00            -1
    # 909227  SF:9717:20230930  POINT (6013637.289 2110057.529)          None    None                -1.00            -1
    # 909228  SF:9717:20230930  POINT (6014283.122 2110658.483)             2   15240                  NaN           NaN
    # 909229  SF:9717:20230930  POINT (6014293.455 2110683.403)          None    None                 2.00         15240
    # 909230  SF:9717:20230930  POINT (6014940.906 2111322.944)          None    None                 2.00         15240
    # 909231  SF:9717:20230930  POINT (6015538.148 2111853.892)             3   15237                  NaN           NaN
    # 909232  SF:9717:20230930  POINT (6015603.806 2111941.794)          None    None                 3.00         15237
    # 909233  SF:9717:20230930  POINT (6015814.592 2112130.926)          None    None                 3.00         15237
    # 909234  SF:9717:20230930  POINT (6015897.484 2112242.153)          None    None                 3.00         15237
    # 909235  SF:9717:20230930   POINT (6015932.62 2112307.364)          None    None                 3.00         15237
    # 909236  SF:9717:20230930  POINT (6015955.806 2112367.716)          None    None                 3.00         15237
    # 909237  SF:9717:20230930  POINT (6015971.693 2112438.778)          None    None                 3.00         15237
    # 909238  SF:9717:20230930  POINT (6015985.888 2112526.263)          None    None                 3.00         15237
    # 909239  SF:9717:20230930   POINT (6015989.04 2112639.464)          None    None                 3.00         15237
    # 909240  SF:9717:20230930  POINT (6016045.522 2113160.945)          None    None                 3.00         15237
    # 909241  SF:9717:20230930  POINT (6016108.599 2113665.541)             4   17145                  NaN           NaN

    # aggregate and convert to list
    shape_links_agg_df = (
        shape_links_df.groupby(by=["shape_id", "shape_stop_sequence"])
        .aggregate(
            point_list=pd.NamedAgg(column="geometry", aggfunc=list),
            num_points=pd.NamedAgg(column="geometry", aggfunc="nunique"),
        )
        .reset_index(drop=False)
    )
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace shape_links_agg_df for {trace_shape_id}:\n"
                f"{shape_links_agg_df.loc[shape_links_agg_df.shape_id == trace_shape_id]}"
            )
    # columns are shape_id, stop_sequence, point_list, num_points
    #                shape_id  shape_stop_sequence                                         point_list  num_points
    # 36235  SF:9717:20230930                -1.00  [POINT (6013227.6381617775 2109614.5949144145)...           5
    # 36236  SF:9717:20230930                 2.00  [POINT (6014293.454981883 2110683.4032201055),...           2
    # 36237  SF:9717:20230930                 3.00  [POINT (6015603.805947471 2111941.793745517), ...           9
    # 36238  SF:9717:20230930                 4.00  [POINT (6016092.661659231 2113720.8560314486),...          13

    # add intermediate stops to stop_links
    stop_links_df = pd.merge(
        left=stop_links_df,
        right=shape_links_agg_df.rename(columns={"shape_stop_sequence": "stop_sequence"}),
        on=["shape_id", "stop_sequence"],
        how="left",
        indicator=True,
    )
    WranglerLogger.debug(
        f"stop_links_df._merge.value_counts():\n{stop_links_df._merge.value_counts()}"
    )
    # for links without intermediate shape points
    stop_links_df.loc[stop_links_df["_merge"] == "left_only", "num_points"] = 0
    if trace_shape_ids:
        WranglerLogger.debug(
            f"trace stop_links_df:\n{stop_links_df.loc[stop_links_df.shape_id.isin(trace_shape_ids)]}"
        )

    # turn them into multi-point lines
    stop_links_df["geometry"] = stop_links_df.apply(
        lambda row: (
            # if intermediate points
            shapely.geometry.LineString(
                [row["stop_geometry"]] + row["point_list"] + [row["next_stop_geometry"]]
            )
            if row["num_points"] > 0
            # no intermediate points
            else shapely.geometry.LineString([row["stop_geometry"], row["next_stop_geometry"]])
        ),
        axis=1,
    )
    WranglerLogger.debug(f"stop_links_df including multi-point lines:\n{stop_links_df}")

    # create GeoDataFrame; this is in the local crs
    stop_links_df.drop(
        columns=["stop_geometry", "next_stop_geometry", "point_list", "_merge"], inplace=True
    )
    stop_links_gdf = gpd.GeoDataFrame(
        stop_links_df, geometry="geometry", crs=feed_tables["stops"].crs
    )
    WranglerLogger.debug(f"stop_links_gdf.dtypes\n{stop_links_gdf.dtypes}")
    # route_type        category
    # route_id            object
    # direction_id      category
    # shape_id            object
    # trip_id             object
    # stop_sequence        int64
    # stop_id             object
    # stop_name           object
    # next_stop_id        object
    # next_stop_name      object
    # num_points           int64
    # geometry          geometry
    # dtype: object

    # Filter to STATION_ROUTE_TYPES for adding to the roadway network
    station_stop_links_gdf = stop_links_gdf.loc[
        stop_links_gdf.route_type.isin(STATION_ROUTE_TYPES)
    ].copy()
    station_stop_id_set = set(station_stop_links_gdf["stop_id"]) | set(
        station_stop_links_gdf["next_stop_id"]
    )
    # Also add parent stations
    parent_station_id_set = set(
        feed_tables["stops"].loc[feed_tables["stops"]["is_parent"] == True, "stop_id"].tolist()
    )
    WranglerLogger.debug(
        f"parent_station_id_set len={len(parent_station_id_set)}:\n{parent_station_id_set}"
    )

    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace station_stop_links_gdf for {trace_shape_id}:\n"
                f"{station_stop_links_gdf.loc[station_stop_links_gdf.shape_id == trace_shape_id]}"
            )

    # Prepare nodes to add - station stops and parent stops
    station_stop_ids_gdf = (
        feed_tables["stops"]
        .loc[
            (feed_tables["stops"]["stop_id"].isin(station_stop_id_set))
            | (feed_tables["stops"]["stop_id"].isin(parent_station_id_set)),
            [
                "stop_id",
                "stop_name",
                "is_parent",
                "stop_lat",
                "stop_lon",
                "model_node_id",
                "geometry",
            ],
        ]
        .reset_index(drop=True)
        .copy()
    )
    station_stop_ids_gdf.rename(
        columns={"stop_lon": "X", "stop_lat": "Y", "stop_id": "stop_id_GTFS"}, inplace=True
    )
    station_stop_ids_gdf.to_crs(LAT_LON_CRS, inplace=True)
    WranglerLogger.debug(f"station_stop_ids_gdf:\n{station_stop_ids_gdf}")
    if trace_stop_id_set:
        WranglerLogger.debug(
            f"trace station_stop_ids_gdf for trace_stop_id_set:\n"
            f"{station_stop_ids_gdf.loc[station_stop_ids_gdf['stop_id_GTFS'].isin(trace_stop_id_set)]}"
        )

    # Don't create new stations where one already exists! (stops that serve both bus and light rail)
    # => Filter to station ONLY
    new_station_stop_ids_gdf = station_stop_ids_gdf.loc[
        station_stop_ids_gdf["model_node_id"].isna()
    ].reset_index(drop=False)
    new_station_stop_ids_gdf.drop(columns={"model_node_id"}, inplace=True)

    # Assign model_node_id and add new stations to roadway network as roadway nodes
    max_node_num = roadway_net.nodes_df.model_node_id.max()
    new_station_stop_ids_gdf["model_node_id"] = new_station_stop_ids_gdf.index + max_node_num + 1
    new_station_stop_ids_gdf.drop(columns=["index"], inplace=True)

    # Apply default node attributes
    if default_node_attribute_dict is None:
        default_node_attribute_dict = {}
    for colname, default_value in default_node_attribute_dict.items():
        new_station_stop_ids_gdf[colname] = default_value

    WranglerLogger.info(f"Adding {len(new_station_stop_ids_gdf):,} nodes to roadway network")
    WranglerLogger.debug(f"new_station_stop_ids_gdf:\n{new_station_stop_ids_gdf}")
    WranglerLogger.debug(f"Before adding nodes, {len(roadway_net.nodes_df)=:,}")
    roadway_net.add_nodes(new_station_stop_ids_gdf)
    WranglerLogger.debug(f"After adding nodes, {len(roadway_net.nodes_df)=:,}")

    # get stop_id -> model_node_id for new nodes and stations that mapped to roadway nodes
    # (e.g. for LRT that have road node stations)
    new_stop_id_to_model_node_id_dict = (
        new_station_stop_ids_gdf[["stop_id_GTFS", "model_node_id"]]
        .set_index("stop_id_GTFS")
        .to_dict()["model_node_id"]
    )
    stop_id_to_model_node_id_dict = (
        station_stop_ids_gdf[["stop_id_GTFS", "model_node_id"]]
        .set_index("stop_id_GTFS")
        .to_dict()["model_node_id"]
    )
    stop_id_to_model_node_id_dict.update(new_stop_id_to_model_node_id_dict)
    WranglerLogger.debug(f"stop_id_to_model_node_id_dict:\n{stop_id_to_model_node_id_dict}")

    # Prepare links to add
    station_stop_links_gdf["A"] = station_stop_links_gdf["stop_id"].map(
        stop_id_to_model_node_id_dict
    )
    station_stop_links_gdf["B"] = station_stop_links_gdf["next_stop_id"].map(
        stop_id_to_model_node_id_dict
    )
    # Set rail/ferry only values
    station_stop_links_gdf = station_stop_links_gdf[
        [
            "route_type",
            "A",
            "stop_id",
            "stop_name",
            "B",
            "next_stop_id",
            "next_stop_name",
            "shape_id",
            "geometry",
        ]
    ]
    station_stop_links_gdf["rail_only"] = False
    station_stop_links_gdf.loc[
        station_stop_links_gdf.route_type.isin(RAIL_ROUTE_TYPES), "rail_only"
    ] = True
    station_stop_links_gdf["ferry_only"] = False
    station_stop_links_gdf.loc[
        station_stop_links_gdf.route_type.isin(FERRY_ROUTE_TYPES), "ferry_only"
    ] = True
    # Aggregate by A,B, choosing first values, and convert back to GeoDataFrame
    station_road_links_gdf = gpd.GeoDataFrame(
        station_stop_links_gdf.groupby(by=["A", "B"])
        .aggregate(
            stop_id=pd.NamedAgg(column="stop_id", aggfunc="first"),
            stop_name=pd.NamedAgg(column="stop_name", aggfunc="first"),
            next_stop_id=pd.NamedAgg(column="next_stop_id", aggfunc="first"),
            next_stop_name=pd.NamedAgg(column="next_stop_name", aggfunc="first"),
            geometry=pd.NamedAgg(column="geometry", aggfunc="first"),
            rail_only=pd.NamedAgg(column="rail_only", aggfunc=any),
            ferry_only=pd.NamedAgg(column="ferry_only", aggfunc=any),
            shape_ids=pd.NamedAgg(column="shape_id", aggfunc=list),
        )
        .reset_index(drop=False),
        crs=station_stop_links_gdf.crs,
    )
    station_road_links_gdf["A"] = station_road_links_gdf["A"].astype(int)
    station_road_links_gdf["B"] = station_road_links_gdf["B"].astype(int)

    # Drop links that are already in roadway network - this may happen for LRT links on roadways
    # But first, make sure rail_only or ferry_only is set to True in the roadway links version

    # save this to re-apply
    links_df_name = roadway_net.links_df.attrs["name"]
    roadway_net.links_df = roadway_net.links_df.merge(
        right=station_road_links_gdf[["A", "B", "rail_only", "ferry_only"]],
        how="left",
        on=["A", "B"],
        validate="one_to_one",
        suffixes=["", "_update"],
        indicator=True,
    )
    # re-apply
    roadway_net.links_df.attrs["name"] = links_df_name
    WranglerLogger.debug(
        f"Making sure existing roadway links corresponding to station pairs have transit access\n"
        f"{roadway_net.links_df.loc[roadway_net.links_df._merge == 'both']}"
    )
    # if any of these are footway or cycleway, warn
    if "roadway" in roadway_net.links_df.columns:
        ACTIVE_OSM_HIGHWAY = ["footway", "cycleway", "path", "pedestrian"]
        active_only = roadway_net.links_df.loc[
            (roadway_net.links_df["_merge"] == "both")
            & roadway_net.links_df["roadway"].isin(ACTIVE_OSM_HIGHWAY)
        ]
        if len(active_only) > 0:
            WranglerLogger.warning(
                f"Adding rail or ferry access to {len(active_only)} active links -- See debug log"
            )
            WranglerLogger.debug(f"Updating the following:\n{active_only}")

    roadway_net.links_df.loc[roadway_net.links_df["_merge"] == "both", "rail_only"] = (
        roadway_net.links_df["rail_only"] | roadway_net.links_df["rail_only_update"]
    )
    roadway_net.links_df.loc[roadway_net.links_df["_merge"] == "both", "ferry_only"] = (
        roadway_net.links_df["ferry_only"] | roadway_net.links_df["ferry_only_update"]
    )
    WranglerLogger.debug(
        f"After updating:\n{roadway_net.links_df.loc[roadway_net.links_df._merge == 'both']}"
    )
    roadway_net.links_df.drop(
        columns=["_merge", "rail_only_update", "ferry_only_update"], inplace=True
    )

    # Now drop those that are already in the roadway network
    station_road_links_gdf = station_road_links_gdf.merge(
        right=roadway_net.links_df[["A", "B"]], how="left", validate="one_to_one", indicator=True
    )
    WranglerLogger.debug(
        f"Dropping the following station_road_links_gdf rows that are already in the roadway network:\n"
        f"{station_road_links_gdf.loc[station_road_links_gdf['_merge'] == 'both']}"
    )

    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace station_road_links_gdf including check for roadway_net.links_df for {trace_shape_id}:\n"
                f"{station_road_links_gdf.loc[station_road_links_gdf['shape_ids'].apply(lambda x, tid=trace_shape_id: isinstance(x, list) and tid in x)]}"
            )
    station_road_links_gdf = station_road_links_gdf.loc[
        station_road_links_gdf["_merge"] == "left_only"
    ]
    station_road_links_gdf.drop(columns={"_merge"}, inplace=True)

    # Assign model_link_id, access for drive,walk,bike,truck,bus
    max_model_link_id = roadway_net.links_df.model_link_id.max()
    station_road_links_gdf["model_link_id"] = station_road_links_gdf.index + max_model_link_id + 1
    station_road_links_gdf["name"] = (
        station_road_links_gdf["stop_name"] + " to " + station_road_links_gdf["next_stop_name"]
    )
    station_road_links_gdf["shape_id"] = (
        station_road_links_gdf["stop_id"] + " to " + station_road_links_gdf["next_stop_id"]
    )
    station_road_links_gdf["drive_access"] = False
    station_road_links_gdf["bike_access"] = False
    station_road_links_gdf["walk_access"] = False
    station_road_links_gdf["truck_access"] = False
    station_road_links_gdf["bus_only"] = False
    station_road_links_gdf["lanes"] = 0
    if "roadway" in roadway_net.links_df.columns:
        station_road_links_gdf["roadway"] = "transit"

    # Set distance
    station_road_links_gdf.to_crs(local_crs, inplace=True)
    station_road_links_gdf["length"] = station_road_links_gdf.length
    if crs_units == "feet":
        station_road_links_gdf["distance"] = station_road_links_gdf["length"] / FEET_PER_MILE
    else:
        station_road_links_gdf["distance"] = (
            station_road_links_gdf["length"] / METERS_PER_KILOMETER
        )
    station_road_links_gdf.to_crs(LAT_LON_CRS, inplace=True)

    # Add to roadway network
    WranglerLogger.info(f"Adding {len(station_road_links_gdf):,} links to roadway network")
    WranglerLogger.debug(f"station_road_links_gdf:\n{station_road_links_gdf}")

    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace station_road_links_gdf for {trace_shape_id}:\n"
                f"{station_road_links_gdf.loc[station_road_links_gdf['shape_ids'].apply(lambda x, tid=trace_shape_id: isinstance(x, list) and tid in x)]}"
            )

    # Apply default link attributes
    if default_link_attribute_dict is None:
        default_link_attribute_dict = {}
    for colname, default_value in default_link_attribute_dict.items():
        station_road_links_gdf[colname] = default_value

    WranglerLogger.debug(f"Before adding links, {len(roadway_net.links_df)=:,}")
    roadway_net.add_links(station_road_links_gdf)
    WranglerLogger.debug(f"After adding links, {len(roadway_net.links_df)=:,}")

    WranglerLogger.info(f"Adding {len(station_road_links_gdf):,} shapes to roadway network")
    roadway_net.add_shapes(station_road_links_gdf)

    # Update feed_table['stops']: set model_node_id for stations in feed_table['stops']
    WranglerLogger.debug(
        f"Before updating, feed_tables['stops'] with model_node_id set: "
        f"{feed_tables['stops']['model_node_id'].notna().sum():,}"
    )
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace feed_tables['stops'] for {trace_shape_id}:\n"
                f"{feed_tables['stops'].loc[feed_tables['stops']['shape_ids'].apply(lambda x, tid=trace_shape_id: isinstance(x, list) and tid in x)]}"
            )

    feed_tables["stops"]["station_node_id"] = feed_tables["stops"]["stop_id"].map(
        stop_id_to_model_node_id_dict
    )

    # Verify no stops table stops have a model_node_id already set (from match_bus_stops_to_roadway_nodes() *and* have a station_model_node_id
    have_both_df = feed_tables["stops"].loc[
        feed_tables["stops"]["model_node_id"].notna()
        & feed_tables["stops"]["station_node_id"].notna()
        & (feed_tables["stops"]["model_node_id"] != feed_tables["stops"]["station_node_id"])
    ]
    assert len(have_both_df) == 0

    feed_tables["stops"].loc[feed_tables["stops"]["station_node_id"].notna(), "model_node_id"] = (
        feed_tables["stops"]["station_node_id"]
    )
    WranglerLogger.debug(
        f"After updating, feed_tables['stops'] with model_node_id set:\n"
        f"{feed_tables['stops']['model_node_id'].notna().sum():,}"
    )
    # Log feed_tables['stops'] for trace_shape_ids
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace feed_tables['stops'] for {trace_shape_id}:\n"
                f"{feed_tables['stops'].loc[feed_tables['stops']['shape_ids'].apply(lambda x, tid=trace_shape_id: isinstance(x, list) and tid in x)]}"
            )

    feed_tables["stops"].drop(columns=["station_node_id"], inplace=True)
    WranglerLogger.debug(f"feed_tables['stops']:\n{feed_tables['stops']}")
    # TODO: I think these are all parent nodes...
    WranglerLogger.debug(
        f"feed_tables['stops'] without model_node_id:\n{feed_tables['stops'].loc[feed_tables['stops']['model_node_id'].isna()]}"
    )

    # Update feed_table['shapes']: set shape_model_node_id for stations in feed_table['shapes'] and delete the other nodes.
    # Those are now in the shape of the roadway network link
    WranglerLogger.debug(
        f"About to update feed['shapes'] in add_stations_and_links_to_roadway_network:\n{feed_tables['shapes']}"
    )
    feed_tables["shapes"]["station_node_id"] = feed_tables["shapes"]["stop_id"].map(
        stop_id_to_model_node_id_dict
    )

    # Verify no shapes table stops have a shape_model_node_id already set *and* have a station_model_node_id
    have_both_df = feed_tables["shapes"].loc[
        feed_tables["shapes"]["shape_model_node_id"].notna()
        & feed_tables["shapes"]["station_node_id"].notna()
        & (
            feed_tables["shapes"]["shape_model_node_id"]
            != feed_tables["shapes"]["station_node_id"]
        )
    ]
    if len(have_both_df) > 0:
        WranglerLogger.fatal(f"have_both_df:\n{have_both_df}")
    assert len(have_both_df) == 0
    feed_tables["shapes"].loc[
        feed_tables["shapes"]["station_node_id"].notna(), "shape_model_node_id"
    ] = feed_tables["shapes"]["station_node_id"]
    feed_tables["shapes"].drop(columns=["station_node_id"], inplace=True)
    # Delete other nodes
    feed_tables["shapes"] = feed_tables["shapes"].loc[
        # leave non station route types (bus) alone -- these will be handled elsewhere
        ~feed_tables["shapes"]["route_type"].isin(STATION_ROUTE_TYPES)
        |
        # for station route types, only keep those with model_node_id set
        feed_tables["shapes"]["shape_model_node_id"].notna()
    ]

    # Log feed_tables['shapes'] for trace_shape_ids
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace feed_tables['shapes'] for {trace_shape_id} at the end of add_stations_and_links_to_roadway_network():\n"
                f"{feed_tables['shapes'].loc[feed_tables['shapes']['shape_id'] == trace_shape_id]}"
            )

    # set A for stop_id's model_node_id
    stop_links_gdf = pd.merge(
        left=stop_links_gdf,
        right=feed_tables["stops"][["stop_id", "model_node_id"]],
        how="left",
        validate="many_to_one",
    ).rename(columns={"model_node_id": "A"})
    # set B for next_stop_id's model_node_id
    stop_links_gdf = pd.merge(
        left=stop_links_gdf,
        right=feed_tables["stops"][["stop_id", "model_node_id"]].rename(
            columns={"stop_id": "next_stop_id"}
        ),
        how="left",
        validate="many_to_one",
    ).rename(columns={"model_node_id": "B"})
    WranglerLogger.debug(f"stop_links_gdf after setting A and B:\n{stop_links_gdf}")

    # Filter non-station stop links to return
    bus_stop_links_gdf = stop_links_gdf.loc[~stop_links_gdf.route_type.isin(STATION_ROUTE_TYPES)]

    return stop_id_to_model_node_id_dict, bus_stop_links_gdf

network_wrangler.utils.transit.add_unmatched_bus_stops_to_network

add_unmatched_bus_stops_to_network(
    feed_tables,
    roadway_net,
    local_crs,
    max_distance,
    trace_shape_ids=None,
    default_node_attribute_dict=None,
)

Add unmatched bus stops as new nodes in the roadway network.

Creates new roadway nodes for bus stops that couldn’t be matched to the bus-accessible network. Clusters nearby unmatched stops (e.g., at transit stations) and creates one node per cluster at the centroid location.

Process Steps: 1. Identifies unmatched bus stops (poor_match=True) 2. Clusters stops using max_distance threshold with DBSCAN 3. Calculates centroid for each cluster 4. Creates new roadway nodes at cluster centroids 5. Updates stops table with new node IDs and locations 6. Adds nodes to roadway network

Parameters:

Name Type Description Default
feed_tables dict[str, DataFrame]

dictionary of GTFS feed tables with ‘stops’ containing: - poor_match (bool): True for unmatched stops - model_node_id (int): Nearest bus-accessible node (for creating connector links) - Other stop attributes

required
roadway_net RoadwayNetwork

RoadwayNetwork to add nodes to

required
local_crs str

Coordinate reference system for projections (e.g., “EPSG:2227”)

required
max_distance float

Distance threshold in crs_units for clustering

required
trace_shape_ids list[str] | None

Optional list of shape_ids for debug logging

None
default_node_attribute_dict dict[str, any] | None

Optional dict of column-name to default value to set on new nodes.

None

Returns:

Type Description
GeoDataFrame

GeoDataFrame of added nodes with columns: - model_node_id (int): New node ID - X, Y (float): Node coordinates in lat/lon - geometry (Point): Node geometry - cluster_id (int): Cluster assignment - stop_ids (list): List of stop_ids in this cluster - stop_names (list): List of stop names in this cluster - nearest_bus_node (int): Nearest bus-accessible node for connectivity - is_transit_stop_node (bool): True (marks these as special transit nodes)

Notes
  • Only processes bus stops (not rail/ferry stations)
  • Clusters stops within max_distance of each other
  • One node created per cluster at centroid
  • Original GTFS stop locations preserved before updating
  • Modifies feed_tables[‘stops’] in place
  • Modifies roadway_net.nodes_df in place
Source code in network_wrangler/utils/transit.py
def add_unmatched_bus_stops_to_network(  # noqa: PLR0915
    feed_tables: dict[str, pd.DataFrame],
    roadway_net: RoadwayNetwork,
    local_crs: str,
    max_distance: float,
    trace_shape_ids: list[str] | None = None,
    default_node_attribute_dict: dict[str, any] | None = None,
) -> gpd.GeoDataFrame:
    """Add unmatched bus stops as new nodes in the roadway network.

    Creates new roadway nodes for bus stops that couldn't be matched to the bus-accessible
    network. Clusters nearby unmatched stops (e.g., at transit stations) and creates one
    node per cluster at the centroid location.

    Process Steps:
    1. Identifies unmatched bus stops (poor_match=True)
    2. Clusters stops using max_distance threshold with DBSCAN
    3. Calculates centroid for each cluster
    4. Creates new roadway nodes at cluster centroids
    5. Updates stops table with new node IDs and locations
    6. Adds nodes to roadway network

    Args:
        feed_tables: dictionary of GTFS feed tables with 'stops' containing:
            - poor_match (bool): True for unmatched stops
            - model_node_id (int): Nearest bus-accessible node (for creating connector links)
            - Other stop attributes
        roadway_net: RoadwayNetwork to add nodes to
        local_crs: Coordinate reference system for projections (e.g., "EPSG:2227")
        max_distance: Distance threshold in crs_units for clustering
        trace_shape_ids: Optional list of shape_ids for debug logging
        default_node_attribute_dict: Optional dict of column-name to default value to set on new nodes.

    Returns:
        GeoDataFrame of added nodes with columns:
            - model_node_id (int): New node ID
            - X, Y (float): Node coordinates in lat/lon
            - geometry (Point): Node geometry
            - cluster_id (int): Cluster assignment
            - stop_ids (list): List of stop_ids in this cluster
            - stop_names (list): List of stop names in this cluster
            - nearest_bus_node (int): Nearest bus-accessible node for connectivity
            - is_transit_stop_node (bool): True (marks these as special transit nodes)

    Notes:
        - Only processes bus stops (not rail/ferry stations)
        - Clusters stops within max_distance of each other
        - One node created per cluster at centroid
        - Original GTFS stop locations preserved before updating
        - Modifies feed_tables['stops'] in place
        - Modifies roadway_net.nodes_df in place
    """
    WranglerLogger.info("Adding unmatched bus stops to roadway network")

    # Get unmatched bus stops (poor_match=True means they have a model_node_id for nearest bus node)
    stops_df = feed_tables["stops"]
    unmatched_mask = (
        (stops_df["is_bus_stop"] == True)
        & (stops_df["poor_match"] == True)
        & (stops_df["model_node_id"].notna())
    )
    unmatched_stops_gdf = stops_df[unmatched_mask].copy()

    if len(unmatched_stops_gdf) == 0:
        WranglerLogger.info("No unmatched bus stops (poor_match=True) found; skipping")
        return gpd.GeoDataFrame()

    WranglerLogger.info(
        f"Processing {len(unmatched_stops_gdf)} unmatched bus stops "
        f"(clustering with max_distance={max_distance})"
    )

    # Ensure it's a GeoDataFrame
    if not isinstance(unmatched_stops_gdf, gpd.GeoDataFrame):
        unmatched_stops_gdf = gpd.GeoDataFrame(unmatched_stops_gdf, geometry="geometry")

    # Project to local CRS for clustering
    unmatched_stops_proj = unmatched_stops_gdf.to_crs(local_crs)

    # Extract coordinates for clustering
    coords = np.array([(geom.x, geom.y) for geom in unmatched_stops_proj.geometry])

    # Cluster using DBSCAN with max_distance threshold
    try:
        from sklearn.cluster import DBSCAN
    except ImportError as e:
        msg = "sklearn is required for clustering. Install with: pip install scikit-learn"
        raise ImportError(msg) from e

    # DBSCAN eps parameter is the maximum distance between two samples in a cluster
    # min_samples=1 means a single point can be a cluster
    clustering = DBSCAN(eps=max_distance, min_samples=1, metric="euclidean")
    unmatched_stops_proj["cluster_id"] = clustering.fit_predict(coords)

    WranglerLogger.info(
        f"Clustered {len(unmatched_stops_gdf)} unmatched stops into "
        f"{unmatched_stops_proj['cluster_id'].nunique()} clusters"
    )

    # Calculate centroid for each cluster
    cluster_centroids = []
    for cluster_id in unmatched_stops_proj["cluster_id"].unique():
        cluster_stops = unmatched_stops_proj[unmatched_stops_proj["cluster_id"] == cluster_id]

        # Calculate centroid in projected CRS
        centroid_x = cluster_stops.geometry.x.mean()
        centroid_y = cluster_stops.geometry.y.mean()
        centroid_geom_proj = shapely.geometry.Point(centroid_x, centroid_y)

        # Convert back to lat/lon
        centroid_gdf = gpd.GeoDataFrame({"geometry": [centroid_geom_proj]}, crs=local_crs).to_crs(
            LAT_LON_CRS
        )
        centroid_geom = centroid_gdf.geometry.iloc[0]

        # Get representative nearest bus node (use model_node_id from first stop in cluster)
        nearest_bus_node = cluster_stops.iloc[0]["model_node_id"]

        cluster_centroids.append(
            {
                "cluster_id": cluster_id,
                "geometry": centroid_geom,
                "X": centroid_geom.x,
                "Y": centroid_geom.y,
                "stop_id_GTFS": str(
                    list(set(cluster_stops["stop_id"].tolist()))
                ),  # make unique list
                "stop_name": str(list(set(cluster_stops["stop_name"].tolist()))),  # make unique
                "nearest_bus_node": nearest_bus_node,
                "is_transit_stop_node": True,
                "num_stops_in_cluster": len(cluster_stops),
            }
        )

    # Create GeoDataFrame of new nodes
    new_nodes_gdf = gpd.GeoDataFrame(cluster_centroids, crs=LAT_LON_CRS)
    WranglerLogger.debug(f"new_nodes_gdf:\n{new_nodes_gdf}")

    # Assign new model_node_ids
    # TODO: county numbering
    max_node_id = roadway_net.nodes_df["model_node_id"].max()
    new_nodes_gdf["model_node_id"] = range(max_node_id + 1, max_node_id + 1 + len(new_nodes_gdf))

    WranglerLogger.info(
        f"Creating {len(new_nodes_gdf)} new roadway nodes "
        f"(IDs {max_node_id + 1} to {max_node_id + len(new_nodes_gdf)})"
    )

    # Apply default node attributes if provided
    if default_node_attribute_dict:
        for attr, value in default_node_attribute_dict.items():
            new_nodes_gdf[attr] = value

    WranglerLogger.debug(
        f"Before adding nodes, roadway network has {len(roadway_net.nodes_df)} nodes"
    )
    roadway_net.add_nodes(new_nodes_gdf)
    WranglerLogger.debug(
        f"After adding nodes, roadway network has {len(roadway_net.nodes_df)} nodes"
    )

    # Update feed_tables['stops'] with new node IDs and locations
    # Create mapping from stop_id to new model_node_id
    stop_to_cluster = unmatched_stops_proj[["stop_id", "cluster_id"]].copy()
    cluster_to_node = new_nodes_gdf[["cluster_id", "model_node_id", "X", "Y", "geometry"]].copy()
    stop_to_node = stop_to_cluster.merge(cluster_to_node, on="cluster_id")
    WranglerLogger.debug(f"stop_to_node:\n{stop_to_node}")
    # Update stops table
    for _, row in stop_to_node.iterrows():
        stop_id = row["stop_id"]
        new_node_id = row["model_node_id"]
        new_x = row["X"]
        new_y = row["Y"]
        new_geom = row["geometry"]

        # Update in feed_tables['stops']
        mask = feed_tables["stops"]["stop_id"] == stop_id
        feed_tables["stops"].loc[mask, "model_node_id"] = new_node_id
        feed_tables["stops"].loc[mask, "stop_lon"] = new_x
        feed_tables["stops"].loc[mask, "stop_lat"] = new_y
        feed_tables["stops"].loc[mask, "geometry"] = new_geom

    WranglerLogger.info(f"Updated {len(stop_to_node)} stops to point to new cluster nodes")

    # Log for trace_shape_ids
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            trace_stops = feed_tables["stops"][
                feed_tables["stops"]["shape_ids"].apply(
                    lambda x, tid=trace_shape_id: isinstance(x, list) and tid in x
                )
            ]
            WranglerLogger.debug(f"trace_stops for {trace_shape_id}:\n{trace_stops}")

    return new_nodes_gdf

network_wrangler.utils.transit.assess_stop_name_roadway_compatibility

assess_stop_name_roadway_compatibility(
    stop_name,
    node_link_names,
    threshold=0.5,
    config=DefaultConfig,
)

Assess if a transit stop name is compatible with a roadway node’s link names.

Checks if street names in the stop name match any of the node’s connected link names. Handles common patterns like “Street1 & Street2” or “Street1 at Street2”.

Exact name matches receive a special high score of 10.0 to enable users to force specific stop-to-node matches by ensuring the stop name exactly matches a link name.

Parameters:

Name Type Description Default
stop_name str

Name of the transit stop (e.g., “Van Ness Ave & Market St”)

required
node_link_names list[str]

List of link names connected to the roadway node

required
threshold float

Minimum fraction of stop streets that must match node links (default 0.5)

0.5
config WranglerConfig

WranglerConfig with TRANSIT.MIN_SUBSTRING_MATCH_LENGTH setting.

DefaultConfig

Returns:

Type Description
tuple[bool, float, list[str]]

Tuple of (is_compatible, match_score, matched_streets) where: - is_compatible: True if stop name is compatible with node - match_score: 10.0 for exact match (to force selection), otherwise fraction of stop streets found in node links (0.0 to 1.0) - matched_streets: List of street names from stop that matched node links

Source code in network_wrangler/utils/transit.py
def assess_stop_name_roadway_compatibility(  # noqa: PLR0912
    stop_name: str,
    node_link_names: list[str],
    threshold: float = 0.5,
    config: WranglerConfig = DefaultConfig,
) -> tuple[bool, float, list[str]]:
    """Assess if a transit stop name is compatible with a roadway node's link names.

    Checks if street names in the stop name match any of the node's connected link names.
    Handles common patterns like "Street1 & Street2" or "Street1 at Street2".

    Exact name matches receive a special high score of 10.0 to enable users to force
    specific stop-to-node matches by ensuring the stop name exactly matches a link name.

    Args:
        stop_name: Name of the transit stop (e.g., "Van Ness Ave & Market St")
        node_link_names: List of link names connected to the roadway node
        threshold: Minimum fraction of stop streets that must match node links (default 0.5)
        config: WranglerConfig with TRANSIT.MIN_SUBSTRING_MATCH_LENGTH setting.

    Returns:
        Tuple of (is_compatible, match_score, matched_streets) where:
            - is_compatible: True if stop name is compatible with node
            - match_score: 10.0 for exact match (to force selection), otherwise fraction
                           of stop streets found in node links (0.0 to 1.0)
            - matched_streets: List of street names from stop that matched node links
    """
    import re

    if not stop_name or (node_link_names is None or len(node_link_names) == 0):
        return False, 0.0, []

    # Check for exact match first - allows users to force specific matches by ensuring
    # stop names and link names are identical (case-insensitive)
    stop_name_normalized = stop_name.lower().strip()
    for node_link in node_link_names:
        if node_link.lower().strip() == stop_name_normalized:
            # Exact match gets special high score to strongly prefer/force this match
            return True, 10.0, [stop_name]

    # Common separators in stop names
    separators = [" & ", " and ", " at ", " @ ", " / ", " near "]

    # Split stop name by separators to get individual street names
    stop_streets = [stop_name]
    for sep in separators:
        if sep in stop_name.lower():
            # Split and clean up each part
            parts = re.split(re.escape(sep), stop_name, flags=re.IGNORECASE)
            stop_streets = [part.strip() for part in parts if part.strip()]
            break

    # Normalize for comparison (lowercase, remove extra spaces)
    normalized_node_links = [link.lower().strip() for link in node_link_names]

    matched_streets = []
    for street in stop_streets:
        street_normalized = street.lower().strip()

        # Check for exact match or substring match
        for node_link in normalized_node_links:
            # Check if the street name is contained in the node link name or vice versa
            # Only do substring matching if both strings meet minimum length to avoid
            # spurious matches with single letters (e.g., "E" matching "Deer Creek")
            if (
                len(street_normalized) >= config.TRANSIT.MIN_SUBSTRING_MATCH_LENGTH
                and len(node_link) >= config.TRANSIT.MIN_SUBSTRING_MATCH_LENGTH
                and (street_normalized in node_link or node_link in street_normalized)
            ):
                matched_streets.append(street)
                break

            # Also check for partial matches (e.g., "Market St" matches "Market Street")
            # Remove common suffixes for comparison
            suffixes = [
                " street",
                " st",
                " avenue",
                " ave",
                " road",
                " rd",
                " boulevard",
                " blvd",
                " drive",
                " dr",
                " lane",
                " ln",
                " way",
                " court",
                " ct",
                " place",
                " pl",
                " parkway",
                " pkwy",
                " highway",
                " hwy",
            ]

            street_base = street_normalized
            for suffix in suffixes:
                if street_base.endswith(suffix):
                    street_base = street_base[: -len(suffix)].strip()
                    break

            node_base = node_link
            for suffix in suffixes:
                if node_base.endswith(suffix):
                    node_base = node_base[: -len(suffix)].strip()
                    break

            # Apply same minimum length requirement for suffix-removed matching
            if (
                street_base
                and node_base
                and len(street_base) >= config.TRANSIT.MIN_SUBSTRING_MATCH_LENGTH
                and len(node_base) >= config.TRANSIT.MIN_SUBSTRING_MATCH_LENGTH
                and (street_base in node_base or node_base in street_base)
            ):
                matched_streets.append(street)
                break

    # Calculate match score
    match_score = len(matched_streets) / len(stop_streets) if len(stop_streets) > 0 else 0.0

    is_compatible = match_score >= threshold

    return is_compatible, match_score, matched_streets

network_wrangler.utils.transit.calculate_path_deviation_from_shape

calculate_path_deviation_from_shape(
    path_nodes,
    original_shape_points,
    roadway_net,
    trace=False,
)

Calculate total deviation of a path from original shape points.

Creates a LineString from the path nodes and calculates the distance from each shape point to the nearest point on the line.

Parameters:

Name Type Description Default
path_nodes list

List of roadway node IDs in the path

required
original_shape_points DataFrame

DataFrame of original GTFS shape points

required
roadway_net

RoadwayNetwork to get node coordinates

required
trace bool

If True, enable trace logging for debugging

False

Returns:

Type Description
float

Total deviation distance (sum of distances from shape points to path line)

Source code in network_wrangler/utils/transit.py
def calculate_path_deviation_from_shape(
    path_nodes: list, original_shape_points: pd.DataFrame, roadway_net, trace: bool = False
) -> float:
    """Calculate total deviation of a path from original shape points.

    Creates a LineString from the path nodes and calculates the distance from each
    shape point to the nearest point on the line.

    Args:
        path_nodes: List of roadway node IDs in the path
        original_shape_points: DataFrame of original GTFS shape points
        roadway_net: RoadwayNetwork to get node coordinates
        trace: If True, enable trace logging for debugging

    Returns:
        Total deviation distance (sum of distances from shape points to path line)
    """
    if original_shape_points.empty or not path_nodes:
        return float("inf")

    try:
        from shapely.geometry import LineString, Point

        # Create LineString from path nodes
        path_coords = []
        for node_id in path_nodes:
            node_row = roadway_net.nodes_df[roadway_net.nodes_df["model_node_id"] == node_id]
            if not node_row.empty:
                path_coords.append((node_row.iloc[0]["X"], node_row.iloc[0]["Y"]))

        if len(path_coords) < 2:  # noqa: PLR2004
            return float("inf")

        path_line = LineString(path_coords)

        # Calculate total deviation - for each shape point, find distance to path line
        total_deviation = 0.0
        if trace:
            WranglerLogger.debug(
                f"Calculating path deviation for line with {len(path_nodes)} nodes against {len(original_shape_points)} shape points"
            )
            # Get link names for first and last nodes for debugging
            first_node_row = roadway_net.nodes_df[
                roadway_net.nodes_df["model_node_id"] == path_nodes[0]
            ]
            last_node_row = roadway_net.nodes_df[
                roadway_net.nodes_df["model_node_id"] == path_nodes[-1]
            ]
            first_link_names = (
                first_node_row.iloc[0].get("link_names", [])
                if not first_node_row.empty and "link_names" in first_node_row.columns
                else []
            )
            last_link_names = (
                last_node_row.iloc[0].get("link_names", [])
                if not last_node_row.empty and "link_names" in last_node_row.columns
                else []
            )
            WranglerLogger.debug(
                f"  Path from node {path_nodes[0]} ({first_link_names}) to node {path_nodes[-1]} ({last_link_names})"
            )

        for idx, shape_row in original_shape_points.iterrows():
            shape_point = Point(shape_row["shape_pt_lon"], shape_row["shape_pt_lat"])
            # Distance from point to nearest point on line
            dist = shape_point.distance(path_line)
            total_deviation += dist

            if (
                trace and idx % max(1, len(original_shape_points) // 5) == 0
            ):  # Log every ~20% of points
                WranglerLogger.debug(f"  Shape point {idx} distance to path: {dist:.6f}")

        if trace:
            WranglerLogger.debug(
                f"Total path deviation: {total_deviation:.6f} (avg per point: {total_deviation / len(original_shape_points):.6f})"
            )

        return total_deviation
    except Exception as e:
        if trace:
            WranglerLogger.debug(f"Error calculating path deviation: {e}")
        return float("inf")
create_connector_links_for_poor_match_stops(
    roadway_net,
    unmatched_stops_gdf,
    local_crs,
    crs_units,
    trace_shape_ids=None,
    default_link_attribute_dict=None,
)

Create connector links between poor match bus stop nodes and nearest bus-accessible nodes.

Creates bidirectional bus-only connector links in the roadway network to enable routing through bus stops that couldn’t be matched directly to existing roadway nodes. These are typically stops that are too far from the road network (poor_match stops added as new nodes).

Parameters:

Name Type Description Default
roadway_net RoadwayNetwork

RoadwayNetwork to add connector links to

required
unmatched_stops_gdf GeoDataFrame

GeoDataFrame of unmatched stops with columns: - model_node_id (int): New transit stop node ID - nearest_bus_node (int): Nearest bus-accessible node for connectivity - stop_ids, stop_names: Stop information - geometry (Point): Stop location

required
local_crs str

Coordinate reference system for distance calculations

required
crs_units str

Distance units (‘feet’ or ‘meters’)

required
trace_shape_ids list[str] | None

Optional list of shape_ids for debug logging

None
default_link_attribute_dict dict[str, any] | None

Optional dict of column-name to default value to set on new links.

None
Notes
  • Creates bidirectional links (forward and reverse) for each stop
  • Links are marked with ref=”unmatched_bus_stop” for identification
  • All links are bus_only=True, no other mode access
  • Modifies roadway_net.links_df and roadway_net.shapes in place
Source code in network_wrangler/utils/transit.py
def create_connector_links_for_poor_match_stops(
    roadway_net: RoadwayNetwork,
    unmatched_stops_gdf: gpd.GeoDataFrame,
    local_crs: str,
    crs_units: str,
    trace_shape_ids: list[str] | None = None,  # noqa: ARG001
    default_link_attribute_dict: dict[str, any] | None = None,
):
    """Create connector links between poor match bus stop nodes and nearest bus-accessible nodes.

    Creates bidirectional bus-only connector links in the roadway network to enable routing
    through bus stops that couldn't be matched directly to existing roadway nodes. These are
    typically stops that are too far from the road network (poor_match stops added as new nodes).

    Args:
        roadway_net: RoadwayNetwork to add connector links to
        unmatched_stops_gdf: GeoDataFrame of unmatched stops with columns:
            - model_node_id (int): New transit stop node ID
            - nearest_bus_node (int): Nearest bus-accessible node for connectivity
            - stop_ids, stop_names: Stop information
            - geometry (Point): Stop location
        local_crs: Coordinate reference system for distance calculations
        crs_units: Distance units ('feet' or 'meters')
        trace_shape_ids: Optional list of shape_ids for debug logging
        default_link_attribute_dict: Optional dict of column-name to default value to set on new links.

    Notes:
        - Creates bidirectional links (forward and reverse) for each stop
        - Links are marked with ref="unmatched_bus_stop" for identification
        - All links are bus_only=True, no other mode access
        - Modifies roadway_net.links_df and roadway_net.shapes in place
    """
    if unmatched_stops_gdf is None or len(unmatched_stops_gdf) == 0:
        WranglerLogger.info("No unmatched stops to create connector links for")
        return

    WranglerLogger.info(
        f"Creating connector links for {len(unmatched_stops_gdf)} unmatched stop clusters"
    )
    WranglerLogger.debug(f"unmatched_stops_gdf:\n{unmatched_stops_gdf}")

    # Create bidirectional links between each new stop node and its nearest bus node
    connector_links = []
    for _, stop_node in unmatched_stops_gdf.iterrows():
        stop_node_id = stop_node["model_node_id"]
        nearest_node_id = stop_node["nearest_bus_node"]

        # Get node geometries
        stop_geom = stop_node["geometry"]
        nearest_node_row = roadway_net.nodes_df[
            roadway_net.nodes_df["model_node_id"] == nearest_node_id
        ]
        if len(nearest_node_row) == 0:
            WranglerLogger.warning(
                f"Could not find nearest_bus_node {nearest_node_id} for stop node {stop_node_id}"
            )
            continue
        nearest_geom = nearest_node_row.iloc[0]["geometry"]

        # Create LineString geometry for the link
        link_geom = shapely.geometry.LineString(
            [(stop_geom.x, stop_geom.y), (nearest_geom.x, nearest_geom.y)]
        )

        # Forward link: stop -> nearest
        connector_links.append(
            {
                "A": stop_node_id,
                "B": nearest_node_id,
                "geometry": link_geom,
                "name": f"unmatched_stop_{stop_node_id}_connector",
                "ref": "unmatched_bus_stop",
            }
        )

        # Reverse link: nearest -> stop
        connector_links.append(
            {
                "A": nearest_node_id,
                "B": stop_node_id,
                "geometry": link_geom,
                "name": f"unmatched_stop_{stop_node_id}_connector_reverse",
                "ref": "unmatched_bus_stop",
            }
        )

    if len(connector_links) == 0:
        WranglerLogger.info("No valid connector links created")
        return

    connector_links_gdf = gpd.GeoDataFrame(connector_links, crs=LAT_LON_CRS)
    connector_links_gdf["shape_id"] = (
        connector_links_gdf["A"].astype(str) + " to " + connector_links_gdf["B"].astype(str)
    )
    connector_links_gdf["name"] = "bus stop connector"

    # Set transit access attributes
    connector_links_gdf["bus_only"] = True
    connector_links_gdf["rail_only"] = False
    connector_links_gdf["ferry_only"] = False
    connector_links_gdf["drive_access"] = True

    # No non-transit access
    connector_links_gdf["truck_access"] = False
    connector_links_gdf["bike_access"] = False
    connector_links_gdf["walk_access"] = False

    # Other attributes
    connector_links_gdf["roadway"] = "transit"
    connector_links_gdf["lanes"] = 1
    connector_links_gdf["managed"] = 0

    # Calculate distance
    connector_links_gdf.to_crs(local_crs, inplace=True)
    connector_links_gdf["length"] = connector_links_gdf.length
    if crs_units == "feet":
        connector_links_gdf["distance"] = connector_links_gdf["length"] / FEET_PER_MILE
    else:
        connector_links_gdf["distance"] = connector_links_gdf["length"] / METERS_PER_KILOMETER
    connector_links_gdf.to_crs(LAT_LON_CRS, inplace=True)

    # Assign model_link_ids
    max_model_link_id = roadway_net.links_df.model_link_id.max()
    connector_links_gdf["model_link_id"] = range(
        max_model_link_id + 1, max_model_link_id + 1 + len(connector_links_gdf)
    )

    WranglerLogger.info(
        f"Adding {len(connector_links_gdf)} connector links for unmatched stops "
        f"(IDs {max_model_link_id + 1} to {max_model_link_id + len(connector_links_gdf)})"
    )
    WranglerLogger.debug(f"connector_links_gdf:\n{connector_links_gdf}")

    # Apply default link attributes
    if default_link_attribute_dict is None:
        default_link_attribute_dict = {}
    for colname, default_value in default_link_attribute_dict.items():
        connector_links_gdf[colname] = default_value

    # Add to roadway network
    roadway_net.add_links(connector_links_gdf)
    roadway_net.add_shapes(connector_links_gdf)

network_wrangler.utils.transit.create_feed_from_gtfs_model

create_feed_from_gtfs_model(
    gtfs_model,
    roadway_net,
    local_crs,
    crs_units,
    timeperiods,
    frequency_method,
    default_frequency_for_onetime_route=180,
    add_stations_and_links=True,
    max_stop_distance=None,
    trace_shape_ids=None,
    errors="raise",
    default_node_attribute_dict=None,
    default_link_attribute_dict=None,
)

Convert GTFS model to Wrangler Feed with stops mapped to roadway network.

Comprehensive conversion that transforms GTFS schedule data into a frequency-based Feed representation compatible with travel modeling. Maps transit stops to roadway nodes and optionally adds station infrastructure to the network.

Process Steps: 1. Prepare roadway network: - Convert roadway_net.nodes_df to GeoDataFrame if needed - Create Point geometries from X, Y coordinates - Set CRS to LAT_LON_CRS (EPSG:4326) - Modifies roadway_net.nodes_df in place

  1. Copy GTFS tables to feed_tables dictionary:
  2. Copy routes, trips, agencies, stops, stop_times, shapes from gtfs_model
  3. Convert stops to GeoDataFrame with Point geometries from stop_lon/stop_lat
  4. Creates feed_tables dict for all subsequent operations

  5. Enrich stops with route/agency metadata:

  6. Calls: add_additional_data_to_stops()
  7. Joins route and agency information to each stop via stop_times and trips
  8. Adds columns: agency_ids, agency_names, route_ids, route_names, route_types, shape_ids, is_parent, is_bus_stop
  9. Modifies feed_tables[‘stops’] in place

  10. Create frequency-based schedules from timetables:

  11. Calls: [create_feed_frequencies()][network_wrangler.models.gtfs.converters.create_feed_frequencies]
  12. Converts GTFS trip-based schedules to frequency-based representation
  13. Groups trips by stop pattern (shape_id) and time period
  14. Calculates headways using specified method (uniform/mean/median)
  15. Creates one representative trip per shape_id
  16. Creates feed_tables[‘frequencies’] table
  17. Modifies: feed_tables[‘stop_times’] (adds departure_minutes), feed_tables[‘trips’] (one row per shape_id)

  18. Match stops to shape points and enrich shapes:

  19. Calls: add_additional_data_to_shapes()
  20. For each shape_id, processes stops in sequence order
  21. For each stop:
    • Match: Find nearest existing shape point within threshold (forward-only search)
    • Insert: If no match, create new shape point at stop location
  22. Uses local minimum matching to handle routes that double back
  23. Renumbers shape_pt_sequence if duplicates or non-integers detected
  24. Writes debug_shapes.geojson with stop matching information
  25. Calls helpers: _match_stop_to_shape_points(), _insert_stop_into_shape(), _align_shape_with_stops(), _write_debug_shapes()
  26. Modifies feed_tables[‘shapes’]: adds stop_id, stop_name, stop_sequence, match_distance_{crs_units}, poor_match

  27. Match bus stops to roadway nodes:

  28. Calls: match_bus_stops_to_roadway_nodes()
  29. Gets bus modal graph from roadway network (bus-accessible nodes only)
  30. For each bus stop:
    • Finds K nearest bus-accessible nodes using BallTree spatial index
    • If use_name_matching=True: scores by distance + name compatibility (combined_score = 0.1 * normalized_dist + 0.9 * (1 - name_score))
    • Selects best match within max_distance threshold
  31. If name matching enabled: marks stops with combined_score > 0.9 as poor_match:
    • Sets poor_match = True (does not update location yet)
    • Keeps model_node_id as nearest bus-accessible node (for connector links)
    • These stops will be added to network in step 6a
  32. If name matching disabled: poor_match = False for all stops
  33. Updates stop locations to matched node positions (except poor_match stops)
  34. Modifies feed_tables[‘stops’]: adds model_node_id, match_distance_{crs_units}, close_match, poor_match (always added, but only True when name matching enabled), node_link_names, name_match_score, normalized_dist, combined_score
  35. Modifies feed_tables[‘shapes’]: updates bus stop shape points to node locations, adds poor_match flag
  36. Modifies roadway_net.nodes_df: adds bus_access column

6a. Add unmatched bus stops to network: - Calls: add_unmatched_bus_stops_to_network() - Identifies unmatched bus stops (poor_match=True) - Clusters nearby stops using DBSCAN (max_distance threshold) - Creates new roadway nodes at cluster centroids - One node per cluster for grouped transit stations - Updates feed_tables[‘stops’] with new node IDs - Modifies roadway_net.nodes_df: adds new transit stop nodes with is_transit_stop_node flag

  1. Add rail/ferry stations and links to roadway network:
  2. Calls: add_stations_and_links_to_roadway_network()
  3. For STATION_ROUTE_TYPES (rail, light rail, subway, etc.):
    • Creates new nodes for each station
    • Adds dedicated transit links between consecutive stations
    • Links follow original GTFS shape geometry
  4. For BUS/TROLLEYBUS routes:
    • Creates bus_stop_links_gdf with consecutive stop pairs (A->B)
    • Includes route metadata (route_id, trip_id, direction_id, etc.)
    • No new nodes/links added yet (handled in next step)
  5. Returns station_id_to_model_node_id_dict and bus_stop_links_gdf
  6. Modifies roadway_net.nodes_df: adds station nodes
  7. Modifies roadway_net.links_df: adds station-to-station links
  8. Modifies feed_tables[‘stops’]: updates station stops with new model_node_ids
  9. Modifies feed_tables[‘shapes’]: removes intermediate shape points for station routes, keeps only station stop points

7a. Create connector links for unmatched bus stops: - Calls: create_connector_links_for_poor_match_stops() - Creates bidirectional bus-only links between: - New transit stop nodes (from step 6a) to their nearest bus-accessible node - Links marked with ref=”unmatched_bus_stop” - Enables routing through these previously unmatched stops - Modifies roadway_net.links_df: adds connector links - Modifies roadway_net.shapes: adds link geometries

  1. Route bus services through road network:
  2. Calls: route_shapes_between_stops()
  3. Gets bus modal graph (DiGraph for pathfinding)
  4. For each consecutive bus stop pair (A->B):
    • Check: If either node not in bus graph (e.g., poor_match):
      • Add direct A->B connection to bus_node_sequence
      • Add to no_path_sequence for special handling
      • Skip pathfinding
    • Find path: Use NetworkX shortest path through bus network
      • Optional: Shape-aware routing (prefers paths close to original shape)
    • Create shape points: Add all intermediate nodes in path as shape points
  5. Handles exceptions (NetworkXNoPath, NodeNotFound): adds to no_path_sequence
  6. If errors=’raise’ and no_path_sequence not empty: raises TransitValidationError
  7. If errors=’ignore’ and no_path_sequence not empty: calls create_links_for_failed_bus_paths() to create special bus-only links (marked ref=”bad_bus_path”) for failed routing segments
  8. Calls helpers: get_original_shape_points_between_stops(), find_shape_aware_shortest_path()
  9. Modifies feed_tables[‘shapes’]: replaces bus route shapes with routed paths through road network, adds shape_model_node_id from roadway

  10. Consolidate duplicate stops mapped to same node:

  11. Renames stop_id -> stop_id_GTFS (original GTFS IDs)
  12. Renames model_node_id -> stop_id (now uses network node IDs as stop IDs)
  13. Multiple GTFS stops may map to same network node
  14. Groups by stop_id (model_node_id) and aggregates:
    • Converts singular fields to lists (stop_id_GTFS, stop_name, etc.)
    • Takes first geometry/location
    • Merges route/agency lists (flattens and deduplicates)
  15. Creates stop_id_to_model_node_id_dict mapping GTFS stop_id -> model_node_id
  16. Modifies feed_tables[‘stops’]: consolidated rows, one per unique network node

  17. Update stop references and create Feed object:

    • Updates feed_tables[‘stop_times’]: maps stop_id_GTFS -> stop_id (model_node_id)
    • Converts stop_times to Wrangler format
    • Creates Feed object with all processed tables:
      • routes, trips, agencies: from GTFS
      • stops: consolidated by model_node_id with metadata
      • stop_times: with updated stop_id references
      • shapes: routed through road network with shape_model_node_id
      • frequencies: frequency-based schedules by time period
    • Returns Feed object ready for network modeling

Parameters:

Name Type Description Default
gtfs_model GtfsModel

Source GTFS data model

required
roadway_net RoadwayNetwork

Target roadway network for stop mapping

required
local_crs str

Coordinate system for distance calculations

required
crs_units str

Distance units (‘feet’ or ‘meters’)

required
timeperiods dict[str, tuple[str, str]]

Time period definitions for frequencies Example: {‘EA’: (‘03:00’,‘06:00’), ‘AM’: (‘06:00’,‘10:00’)}

required
frequency_method Literal['uniform_headway', 'mean_headway', 'median_headway']

How to calculate headways (‘uniform_headway’, ‘mean_headway’, or ‘median_headway’)

required
default_frequency_for_onetime_route int

Default headway in minutes for routes with one trip per period (default: 180)

180
add_stations_and_links bool

If True, add stations to roadway network (recommended, False not implemented)

True
max_stop_distance float | None

Maximum distance in crs_units for matching bus stops to roadway nodes. If None, uses default MAX_DISTANCE_STOP values

None
trace_shape_ids list[str] | None

Shape IDs for detailed debug logging

None
errors Literal['raise', 'ignore']

How to handle routing errors (‘raise’ or ‘ignore’)

'raise'
default_node_attribute_dict dict[str, any] | None

node attributes to set for new transit nodes. Defaults to None.

None
default_link_attribute_dict dict[str, any] | None

link attributes to set for new transit links. Defaults to None.

None

Returns:

Name Type Description
Feed Feed

Wrangler Feed object with: - Stops mapped to roadway nodes - Frequency-based trip representation - Routes following road network paths

Raises:

Type Description
TransitValidationError

If bus stops can’t be matched to roadway

NodeNotFoundError

If required nodes aren’t found

Notes
  • Bus routes are re-routed through actual road network
  • Station routes keep original alignment with new nodes/links
Source code in network_wrangler/utils/transit.py
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
3549
3550
3551
3552
3553
3554
3555
3556
3557
3558
3559
3560
3561
3562
3563
3564
3565
3566
3567
3568
3569
3570
3571
3572
3573
3574
3575
3576
3577
3578
3579
3580
3581
3582
3583
3584
3585
3586
3587
3588
3589
3590
3591
3592
3593
3594
3595
3596
3597
3598
3599
3600
3601
3602
3603
3604
3605
3606
3607
3608
3609
3610
3611
3612
3613
3614
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
3628
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
3642
3643
3644
3645
3646
3647
3648
3649
3650
3651
3652
3653
3654
3655
3656
3657
3658
3659
3660
3661
3662
3663
3664
3665
3666
3667
3668
3669
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
3683
3684
3685
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
3705
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
3725
3726
3727
3728
3729
3730
3731
3732
3733
3734
3735
3736
3737
3738
3739
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
3761
3762
def create_feed_from_gtfs_model(  # noqa: PLR0912, PLR0915
    gtfs_model: GtfsModel,
    roadway_net: RoadwayNetwork,
    local_crs: str,
    crs_units: str,
    timeperiods: dict[str, tuple[str, str]],
    frequency_method: Literal["uniform_headway", "mean_headway", "median_headway"],
    default_frequency_for_onetime_route: int = 180,
    add_stations_and_links: bool = True,
    max_stop_distance: float | None = None,
    trace_shape_ids: list[str] | None = None,
    errors: Literal["raise", "ignore"] = "raise",
    default_node_attribute_dict: dict[str, any] | None = None,
    default_link_attribute_dict: dict[str, any] | None = None,
) -> Feed:
    """Convert GTFS model to Wrangler Feed with stops mapped to roadway network.

    Comprehensive conversion that transforms GTFS schedule data into a frequency-based
    Feed representation compatible with travel modeling. Maps transit stops to roadway
    nodes and optionally adds station infrastructure to the network.

    Process Steps:
    1. Prepare roadway network:
       - Convert roadway_net.nodes_df to GeoDataFrame if needed
       - Create Point geometries from X, Y coordinates
       - Set CRS to LAT_LON_CRS (EPSG:4326)
       - Modifies roadway_net.nodes_df in place

    2. Copy GTFS tables to feed_tables dictionary:
       - Copy routes, trips, agencies, stops, stop_times, shapes from gtfs_model
       - Convert stops to GeoDataFrame with Point geometries from stop_lon/stop_lat
       - Creates feed_tables dict for all subsequent operations

    3. Enrich stops with route/agency metadata:
       - Calls: [`add_additional_data_to_stops()`][network_wrangler.utils.transit.add_additional_data_to_stops]
       - Joins route and agency information to each stop via stop_times and trips
       - Adds columns: agency_ids, agency_names, route_ids, route_names, route_types,
         shape_ids, is_parent, is_bus_stop
       - Modifies feed_tables['stops'] in place

    4. Create frequency-based schedules from timetables:
       - Calls: [`create_feed_frequencies()`][network_wrangler.models.gtfs.converters.create_feed_frequencies]
       - Converts GTFS trip-based schedules to frequency-based representation
       - Groups trips by stop pattern (shape_id) and time period
       - Calculates headways using specified method (uniform/mean/median)
       - Creates one representative trip per shape_id
       - Creates feed_tables['frequencies'] table
       - Modifies: feed_tables['stop_times'] (adds departure_minutes),
         feed_tables['trips'] (one row per shape_id)

    5. Match stops to shape points and enrich shapes:
       - Calls: [`add_additional_data_to_shapes()`][network_wrangler.utils.transit.add_additional_data_to_shapes]
       - For each shape_id, processes stops in sequence order
       - For each stop:
           - Match: Find nearest existing shape point within threshold (forward-only search)
           - Insert: If no match, create new shape point at stop location
       - Uses local minimum matching to handle routes that double back
       - Renumbers shape_pt_sequence if duplicates or non-integers detected
       - Writes debug_shapes.geojson with stop matching information
       - Calls helpers: `_match_stop_to_shape_points()`, `_insert_stop_into_shape()`,
         `_align_shape_with_stops()`, `_write_debug_shapes()`
       - Modifies feed_tables['shapes']: adds stop_id, stop_name, stop_sequence,
         match_distance_{crs_units}, poor_match

    6. Match bus stops to roadway nodes:
       - Calls: [`match_bus_stops_to_roadway_nodes()`][network_wrangler.utils.transit.match_bus_stops_to_roadway_nodes]
       - Gets bus modal graph from roadway network (bus-accessible nodes only)
       - For each bus stop:
           - Finds K nearest bus-accessible nodes using BallTree spatial index
           - If use_name_matching=True: scores by distance + name compatibility
             (combined_score = 0.1 * normalized_dist + 0.9 * (1 - name_score))
           - Selects best match within max_distance threshold
       - If name matching enabled: marks stops with combined_score > 0.9 as poor_match:
           - Sets poor_match = True (does not update location yet)
           - Keeps model_node_id as nearest bus-accessible node (for connector links)
           - These stops will be added to network in step 6a
       - If name matching disabled: poor_match = False for all stops
       - Updates stop locations to matched node positions (except poor_match stops)
       - Modifies feed_tables['stops']: adds model_node_id, match_distance_{crs_units},
         close_match, poor_match (always added, but only True when name matching enabled),
         node_link_names, name_match_score, normalized_dist, combined_score
       - Modifies feed_tables['shapes']: updates bus stop shape points to node locations,
         adds poor_match flag
       - Modifies roadway_net.nodes_df: adds bus_access column

    6a. Add unmatched bus stops to network:
        - Calls: [`add_unmatched_bus_stops_to_network()`][network_wrangler.utils.transit.add_unmatched_bus_stops_to_network]
        - Identifies unmatched bus stops (poor_match=True)
        - Clusters nearby stops using DBSCAN (max_distance threshold)
        - Creates new roadway nodes at cluster centroids
        - One node per cluster for grouped transit stations
        - Updates feed_tables['stops'] with new node IDs
        - Modifies roadway_net.nodes_df: adds new transit stop nodes with is_transit_stop_node flag

    7. Add rail/ferry stations and links to roadway network:
       - Calls: [`add_stations_and_links_to_roadway_network()`][network_wrangler.utils.transit.add_stations_and_links_to_roadway_network]
       - For STATION_ROUTE_TYPES (rail, light rail, subway, etc.):
           - Creates new nodes for each station
           - Adds dedicated transit links between consecutive stations
           - Links follow original GTFS shape geometry
       - For BUS/TROLLEYBUS routes:
           - Creates bus_stop_links_gdf with consecutive stop pairs (A->B)
           - Includes route metadata (route_id, trip_id, direction_id, etc.)
           - No new nodes/links added yet (handled in next step)
       - Returns station_id_to_model_node_id_dict and bus_stop_links_gdf
       - Modifies roadway_net.nodes_df: adds station nodes
       - Modifies roadway_net.links_df: adds station-to-station links
       - Modifies feed_tables['stops']: updates station stops with new model_node_ids
       - Modifies feed_tables['shapes']: removes intermediate shape points for station routes,
         keeps only station stop points

    7a. Create connector links for unmatched bus stops:
        - Calls: [`create_connector_links_for_poor_match_stops()`][network_wrangler.utils.transit.create_connector_links_for_poor_match_stops]
        - Creates bidirectional bus-only links between:
            - New transit stop nodes (from step 6a) to their nearest bus-accessible node
        - Links marked with ref="unmatched_bus_stop"
        - Enables routing through these previously unmatched stops
        - Modifies roadway_net.links_df: adds connector links
        - Modifies roadway_net.shapes: adds link geometries

    8. Route bus services through road network:
       - Calls: [`route_shapes_between_stops()`][network_wrangler.utils.transit.route_shapes_between_stops]
       - Gets bus modal graph (DiGraph for pathfinding)
       - For each consecutive bus stop pair (A->B):
           - Check: If either node not in bus graph (e.g., poor_match):
               - Add direct A->B connection to bus_node_sequence
               - Add to no_path_sequence for special handling
               - Skip pathfinding
           - Find path: Use NetworkX shortest path through bus network
               - Optional: Shape-aware routing (prefers paths close to original shape)
           - Create shape points: Add all intermediate nodes in path as shape points
       - Handles exceptions (NetworkXNoPath, NodeNotFound): adds to no_path_sequence
       - If errors='raise' and no_path_sequence not empty: raises TransitValidationError
       - If errors='ignore' and no_path_sequence not empty: calls create_links_for_failed_bus_paths()
         to create special bus-only links (marked ref="bad_bus_path") for failed routing segments
       - Calls helpers: `get_original_shape_points_between_stops()`,
         `find_shape_aware_shortest_path()`
       - Modifies feed_tables['shapes']: replaces bus route shapes with routed paths
         through road network, adds shape_model_node_id from roadway

    9. Consolidate duplicate stops mapped to same node:
       - Renames stop_id -> stop_id_GTFS (original GTFS IDs)
       - Renames model_node_id -> stop_id (now uses network node IDs as stop IDs)
       - Multiple GTFS stops may map to same network node
       - Groups by stop_id (model_node_id) and aggregates:
           - Converts singular fields to lists (stop_id_GTFS, stop_name, etc.)
           - Takes first geometry/location
           - Merges route/agency lists (flattens and deduplicates)
       - Creates stop_id_to_model_node_id_dict mapping GTFS stop_id -> model_node_id
       - Modifies feed_tables['stops']: consolidated rows, one per unique network node

    10. Update stop references and create Feed object:
        - Updates feed_tables['stop_times']: maps stop_id_GTFS -> stop_id (model_node_id)
        - Converts stop_times to Wrangler format
        - Creates Feed object with all processed tables:
            - routes, trips, agencies: from GTFS
            - stops: consolidated by model_node_id with metadata
            - stop_times: with updated stop_id references
            - shapes: routed through road network with shape_model_node_id
            - frequencies: frequency-based schedules by time period
        - Returns Feed object ready for network modeling

    Args:
        gtfs_model: Source GTFS data model
        roadway_net: Target roadway network for stop mapping
        local_crs: Coordinate system for distance calculations
        crs_units: Distance units ('feet' or 'meters')
        timeperiods: Time period definitions for frequencies
            Example: {'EA': ('03:00','06:00'), 'AM': ('06:00','10:00')}
        frequency_method: How to calculate headways
            ('uniform_headway', 'mean_headway', or 'median_headway')
        default_frequency_for_onetime_route: Default headway in minutes
            for routes with one trip per period (default: 180)
        add_stations_and_links: If True, add stations to roadway network
            (recommended, False not implemented)
        max_stop_distance: Maximum distance in crs_units for matching bus stops
            to roadway nodes. If None, uses default MAX_DISTANCE_STOP values
        trace_shape_ids: Shape IDs for detailed debug logging
        errors: How to handle routing errors ('raise' or 'ignore')
        default_node_attribute_dict: node attributes to set for new transit nodes.
            Defaults to None.
        default_link_attribute_dict: link attributes to set for new transit links.
            Defaults to None.

    Returns:
        Feed: Wrangler Feed object with:
            - Stops mapped to roadway nodes
            - Frequency-based trip representation
            - Routes following road network paths

    Raises:
        TransitValidationError: If bus stops can't be matched to roadway
        NodeNotFoundError: If required nodes aren't found

    Notes:
        - Bus routes are re-routed through actual road network
        - Station routes keep original alignment with new nodes/links
    """
    WranglerLogger.debug(f"create_feed_from_gtfsmodel()")
    if crs_units not in ["feet", "meters"]:
        msg = f"crs_units must be one of 'feet' or 'meters'; received {crs_units}"
        raise ValueError(msg)

    # Convert roadway_net.nodes_df GeoDataFrame if needed (modifying in place)
    if not isinstance(roadway_net.nodes_df, gpd.GeoDataFrame):
        if "geometry" not in roadway_net.nodes_df.columns:
            node_geometry = [
                shapely.geometry.Point(x, y)
                for x, y in zip(roadway_net.nodes_df["X"], roadway_net.nodes_df["Y"], strict=False)
            ]
            roadway_net.nodes_df = gpd.GeoDataFrame(
                roadway_net.nodes_df, geometry=node_geometry, crs=LAT_LON_CRS
            )
        else:
            roadway_net.nodes_df = gpd.GeoDataFrame(roadway_net.nodes_df, crs=LAT_LON_CRS)
    elif roadway_net.nodes_df.crs is None:
        roadway_net.nodes_df = roadway_net.nodes_df.set_crs(LAT_LON_CRS)

    # Start with the tables from the GTFS model
    feed_tables = {}

    # Copy over standard tables that don't need modification
    # GtfsModel guarantees routes and trips exist
    # create a copies of this which we'll manipulate
    feed_tables["routes"] = gtfs_model.routes.copy()
    feed_tables["trips"] = gtfs_model.trips.copy()
    feed_tables["agencies"] = gtfs_model.agency.copy()
    feed_tables["stops"] = gtfs_model.stops.copy()
    feed_tables["stop_times"] = gtfs_model.stop_times.copy()
    feed_tables["shapes"] = gtfs_model.shapes.copy()

    # create mapping from gtfs_model stop to RoadwayNetwork nodes
    # GtfsModel guarantees stops exists
    if not isinstance(feed_tables["stops"], gpd.GeoDataFrame):
        stop_geometry = [
            shapely.geometry.Point(lon, lat)
            for lon, lat in zip(
                gtfs_model.stops["stop_lon"], gtfs_model.stops["stop_lat"], strict=False
            )
        ]
        feed_tables["stops"] = gpd.GeoDataFrame(
            feed_tables["stops"], geometry=stop_geometry, crs=LAT_LON_CRS
        )

    # Add helpful extra data to stops table
    add_additional_data_to_stops(feed_tables)

    # create frequencies table from GTFS stop_times (if no frequencies table is specified)
    if hasattr(gtfs_model, "frequencies") and gtfs_model.frequencies is not None:
        feed_tables["frequencies"] = gtfs_model.frequencies
        # TODO: What if the frequencies are specified for the wrong time periods?
    else:
        # GtfsModel specifies every individual trip but Feed expects the trip to be
        # representative with frequencies. This makes that conversion
        create_feed_frequencies(
            feed_tables,
            timeperiods,
            frequency_method,
            default_frequency_for_onetime_route,
            trace_shape_ids,
        )

    if not add_stations_and_links:
        msg = "create_feed_from_gtfs_feed() doesn't implement add_stations_and_links==False."
        raise NotImplementedError(msg)

    # Add helpful extra data to shapes table
    add_additional_data_to_shapes(feed_tables, local_crs, crs_units, trace_shape_ids)

    # Use provided max_stop_distance or default
    if max_stop_distance is None:
        max_stop_distance = (
            DefaultConfig.TRANSIT.MAX_DISTANCE_STOP_FEET
            if crs_units == "feet"
            else DefaultConfig.TRANSIT.MAX_DISTANCE_STOP_METERS
        )

    match_bus_stops_to_roadway_nodes(
        feed_tables,
        roadway_net,
        local_crs,
        crs_units,
        max_stop_distance,
        trace_shape_ids,
        use_name_matching=True,  # Use name matching when available
    )

    # Add unmatched bus stops as new nodes in the roadway network
    unmatched_stops_nodes_gdf = add_unmatched_bus_stops_to_network(
        feed_tables,
        roadway_net,
        local_crs,
        max_stop_distance,
        trace_shape_ids,
        default_node_attribute_dict,
    )

    # for fixed route transit, add the links and stops to the roadway network
    _, bus_stop_links_gdf = add_stations_and_links_to_roadway_network(
        feed_tables,
        roadway_net,
        local_crs,
        crs_units,
        trace_shape_ids,
        default_node_attribute_dict,
        default_link_attribute_dict,
    )

    # Create connector links for unmatched bus stops
    if len(unmatched_stops_nodes_gdf) > 0:
        create_connector_links_for_poor_match_stops(
            roadway_net=roadway_net,
            unmatched_stops_gdf=unmatched_stops_nodes_gdf,
            local_crs=local_crs,
            crs_units=crs_units,
            trace_shape_ids=trace_shape_ids,
            default_link_attribute_dict=default_link_attribute_dict,
        )

    WranglerLogger.debug(f"bus_stop_links_gdf:\n{bus_stop_links_gdf}")

    # finally, we need to find shortest paths through the bus network
    # between bus stops and update stops and shapes accordingly
    try:
        route_shapes_between_stops(
            bus_stop_links_gdf,
            feed_tables,
            roadway_net,
            local_crs,
            crs_units,
            trace_shape_ids,
            errors,
            default_link_attribute_dict,
        )
    except Exception as e:
        raise e

    # Getting ready to create Feed object
    # stop_id is now really the model_node_id -- set it
    feed_tables["stops"].rename(
        columns={"stop_id": "stop_id_GTFS", "model_node_id": "stop_id"}, inplace=True
    )
    # But some of the stops are mapped the the same model_node_id (now, stop_id) -- merge them.
    duplicate_stop_ids_df = feed_tables["stops"].loc[
        feed_tables["stops"].duplicated(subset=["stop_id"], keep=False)
    ]
    WranglerLogger.debug(f"duplicate_stop_ids_df:\n{duplicate_stop_ids_df}")
    WranglerLogger.debug(f"feed_tables['stops'].dtypes:\n{feed_tables['stops'].dtypes}")
    # stop_id_GTFS                     object
    # stop_name                        object
    # stop_lat                        float64
    # stop_lon                        float64
    # zone_id                          object
    # location_type                  category
    # parent_station                   object
    # level_id                         object
    # geometry                       geometry
    # agency_ids                       object
    # agency_names                     object
    # route_ids                        object
    # route_names                      object
    # route_types                      object
    # shape_ids                        object
    # is_parent                          bool
    # is_bus_stop                        bool
    # stop_id                          object
    # match_distance_feet             float64
    # valid_match                      object

    # create full stop_id_to_model_node_id_dict mapping
    stop_id_to_model_node_id_dict = (
        feed_tables["stops"][["stop_id_GTFS", "stop_id"]]
        .set_index("stop_id_GTFS")
        .to_dict()["stop_id"]
    )
    WranglerLogger.debug(f"stop_id_to_model_node_id_dict: {stop_id_to_model_node_id_dict}")

    # Convert NaN to empty lists before aggregation
    feed_tables["stops"]["agency_ids"] = feed_tables["stops"]["agency_ids"].apply(
        lambda x: x if isinstance(x, list) else []
    )
    feed_tables["stops"]["agency_names"] = feed_tables["stops"]["agency_names"].apply(
        lambda x: x if isinstance(x, list) else []
    )
    feed_tables["stops"]["route_ids"] = feed_tables["stops"]["route_ids"].apply(
        lambda x: x if isinstance(x, list) else []
    )
    feed_tables["stops"]["route_names"] = feed_tables["stops"]["route_names"].apply(
        lambda x: x if isinstance(x, list) else []
    )
    feed_tables["stops"]["route_types"] = feed_tables["stops"]["route_types"].apply(
        lambda x: x if isinstance(x, list) else []
    )
    feed_tables["stops"]["shape_ids"] = feed_tables["stops"]["shape_ids"].apply(
        lambda x: x if isinstance(x, list) else []
    )

    feed_tables["stops"] = (
        feed_tables["stops"]
        .groupby(by=["stop_id"])
        .aggregate(
            stop_id_GTFS=pd.NamedAgg(column="stop_id_GTFS", aggfunc=list),
            stop_name=pd.NamedAgg(column="stop_name", aggfunc=list),
            stop_lat=pd.NamedAgg(column="stop_lat", aggfunc="first"),
            stop_lon=pd.NamedAgg(column="stop_lon", aggfunc="first"),
            zone_id=pd.NamedAgg(column="zone_id", aggfunc=list),
            location_type=pd.NamedAgg(column="location_type", aggfunc="first"),
            parent_station=pd.NamedAgg(column="parent_station", aggfunc="first"),
            level_id=pd.NamedAgg(column="level_id", aggfunc=list),
            geometry=pd.NamedAgg(column="geometry", aggfunc="first"),
            agency_ids=pd.NamedAgg(
                column="agency_ids",
                aggfunc=lambda x: list({item for sublist in x for item in sublist}),
            ),
            agency_names=pd.NamedAgg(
                column="agency_names",
                aggfunc=lambda x: list({item for sublist in x for item in sublist}),
            ),
            route_ids=pd.NamedAgg(
                column="route_ids",
                aggfunc=lambda x: list({item for sublist in x for item in sublist}),
            ),
            route_names=pd.NamedAgg(
                column="route_names",
                aggfunc=lambda x: list({item for sublist in x for item in sublist}),
            ),
            route_types=pd.NamedAgg(
                column="route_types",
                aggfunc=lambda x: list({item for sublist in x for item in sublist}),
            ),
            shape_ids=pd.NamedAgg(
                column="shape_ids",
                aggfunc=lambda x: list({item for sublist in x for item in sublist}),
            ),
            is_parent=pd.NamedAgg(column="is_parent", aggfunc=any),
            is_bus_stop=pd.NamedAgg(column="is_bus_stop", aggfunc=any),
        )
        .reset_index(drop=False)
    )
    feed_tables["stops"]["stop_id"] = feed_tables["stops"]["stop_id"].astype(int)

    # Update feed_tables['stop_times']
    feed_tables["stop_times"].rename(columns={"stop_id": "stop_id_GTFS"}, inplace=True)
    feed_tables["stop_times"]["stop_id"] = feed_tables["stop_times"]["stop_id_GTFS"].map(
        stop_id_to_model_node_id_dict
    )

    # Log all tables
    for table_name, table_data in feed_tables.items():
        WranglerLogger.debug(
            f"Before creating Feed object, feed_tables[{table_name}]:\n{table_data}"
        )

    # create Feed object from results of the above
    try:
        feed = Feed(**feed_tables)
        WranglerLogger.info(f"Successfully created Feed with {len(feed_tables)} tables")
        return feed
    except Exception as e:
        WranglerLogger.error(f"Error creating Feed: {e}")
        raise e
create_links_for_failed_bus_paths(
    roadway_net,
    no_bus_path_gdf,
    local_crs,
    crs_units,
    trace_shape_ids=None,
    default_link_attribute_dict=None,
)

Create direct transit-only links for bus stop pairs that couldn’t be routed.

When pathfinding through the bus network fails for consecutive bus stop pairs, this function creates direct transit-only links connecting them. These links enable the transit route to continue even when the underlying road network doesn’t provide a valid bus path.

Parameters:

Name Type Description Default
roadway_net RoadwayNetwork

RoadwayNetwork to add links to

required
no_bus_path_gdf GeoDataFrame

GeoDataFrame of failed bus path segments with columns: - A, B (int): Stop node IDs that couldn’t be connected via pathfinding - trip_id, stop_sequence: Trip and sequence information - stop_id, stop_name: Stop identifiers - next_stop_id, next_stop_name: Next stop identifiers - geometry (LineString): Direct connection geometry between stops - route_type, route_id, direction_id, shape_id: Route metadata

required
local_crs str

Coordinate reference system for distance calculations

required
crs_units str

Distance units (‘feet’ or ‘meters’)

required
trace_shape_ids list[str] | None

Optional list of shape_ids for debug logging

None
default_link_attribute_dict dict[str, any] | None

Optional dict of column-name to default value to set on new links.

None
Notes
  • Links are marked with ref=”bad_bus_path” for identification
  • Links allow all transit modes (bus_only, rail_only, ferry_only = True)
  • If a link already exists for an A->B pair, transit access is added to existing link
  • Modifies roadway_net.links_df and roadway_net.shapes in place
Source code in network_wrangler/utils/transit.py
def create_links_for_failed_bus_paths(  # noqa: PLR0915
    roadway_net: RoadwayNetwork,
    no_bus_path_gdf: gpd.GeoDataFrame,
    local_crs: str,
    crs_units: str,
    trace_shape_ids: list[str] | None = None,
    default_link_attribute_dict: dict[str, any] | None = None,
):
    """Create direct transit-only links for bus stop pairs that couldn't be routed.

    When pathfinding through the bus network fails for consecutive bus stop pairs,
    this function creates direct transit-only links connecting them. These links
    enable the transit route to continue even when the underlying road network
    doesn't provide a valid bus path.

    Args:
        roadway_net: RoadwayNetwork to add links to
        no_bus_path_gdf: GeoDataFrame of failed bus path segments with columns:
            - A, B (int): Stop node IDs that couldn't be connected via pathfinding
            - trip_id, stop_sequence: Trip and sequence information
            - stop_id, stop_name: Stop identifiers
            - next_stop_id, next_stop_name: Next stop identifiers
            - geometry (LineString): Direct connection geometry between stops
            - route_type, route_id, direction_id, shape_id: Route metadata
        local_crs: Coordinate reference system for distance calculations
        crs_units: Distance units ('feet' or 'meters')
        trace_shape_ids: Optional list of shape_ids for debug logging
        default_link_attribute_dict: Optional dict of column-name to default value to set on new links.

    Notes:
        - Links are marked with ref="bad_bus_path" for identification
        - Links allow all transit modes (bus_only, rail_only, ferry_only = True)
        - If a link already exists for an A->B pair, transit access is added to existing link
        - Modifies roadway_net.links_df and roadway_net.shapes in place
    """
    if no_bus_path_gdf is None or len(no_bus_path_gdf) == 0:
        WranglerLogger.info("No failed bus paths to create links for")
        return

    WranglerLogger.info(f"Creating links for {len(no_bus_path_gdf)} failed bus path segments")

    # no_bus_path_gdf columns:
    # A, B, shape_id, stop_sequence, route_type, route_id, direction_id, trip_id,
    # stop_id, stop_name, next_stop_id, next_stop_name, num_points, geometry
    add_links_gdf = no_bus_path_gdf.copy()
    # drop some unneeded columns
    add_links_gdf.drop(
        columns=["route_type", "route_id", "direction_id", "shape_id", "num_points"]
    )
    # roll up to unique A,B, using the first
    add_links_gdf = gpd.GeoDataFrame(
        data=add_links_gdf.groupby(by=["A", "B"])
        .agg(
            trip_ids=pd.NamedAgg(column="trip_id", aggfunc=list),
            stop_seqs=pd.NamedAgg(column="stop_sequence", aggfunc=list),
            stop_id=pd.NamedAgg(column="stop_id", aggfunc="first"),
            stop_name=pd.NamedAgg(column="stop_name", aggfunc="first"),
            next_stop_id=pd.NamedAgg(column="next_stop_id", aggfunc="first"),
            next_stop_name=pd.NamedAgg(column="next_stop_name", aggfunc="first"),
            geometry=pd.NamedAgg(column="geometry", aggfunc="first"),
        )
        .reset_index(drop=False),
        geometry="geometry",
        crs=no_bus_path_gdf.crs,
    )
    add_links_gdf["name"] = add_links_gdf["trip_ids"].astype(str)
    add_links_gdf["shape_id"] = add_links_gdf["stop_id"] + " to " + add_links_gdf["next_stop_id"]
    # make ok for buses (but not ferry or rail, since that's confusing)
    add_links_gdf["rail_only"] = False
    add_links_gdf["ferry_only"] = False
    add_links_gdf["bus_only"] = True
    add_links_gdf["drive_access"] = True
    # not ok for others
    add_links_gdf["truck_access"] = False
    add_links_gdf["bike_access"] = False
    add_links_gdf["walk_access"] = False
    # fill in some defaults
    add_links_gdf["roadway"] = "transit"
    add_links_gdf["lanes"] = 1
    add_links_gdf["managed"] = 0
    # this is how you find me
    add_links_gdf["ref"] = "bad_bus_path"

    add_links_gdf.to_crs(local_crs, inplace=True)
    add_links_gdf["length"] = add_links_gdf.length
    if crs_units == "feet":
        add_links_gdf["distance"] = add_links_gdf["length"] / FEET_PER_MILE
    else:
        add_links_gdf["distance"] = add_links_gdf["length"] / METERS_PER_KILOMETER
    add_links_gdf.to_crs(LAT_LON_CRS, inplace=True)
    add_links_gdf.reset_index(drop=True, inplace=True)

    # check if any exist already
    add_links_gdf["temp_model_link_id"] = add_links_gdf.index
    exists_already_df = pd.merge(
        left=roadway_net.links_df,
        right=add_links_gdf[["A", "B", "temp_model_link_id"]],
        on=["A", "B"],
        how="inner",
    )
    if len(exists_already_df) > 0:
        WranglerLogger.warning(
            f"Can't add the following links because they exist already; adding transit modes:\n"
            f"{exists_already_df}"
        )
        # set transit usability for those links
        roadway_net.links_df.loc[
            roadway_net.links_df["model_link_id"].isin(exists_already_df["model_link_id"]),
            "rail_only",
        ] = True
        roadway_net.links_df.loc[
            roadway_net.links_df["model_link_id"].isin(exists_already_df["model_link_id"]),
            "bus_only",
        ] = True
        roadway_net.links_df.loc[
            roadway_net.links_df["model_link_id"].isin(exists_already_df["model_link_id"]),
            "ferry_only",
        ] = True
        # remove the duplicate from add_links_gdf
        add_links_gdf = add_links_gdf.loc[
            ~add_links_gdf["temp_model_link_id"].isin(exists_already_df["temp_model_link_id"])
        ]
        add_links_gdf.reset_index(drop=True, inplace=True)

    # we're done with this
    add_links_gdf.drop(columns=["temp_model_link_id"], inplace=True)

    if len(add_links_gdf) == 0:
        WranglerLogger.info("All failed bus path links already exist")
        return

    # assign model_link_id
    max_model_link_id = roadway_net.links_df.model_link_id.max()
    add_links_gdf["model_link_id"] = add_links_gdf.index + max_model_link_id + 1

    # log for trace_shape_ids
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            trace_trip_id = f"{trace_shape_id}_trip"
            shape_mask = add_links_gdf["trip_ids"].apply(
                lambda x, tid=trace_trip_id: tid in x if isinstance(x, list) else False
            )
            if shape_mask.any():
                WranglerLogger.debug(
                    f"adding links for trace {trace_shape_id} in create_links_for_failed_bus_paths:\n"
                    f"{add_links_gdf.loc[shape_mask]}"
                )

    # Add links
    WranglerLogger.debug(f"add_links_gdf:\n{add_links_gdf}")

    # Apply default link attributes
    if default_link_attribute_dict is None:
        default_link_attribute_dict = {}
    for colname, default_value in default_link_attribute_dict.items():
        add_links_gdf[colname] = default_value

    roadway_net.add_links(add_links_gdf)
    WranglerLogger.info(f"Adding {len(add_links_gdf):,} links for failed bus paths")

    # Add shapes
    roadway_net.add_shapes(add_links_gdf)

network_wrangler.utils.transit.drop_transit_agency

drop_transit_agency(transit_data, agency_id)

Remove all routes, trips, stops, etc. for a specific agency or agencies.

Filters out all data associated with the specified agency_id(s), ensuring the resulting transit data remains valid by removing orphaned stops and maintaining referential integrity. Modifies transit_data in place.

Parameters:

Name Type Description Default
transit_data GtfsModel | Feed

Either a GtfsModel or Feed object to filter. Modified in place.

required
agency_id str | list[str]

Single agency_id string or list of agency_ids to remove

required
Example

Remove a single agency

drop_transit_agency(gtfs_model, “SFMTA”)

Remove multiple agencies

drop_transit_agency(gtfs_model, [“SFMTA”, “AC”])

Source code in network_wrangler/transit/filter.py
def drop_transit_agency(  # noqa: PLR0915
    transit_data: GtfsModel | Feed,
    agency_id: str | list[str],
) -> None:
    """Remove all routes, trips, stops, etc. for a specific agency or agencies.

    Filters out all data associated with the specified agency_id(s), ensuring
    the resulting transit data remains valid by removing orphaned stops and
    maintaining referential integrity. Modifies transit_data in place.

    Args:
        transit_data: Either a GtfsModel or Feed object to filter. Modified in place.
        agency_id: Single agency_id string or list of agency_ids to remove

    Example:
        >>> # Remove a single agency
        >>> drop_transit_agency(gtfs_model, "SFMTA")
        >>> # Remove multiple agencies
        >>> drop_transit_agency(gtfs_model, ["SFMTA", "AC"])
    """
    # Convert single agency_id to list for uniform handling
    agency_ids_to_remove = [agency_id] if isinstance(agency_id, str) else agency_id

    WranglerLogger.info(f"Removing transit data for agency/agencies: {agency_ids_to_remove}")

    # Get data tables (references, not copies)
    routes_df = transit_data.routes
    trips_df = transit_data.trips
    stop_times_df = transit_data.stop_times
    stops_df = transit_data.stops
    is_gtfs = isinstance(transit_data, GtfsModel)

    # Find routes to keep (those NOT belonging to agencies being removed)
    if "agency_id" in routes_df.columns:
        routes_to_keep = routes_df[~routes_df.agency_id.isin(agency_ids_to_remove)]
        routes_removed = len(routes_df) - len(routes_to_keep)
    else:
        # If no agency_id column in routes, log warning and keep all routes
        WranglerLogger.warning(
            "No agency_id column found in routes table - cannot filter by agency"
        )
        routes_to_keep = routes_df
        routes_removed = 0

    route_ids_to_keep = set(routes_to_keep.route_id)

    # Filter trips based on remaining routes
    trips_to_keep = trips_df[trips_df.route_id.isin(route_ids_to_keep)]
    trips_removed = len(trips_df) - len(trips_to_keep)
    trip_ids_to_keep = set(trips_to_keep.trip_id)

    # Filter stop_times based on remaining trips
    stop_times_to_keep = stop_times_df[stop_times_df.trip_id.isin(trip_ids_to_keep)]
    stop_times_removed = len(stop_times_df) - len(stop_times_to_keep)

    # Find stops that are still referenced
    stops_still_used = set(stop_times_to_keep.stop_id.unique())
    stops_to_keep = stops_df[stops_df.stop_id.isin(stops_still_used)]

    # Check if any of these stops reference parent stations
    if "parent_station" in stops_to_keep.columns:
        # Get parent stations that are referenced by kept stops
        parent_stations = stops_to_keep["parent_station"].dropna().unique()
        parent_stations = [ps for ps in parent_stations if ps != ""]  # Remove empty strings

        if len(parent_stations) > 0:
            # Find parent stations that aren't already in our filtered stops
            existing_stop_ids = set(stops_to_keep.stop_id)
            missing_parent_stations = [ps for ps in parent_stations if ps not in existing_stop_ids]

            if len(missing_parent_stations) > 0:
                WranglerLogger.debug(
                    f"Adding back {len(missing_parent_stations)} parent stations referenced by kept stops"
                )

                # Get the parent station records
                parent_station_records = stops_df[stops_df.stop_id.isin(missing_parent_stations)]

                # Append parent stations to filtered stops
                stops_to_keep = pd.concat(
                    [stops_to_keep, parent_station_records], ignore_index=True
                )

    stops_removed = len(stops_df) - len(stops_to_keep)

    WranglerLogger.info(
        f"Removed: {routes_removed:,} routes, {trips_removed:,} trips, "
        f"{stop_times_removed:,} stop_times, {stops_removed:,} stops"
    )

    WranglerLogger.info(
        f"Remaining: {len(routes_to_keep):,} routes, {len(trips_to_keep):,} trips, "
        f"{len(stops_to_keep):,} stops"
    )
    WranglerLogger.debug(
        f"Stops removed:\n{stops_df.loc[~stops_df['stop_id'].isin(stops_still_used)]}"
    )

    # Update tables in place, in order so that validation is ok
    transit_data.stop_times = stop_times_to_keep
    transit_data.trips = trips_to_keep
    transit_data.routes = routes_to_keep
    transit_data.stops = stops_to_keep

    # Handle agency table
    if hasattr(transit_data, "agency") and transit_data.agency is not None:
        # Keep agencies that are NOT being removed
        filtered_agency = transit_data.agency[
            ~transit_data.agency.agency_id.isin(agency_ids_to_remove)
        ]
        WranglerLogger.info(
            f"Removed {len(transit_data.agency) - len(filtered_agency):,} agencies"
        )
        transit_data.agency = filtered_agency

    # Handle shapes table
    if (
        hasattr(transit_data, "shapes")
        and transit_data.shapes is not None
        and "shape_id" in trips_to_keep.columns
    ):
        shape_ids = set(trips_to_keep.shape_id.dropna().unique())
        filtered_shapes = transit_data.shapes[transit_data.shapes.shape_id.isin(shape_ids)]
        WranglerLogger.info(
            f"Removed {len(transit_data.shapes) - len(filtered_shapes):,} shape points"
        )
        transit_data.shapes = filtered_shapes

    # Handle calendar table
    if hasattr(transit_data, "calendar") and transit_data.calendar is not None:
        # Keep only service_ids referenced by remaining trips
        service_ids = set(trips_to_keep.service_id.unique())
        transit_data.calendar = transit_data.calendar[
            transit_data.calendar.service_id.isin(service_ids)
        ]

    # Handle calendar_dates table
    if hasattr(transit_data, "calendar_dates") and transit_data.calendar_dates is not None:
        # Keep only service_ids referenced by remaining trips
        service_ids = set(trips_to_keep.service_id.unique())
        transit_data.calendar_dates = transit_data.calendar_dates[
            transit_data.calendar_dates.service_id.isin(service_ids)
        ]

    # Handle frequencies table
    if hasattr(transit_data, "frequencies") and transit_data.frequencies is not None:
        # Keep only frequencies for remaining trips
        transit_data.frequencies = transit_data.frequencies[
            transit_data.frequencies.trip_id.isin(trip_ids_to_keep)
        ]

network_wrangler.utils.transit.filter_transit_by_boundary

filter_transit_by_boundary(
    transit_data,
    boundary,
    partially_include_route_type_action=None,
)

Filter transit routes based on whether they have stops within a boundary.

Removes routes that are entirely outside the boundary shapefile. Routes that are partially within the boundary are kept by default, but can be configured per route type to be truncated at the boundary. Modifies transit_data in place.

Parameters:

Name Type Description Default
transit_data GtfsModel | Feed

Either a GtfsModel or Feed object to filter. Modified in place.

required
boundary str | Path | GeoDataFrame

Path to boundary shapefile or a GeoDataFrame with boundary polygon(s)

required
partially_include_route_type_action dict[RouteType, str] | None

Optional dictionary mapping RouteType enum to action for routes partially within boundary: - “truncate”: Truncate route to only include stops within boundary Route types not specified in this dictionary will be kept entirely (default).

None
Example

from network_wrangler.models.gtfs.types import RouteType

Remove routes entirely outside the Bay Area

filtered_gtfs = filter_transit_by_boundary(gtfs_model, “bay_area_boundary.shp”)

Truncate rail routes at boundary, keep other route types unchanged

filtered_gtfs = filter_transit_by_boundary( … gtfs_model, … “bay_area_boundary.shp”, … partially_include_route_type_action={ … RouteType.RAIL: “truncate”, # Rail - will be truncated at boundary … # Other route types not listed will be kept entirely … }, … )

Todo

This is similar to clip_feed_to_boundary – consolidate?

Source code in network_wrangler/transit/filter.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
def filter_transit_by_boundary(  # noqa: PLR0912, PLR0915
    transit_data: GtfsModel | Feed,
    boundary: str | Path | gpd.GeoDataFrame,
    partially_include_route_type_action: dict[RouteType, str] | None = None,
) -> None:
    """Filter transit routes based on whether they have stops within a boundary.

    Removes routes that are entirely outside the boundary shapefile. Routes that are
    partially within the boundary are kept by default, but can be configured per
    route type to be truncated at the boundary. Modifies transit_data in place.

    Args:
        transit_data: Either a GtfsModel or Feed object to filter. Modified in place.
        boundary: Path to boundary shapefile or a GeoDataFrame with boundary polygon(s)
        partially_include_route_type_action: Optional dictionary mapping RouteType enum to
            action for routes partially within boundary:
            - "truncate": Truncate route to only include stops within boundary
            Route types not specified in this dictionary will be kept entirely (default).

    Example:
        >>> from network_wrangler.models.gtfs.types import RouteType
        >>> # Remove routes entirely outside the Bay Area
        >>> filtered_gtfs = filter_transit_by_boundary(gtfs_model, "bay_area_boundary.shp")
        >>> # Truncate rail routes at boundary, keep other route types unchanged
        >>> filtered_gtfs = filter_transit_by_boundary(
        ...     gtfs_model,
        ...     "bay_area_boundary.shp",
        ...     partially_include_route_type_action={
        ...         RouteType.RAIL: "truncate",  # Rail - will be truncated at boundary
        ...         # Other route types not listed will be kept entirely
        ...     },
        ... )

    !!! todo
        This is similar to [`clip_feed_to_boundary`][network_wrangler.transit.clip.clip_feed_to_boundary] -- consolidate?

    """
    WranglerLogger.info("Filtering transit routes by boundary")

    # Log input parameters
    WranglerLogger.debug(
        f"partially_include_route_type_action: {partially_include_route_type_action}"
    )

    # Load boundary if it's a file path
    if isinstance(boundary, str | Path):
        WranglerLogger.debug(f"Loading boundary from file: {boundary}")
        boundary_gdf = gpd.read_file(boundary)
    else:
        WranglerLogger.debug("Using provided boundary GeoDataFrame")
        boundary_gdf = boundary

    WranglerLogger.debug(f"Boundary has {len(boundary_gdf)} polygon(s)")

    # Ensure boundary is in a geographic CRS for spatial operations
    if boundary_gdf.crs is None:
        WranglerLogger.warning("Boundary has no CRS, assuming EPSG:4326")
        boundary_gdf = boundary_gdf.set_crs(LAT_LON_CRS)
    else:
        WranglerLogger.debug(f"Boundary CRS: {boundary_gdf.crs}")

    # Get references to tables (not copies since we'll modify in place)
    is_gtfs = isinstance(transit_data, GtfsModel)
    stops_df = transit_data.stops
    routes_df = transit_data.routes
    trips_df = transit_data.trips
    stop_times_df = transit_data.stop_times

    if is_gtfs:
        WranglerLogger.debug("Processing GtfsModel data")
    else:
        WranglerLogger.debug("Processing Feed data")

    WranglerLogger.debug(
        f"Input data has {len(stops_df)} stops, {len(routes_df)} routes, {len(trips_df)} trips, {len(stop_times_df)} stop_times"
    )

    # Create GeoDataFrame from stops
    stops_gdf = gpd.GeoDataFrame(
        stops_df,
        geometry=gpd.points_from_xy(stops_df.stop_lon, stops_df.stop_lat),
        crs=LAT_LON_CRS,
    )

    # Reproject to match boundary CRS if needed
    if stops_gdf.crs != boundary_gdf.crs:
        WranglerLogger.debug(f"Reprojecting stops from {stops_gdf.crs} to {boundary_gdf.crs}")
        stops_gdf = stops_gdf.to_crs(boundary_gdf.crs)

    # Spatial join to find stops within boundary
    WranglerLogger.debug("Performing spatial join to find stops within boundary")
    stops_in_boundary = gpd.sjoin(stops_gdf, boundary_gdf, how="inner", predicate="within")
    stops_in_boundary_ids = set(stops_in_boundary.stop_id.unique())

    # Log some stops that are outside boundary for debugging
    stops_outside_boundary = set(stops_df.stop_id) - stops_in_boundary_ids
    if stops_outside_boundary:
        sample_outside = list(stops_outside_boundary)[:5]
        WranglerLogger.debug(f"Sample of stops outside boundary: {sample_outside}")

    WranglerLogger.info(
        f"Found {len(stops_in_boundary_ids):,} stops within boundary "
        f"out of {len(stops_df):,} total stops"
    )

    # Find which routes to keep
    # Get unique stop-route pairs from stop_times and trips
    stop_route_pairs = pd.merge(
        stop_times_df[["trip_id", "stop_id"]], trips_df[["trip_id", "route_id"]], on="trip_id"
    )[["stop_id", "route_id"]].drop_duplicates()

    # Group by route to find which stops each route serves
    route_stops = stop_route_pairs.groupby("route_id")["stop_id"].apply(set).reset_index()
    route_stops.columns = ["route_id", "stop_ids"]

    # Add route_type information
    route_stops = pd.merge(
        route_stops, routes_df[["route_id", "route_type"]], on="route_id", how="left"
    )

    # Initialize with default filters
    if partially_include_route_type_action is None:
        partially_include_route_type_action = {}

    # Convert RouteType enum keys to int values for comparison with dataframe
    normalized_route_type_action = {}
    for key, value in partially_include_route_type_action.items():
        if not isinstance(key, RouteType):
            msg = f"Keys in partially_include_route_type_action must be RouteType enum, got {type(key)}"
            raise TypeError(msg)
        normalized_route_type_action[key.value] = value
    partially_include_route_type_action = normalized_route_type_action

    # Track routes to truncate
    routes_to_truncate = {}

    # Determine which routes to keep and how to handle them
    def determine_route_handling(row):
        route_id = row["route_id"]
        route_type = row["route_type"]
        stop_ids = row["stop_ids"]

        # Check if route has stops both inside and outside boundary
        stops_inside = stop_ids.intersection(stops_in_boundary_ids)
        stops_outside = stop_ids - stops_in_boundary_ids

        # If all stops are outside, always remove
        if len(stops_inside) == 0:
            WranglerLogger.debug(
                f"Route {route_id} (type {route_type}): all {len(stop_ids)} stops outside boundary - REMOVE"
            )
            return "remove"

        # If all stops are inside, always keep
        if len(stops_outside) == 0:
            return "keep"

        # Route has stops both inside and outside - check partially_include_route_type_action
        WranglerLogger.debug(
            f"Route {route_id} (type {route_type}): {len(stops_inside)} stops inside, "
            f"{len(stops_outside)} stops outside boundary"
        )

        if route_type in partially_include_route_type_action:
            action = partially_include_route_type_action[route_type]
            WranglerLogger.debug(
                f"  - Applying configured action for route_type {route_type}: {action}"
            )
            if action == "truncate":
                return "truncate"

        # Default to keep if not specified
        WranglerLogger.debug(
            f"  - No action configured for route_type {route_type}, defaulting to KEEP"
        )
        return "keep"

    route_stops["handling"] = route_stops.apply(determine_route_handling, axis=1)
    WranglerLogger.debug(f"route_stops with handling set:\n{route_stops}")

    routes_to_keep = set(
        route_stops[route_stops["handling"].isin(["keep", "truncate"])]["route_id"]
    )
    routes_to_remove = set(route_stops[route_stops["handling"] == "remove"]["route_id"])
    routes_needing_truncation = set(route_stops[route_stops["handling"] == "truncate"]["route_id"])

    WranglerLogger.info(
        f"Keeping {len(routes_to_keep):,} routes out of {len(routes_df):,} total routes"
    )

    if routes_to_remove:
        WranglerLogger.info(f"Removing {len(routes_to_remove):,} routes entirely outside boundary")
        WranglerLogger.debug(f"Routes being removed: {sorted(routes_to_remove)[:10]}...")

    if routes_needing_truncation:
        WranglerLogger.info(f"Truncating {len(routes_needing_truncation):,} routes at boundary")
        WranglerLogger.debug(
            f"Routes being truncated: {sorted(routes_needing_truncation)[:10]}..."
        )

    # Filter data
    filtered_routes = routes_df[routes_df.route_id.isin(routes_to_keep)]
    filtered_trips = trips_df[trips_df.route_id.isin(routes_to_keep)]
    filtered_trip_ids = set(filtered_trips.trip_id)

    # Handle truncation by calling truncate_route_at_stop for each route needing truncation
    if routes_needing_truncation:
        WranglerLogger.debug(f"Processing truncation for {len(routes_needing_truncation)} routes")

        # Start with the current filtered data
        # Need to ensure stop_times only includes trips that are in filtered_trips
        filtered_stop_times_for_truncation = stop_times_df[
            stop_times_df.trip_id.isin(filtered_trip_ids)
        ]

        # First update transit_data with filtered data before truncation (in order to maintain validation)
        transit_data.stop_times = filtered_stop_times_for_truncation
        transit_data.trips = filtered_trips
        transit_data.routes = filtered_routes

        # Process each route that needs truncation
        for route_id in routes_needing_truncation:
            WranglerLogger.debug(f"Processing truncation for route {route_id}")

            # Get trips for this route
            route_trips = trips_df[trips_df.route_id == route_id]

            # Group by direction_id
            for direction_id in route_trips.direction_id.unique():
                dir_trips = route_trips[route_trips.direction_id == direction_id]
                if len(dir_trips) == 0:
                    continue

                # Analyze stop patterns for this route/direction
                # Get a representative trip (first one)
                sample_trip_id = dir_trips.iloc[0].trip_id
                sample_stop_times = transit_data.stop_times[
                    transit_data.stop_times.trip_id == sample_trip_id
                ].sort_values("stop_sequence")

                # Find which stops are inside/outside boundary
                stop_boundary_status = sample_stop_times["stop_id"].isin(stops_in_boundary_ids)

                # Check if route exits and re-enters boundary (complex case)
                boundary_changes = stop_boundary_status.ne(stop_boundary_status.shift()).cumsum()
                num_segments = boundary_changes.nunique()

                if num_segments > MIN_ROUTE_SEGMENTS:
                    # Complex case: route exits and re-enters boundary
                    route_info = routes_df[routes_df.route_id == route_id].iloc[0]
                    route_name = route_info.get("route_short_name", route_id)
                    msg = (
                        f"Route {route_name} ({route_id}) direction {direction_id} has a complex "
                        f"boundary crossing pattern (crosses boundary {num_segments - 1} times). "
                        f"Can only handle routes that exit boundary at beginning or end."
                    )
                    raise ValueError(msg)

                # Determine truncation type
                first_stop_inside = stop_boundary_status.iloc[0]
                last_stop_inside = stop_boundary_status.iloc[-1]

                if not first_stop_inside and not last_stop_inside:
                    # All stops outside - shouldn't happen as route would be removed
                    continue
                if first_stop_inside and last_stop_inside:
                    # All stops inside - no truncation needed
                    continue
                if not first_stop_inside and last_stop_inside:
                    # Starts outside, ends inside - truncate before first inside stop
                    # Find first True value (first stop inside boundary)
                    first_inside_pos = stop_boundary_status.tolist().index(True)
                    first_inside_stop = sample_stop_times.iloc[first_inside_pos]["stop_id"]

                    WranglerLogger.debug(
                        f"Route {route_id} dir {direction_id}: truncating before stop {first_inside_stop}"
                    )
                    truncate_route_at_stop(
                        transit_data, route_id, direction_id, first_inside_stop, "before"
                    )
                elif first_stop_inside and not last_stop_inside:
                    # Starts inside, ends outside - truncate after last inside stop
                    # Find last True value (last stop inside boundary)
                    reversed_list = stop_boundary_status.tolist()[::-1]
                    last_inside_pos = len(reversed_list) - 1 - reversed_list.index(True)
                    last_inside_stop = sample_stop_times.iloc[last_inside_pos]["stop_id"]

                    WranglerLogger.debug(
                        f"Route {route_id} dir {direction_id}: truncating after stop {last_inside_stop}"
                    )
                    truncate_route_at_stop(
                        transit_data, route_id, direction_id, last_inside_stop, "after"
                    )

        # After truncation, transit_data has been modified in place
        # Update references to current state (in order to maintain validation)
        filtered_stop_times = transit_data.stop_times
        filtered_trips = transit_data.trips
        filtered_routes = transit_data.routes
        filtered_stops = transit_data.stops
    else:
        # No truncation needed - update transit_data with filtered data
        filtered_stop_times = stop_times_df[stop_times_df.trip_id.isin(filtered_trip_ids)]
        filtered_stops = stops_df[stops_df.stop_id.isin(filtered_stop_times.stop_id.unique())]

        # Check if any of the filtered stops reference parent stations
        if "parent_station" in filtered_stops.columns:
            # Get parent stations that are referenced by kept stops
            parent_stations = filtered_stops["parent_station"].dropna().unique()
            parent_stations = [ps for ps in parent_stations if ps != ""]  # Remove empty strings

            if len(parent_stations) > 0:
                # Find parent stations that aren't already in our filtered stops
                existing_stop_ids = set(filtered_stops.stop_id)
                missing_parent_stations = [
                    ps for ps in parent_stations if ps not in existing_stop_ids
                ]

                if len(missing_parent_stations) > 0:
                    WranglerLogger.debug(
                        f"Adding back {len(missing_parent_stations)} parent stations referenced by kept stops"
                    )

                    # Get the parent station records
                    parent_station_records = stops_df[
                        stops_df.stop_id.isin(missing_parent_stations)
                    ]

                    # Append parent stations to filtered stops
                    filtered_stops = pd.concat(
                        [filtered_stops, parent_station_records], ignore_index=True
                    )

        transit_data.stop_times = filtered_stop_times
        transit_data.trips = filtered_trips
        transit_data.routes = filtered_routes
        transit_data.stops = filtered_stops

    # Log details about removed stops
    stops_still_used = set(filtered_stops.stop_id.unique())
    removed_stops = set(stops_df.stop_id) - stops_still_used
    if removed_stops:
        WranglerLogger.debug(f"Removed {len(removed_stops)} stops that are no longer referenced")

        # Get details of removed stops
        removed_stops_df = stops_df[stops_df["stop_id"].isin(removed_stops)][
            ["stop_id", "stop_name"]
        ]

        # Log up to 20 removed stops with their names
        sample_size = min(20, len(removed_stops_df))
        for _, stop in removed_stops_df.head(sample_size).iterrows():
            WranglerLogger.debug(f"  - Removed stop: {stop['stop_id']} ({stop['stop_name']})")

        if len(removed_stops) > sample_size:
            WranglerLogger.debug(f"  ... and {len(removed_stops) - sample_size} more stops")

    WranglerLogger.info(
        f"After filtering: {len(filtered_routes):,} routes, "
        f"{len(filtered_trips):,} trips, {len(filtered_stops):,} stops"
    )

    # Log summary of filtering by action type
    route_handling_summary = route_stops.groupby("handling").size()
    WranglerLogger.debug(f"Route handling summary:\n{route_handling_summary}")

    # Log route type distribution for routes with mixed stops
    mixed_routes = route_stops[
        (route_stops["handling"].isin(["keep", "truncate"]))
        & (
            route_stops["route_id"].isin(routes_needing_truncation) | route_stops["handling"]
            == "keep"
        )
    ]
    if len(mixed_routes) > 0:
        route_type_summary = mixed_routes.groupby("route_type")["handling"].value_counts()
        WranglerLogger.debug(f"Route types with partial stops:\n{route_type_summary}")

    # Update other tables in transit_data in place
    if is_gtfs:
        # For GtfsModel, also filter shapes and other tables if they exist
        if (
            hasattr(transit_data, "agency")
            and transit_data.agency is not None
            and "agency_id" in filtered_routes.columns
        ):
            agency_ids = set(filtered_routes.agency_id.dropna().unique())
            transit_data.agency = transit_data.agency[
                transit_data.agency.agency_id.isin(agency_ids)
            ]

        if (
            hasattr(transit_data, "shapes")
            and transit_data.shapes is not None
            and "shape_id" in filtered_trips.columns
        ):
            shape_ids = set(filtered_trips.shape_id.dropna().unique())
            transit_data.shapes = transit_data.shapes[transit_data.shapes.shape_id.isin(shape_ids)]

        if hasattr(transit_data, "calendar") and transit_data.calendar is not None:
            # Keep only service_ids referenced by remaining trips
            service_ids = set(filtered_trips.service_id.unique())
            transit_data.calendar = transit_data.calendar[
                transit_data.calendar.service_id.isin(service_ids)
            ]

        if hasattr(transit_data, "calendar_dates") and transit_data.calendar_dates is not None:
            # Keep only service_ids referenced by remaining trips
            service_ids = set(filtered_trips.service_id.unique())
            transit_data.calendar_dates = transit_data.calendar_dates[
                transit_data.calendar_dates.service_id.isin(service_ids)
            ]

        if hasattr(transit_data, "frequencies") and transit_data.frequencies is not None:
            # Keep only frequencies for remaining trips
            transit_data.frequencies = transit_data.frequencies[
                transit_data.frequencies.trip_id.isin(filtered_trip_ids)
            ]

    else:  # Feed
        # For Feed, also handle frequencies and shapes
        if (
            hasattr(transit_data, "shapes")
            and transit_data.shapes is not None
            and "shape_id" in filtered_trips.columns
        ):
            shape_ids = set(filtered_trips.shape_id.dropna().unique())
            transit_data.shapes = transit_data.shapes[transit_data.shapes.shape_id.isin(shape_ids)]

        if hasattr(transit_data, "frequencies") and transit_data.frequencies is not None:
            # Keep only frequencies for remaining trips
            transit_data.frequencies = transit_data.frequencies[
                transit_data.frequencies.trip_id.isin(filtered_trip_ids)
            ]

network_wrangler.utils.transit.find_shape_aware_shortest_path

find_shape_aware_shortest_path(
    G_bus,
    start_node,
    end_node,
    original_shape_points,
    roadway_net,
    tolerance=1.1,
    trace=False,
)

Find shortest path that considers original shape points.

Uses constrained shortest path approach: 1. Find shortest distance 2. Get all paths within tolerance of shortest distance 3. Among those, select path with minimum deviation from original shape

Parameters:

Name Type Description Default
G_bus DiGraph

NetworkX DiGraph of bus network

required
start_node int

Starting node ID

required
end_node int

Ending node ID

required
original_shape_points DataFrame

DataFrame of original GTFS shape points

required
roadway_net RoadwayNetwork

RoadwayNetwork to get node coordinates

required
tolerance float

Maximum ratio of path distance to shortest distance (default 1.10 = 110%)

1.1
trace bool

Whether to log trace information

False

Returns:

Type Description
list

List of node IDs representing the best path

Source code in network_wrangler/utils/transit.py
def find_shape_aware_shortest_path(  # noqa: PLR0912
    G_bus: nx.DiGraph,
    start_node: int,
    end_node: int,
    original_shape_points: pd.DataFrame,
    roadway_net: RoadwayNetwork,
    tolerance: float = 1.10,
    trace: bool = False,
) -> list:
    """Find shortest path that considers original shape points.

    Uses constrained shortest path approach:
    1. Find shortest distance
    2. Get all paths within tolerance of shortest distance
    3. Among those, select path with minimum deviation from original shape

    Args:
        G_bus: NetworkX DiGraph of bus network
        start_node: Starting node ID
        end_node: Ending node ID
        original_shape_points: DataFrame of original GTFS shape points
        roadway_net: RoadwayNetwork to get node coordinates
        tolerance: Maximum ratio of path distance to shortest distance (default 1.10 = 110%)
        trace: Whether to log trace information

    Returns:
        List of node IDs representing the best path
    """
    try:
        # First, get the absolute shortest path distance
        shortest_dist = nx.shortest_path_length(G_bus, start_node, end_node, weight="distance")
        max_allowed_dist = shortest_dist * tolerance

        # Get multiple shortest paths to evaluate
        from itertools import islice

        candidate_paths = list(
            islice(
                nx.shortest_simple_paths(G_bus, start_node, end_node, weight="distance"),
                DefaultConfig.TRANSIT.MAX_SHAPE_CANDIDATE_PATHS,
            )
        )

        best_path = None
        best_deviation = float("inf")
        paths_within_tolerance = 0

        # Store path info for debug output
        debug_paths = []

        for path in candidate_paths:
            # Calculate path distance
            path_dist = 0
            for i in range(len(path) - 1):
                edge_data = G_bus[path[i]][path[i + 1]]
                path_dist += edge_data.get("distance", 0)

            # Check if within tolerance
            if path_dist <= max_allowed_dist:
                paths_within_tolerance += 1

                # Calculate shape deviation for this path
                deviation = calculate_path_deviation_from_shape(
                    path, original_shape_points, roadway_net, trace=trace
                )

                # Store for debug output
                if trace:
                    debug_paths.append(
                        {
                            "path": path,
                            "distance": path_dist,
                            "deviation": deviation,
                            "is_best": False,  # Will update later
                        }
                    )

                # Select path with minimum deviation
                if deviation < best_deviation:
                    best_deviation = deviation
                    best_path = path
                    if trace:
                        WranglerLogger.debug(
                            f"  New best path with deviation {deviation:.6f}, distance {path_dist:.3f}"
                        )

        # Mark the best path
        if trace and debug_paths and best_path:
            for dp in debug_paths:
                if dp["path"] == best_path:
                    dp["is_best"] = True
                    break

        if trace:
            WranglerLogger.debug(
                f"Shape-aware routing: {paths_within_tolerance} paths within {tolerance:.1%} tolerance"
            )
            if not original_shape_points.empty:
                WranglerLogger.debug(
                    f"Original shape has {len(original_shape_points)} points between stops"
                )
            if best_path:
                WranglerLogger.debug(f"Selected path with deviation {best_deviation:.6f}")

        return best_path if best_path else candidate_paths[0]  # Fallback to shortest

    except Exception as e:
        if trace:
            WranglerLogger.debug(
                f"Shape-aware routing failed: {type(e).__name__}: {e}, falling back to standard shortest path"
            )
            import traceback

            WranglerLogger.debug(f"Traceback: {traceback.format_exc()}")
        # Fallback to standard shortest path
        return nx.shortest_path(G_bus, start_node, end_node, weight="distance")

network_wrangler.utils.transit.get_original_shape_points_between_stops

get_original_shape_points_between_stops(
    feed_tables,
    shape_id,
    from_stop_seq,
    to_stop_seq,
    trace=False,
)

Get original GTFS shape points between two consecutive stops.

Uses stop_sequence information already added by add_additional_data_to_shapes().

Parameters:

Name Type Description Default
feed_tables dict

GTFS feed tables dictionary

required
shape_id str

Shape identifier

required
from_stop_seq int

Starting stop sequence number

required
to_stop_seq int

Ending stop sequence number (should be from_stop_seq + 1)

required
trace bool

If True, enable trace logging for debugging

False

Returns:

Type Description

DataFrame of shape points between stops, or empty DataFrame if not found

Source code in network_wrangler/utils/transit.py
def get_original_shape_points_between_stops(  # noqa: PLR0912
    feed_tables: dict, shape_id: str, from_stop_seq: int, to_stop_seq: int, trace: bool = False
):
    """Get original GTFS shape points between two consecutive stops.

    Uses stop_sequence information already added by add_additional_data_to_shapes().

    Args:
        feed_tables: GTFS feed tables dictionary
        shape_id: Shape identifier
        from_stop_seq: Starting stop sequence number
        to_stop_seq: Ending stop sequence number (should be from_stop_seq + 1)
        trace: If True, enable trace logging for debugging

    Returns:
        DataFrame of shape points between stops, or empty DataFrame if not found
    """
    try:
        if trace:
            WranglerLogger.debug(
                f"Getting shape points for shape_id={shape_id} between stop_seq {from_stop_seq} and {to_stop_seq}"
            )

        # Get shape points for this shape_id
        shape_points = feed_tables["shapes"][feed_tables["shapes"]["shape_id"] == shape_id].copy()

        if trace:
            WranglerLogger.debug(
                f"  trace Found {len(shape_points)} total shape points for shape_id={shape_id}"
            )
            if not shape_points.empty and "stop_sequence" in shape_points.columns:
                unique_stop_seqs = shape_points["stop_sequence"].dropna().unique()
                WranglerLogger.debug(
                    f"  Unique stop_sequences in shape: {sorted(unique_stop_seqs)}"
                )

        if shape_points.empty:
            if trace:
                WranglerLogger.debug(f"  No shape points found for shape_id={shape_id}")
            return shape_points

        # Sort by shape_pt_sequence
        shape_points = shape_points.sort_values("shape_pt_sequence")

        # Check if stop_sequence column exists
        if "stop_sequence" not in shape_points.columns:
            if trace:
                WranglerLogger.debug(
                    f"  WARNING: 'stop_sequence' column not found in shapes table"
                )
                WranglerLogger.debug(f"  Available columns: {list(shape_points.columns)}")
            return pd.DataFrame()

        # Find the shape_pt_sequence values for the start and end stops
        # Only rows with stop_sequence values are actual stops
        stop_points = shape_points[shape_points["stop_sequence"].notna()]

        if trace:
            WranglerLogger.debug(
                f"  Found {len(stop_points)} stop points (vs {len(shape_points)} total shape points)"
            )
            if not stop_points.empty:
                WranglerLogger.debug(
                    f"  Stop sequences present: {sorted(stop_points['stop_sequence'].unique())}"
                )

        # Find shape_pt_sequence range for the requested stop sequences
        from_stop_points = stop_points[stop_points["stop_sequence"] == from_stop_seq]
        to_stop_points = stop_points[stop_points["stop_sequence"] == to_stop_seq]

        if from_stop_points.empty or to_stop_points.empty:
            if trace:
                WranglerLogger.debug(
                    f"  WARNING: Could not find stop sequences {from_stop_seq} or {to_stop_seq}"
                )
                if from_stop_points.empty:
                    WranglerLogger.debug(f"    from_stop_seq {from_stop_seq} not found")
                if to_stop_points.empty:
                    WranglerLogger.debug(f"    to_stop_seq {to_stop_seq} not found")
            return pd.DataFrame()

        # Get the shape_pt_sequence values for these stops
        from_shape_seq = from_stop_points["shape_pt_sequence"].iloc[0]
        to_shape_seq = to_stop_points["shape_pt_sequence"].iloc[0]

        if trace:
            WranglerLogger.debug(
                f"  Stop sequence {from_stop_seq} is at shape_pt_sequence {from_shape_seq}"
            )
            WranglerLogger.debug(
                f"  Stop sequence {to_stop_seq} is at shape_pt_sequence {to_shape_seq}"
            )

        # Filter to get all shape points between these two stops (inclusive)
        shape_points = shape_points[
            (shape_points["shape_pt_sequence"] >= from_shape_seq)
            & (shape_points["shape_pt_sequence"] <= to_shape_seq)
        ]

        if trace:
            WranglerLogger.debug(
                f"  Filtered to {len(shape_points)} shape points between shape_pt_sequences {from_shape_seq} and {to_shape_seq}"
            )
            if not shape_points.empty:
                num_with_stops = shape_points["stop_sequence"].notna().sum()
                WranglerLogger.debug(
                    f"    Including {num_with_stops} stop points and {len(shape_points) - num_with_stops} intermediate shape points"
                )

        return shape_points
    except Exception as e:
        if trace:
            WranglerLogger.debug(
                f"  ERROR in get_original_shape_points_between_stops: {type(e).__name__}: {e}"
            )
            import traceback

            WranglerLogger.debug(f"  Traceback: {traceback.format_exc()}")
        return pd.DataFrame()

network_wrangler.utils.transit.match_bus_stops_to_roadway_nodes

match_bus_stops_to_roadway_nodes(
    feed_tables,
    roadway_net,
    local_crs,
    crs_units,
    max_distance,
    trace_shape_ids=None,
    use_name_matching=True,
    name_match_weight=None,
    config=DefaultConfig,
)

Match bus stops to bus-accessible nodes in the roadway network.

Matches bus and trolleybus stops to the nearest bus-accessible nodes in the roadway network using spatial proximity and optionally street name compatibility. Updates stop and shape locations to snap to road nodes.

Process Steps:

  1. Identifies bus stops (route_types BUS or TROLLEYBUS) in feed_tables[‘stops’]
  2. Builds bus network graph from roadway to find accessible nodes
  3. Projects geometries to local CRS for accurate distance calculations
  4. Uses BallTree spatial index to find candidate nodes within max_distance
  5. If name matching is enabled and link_names exist, scores candidates by both distance and name compatibility, selecting best match within max_distance
  6. Marks stops with combined_score > 0.9 as poor_match=True (only when name matching enabled)
  7. Excludes stops that serve station route types (rail, ferry, etc.) - these are handled separately
  8. For poor_match stops, their model_node_id is the nearest bus-accessible node (to use for creating connector links later)
  9. Updates stop locations to matched road node locations (except poor_match stops)
  10. Updates shape point locations for matched bus stops

Modifies feed_tables in place:

  • feed_tables[‘stops’] - Adds/modifies columns:

    • is_bus_stop (bool): True if stop serves BUS or TROLLEYBUS routes
    • model_node_id (int): Matched roadway node ID (None if no close match)
    • match_distance_{crs_units} (float): Distance to matched node
    • close_match (bool): True if match found within max_distance
    • poor_match (bool): True if combined_score > 0.9 AND stop doesn’t serve station routes (only when name matching enabled)
    • stop_lon, stop_lat, geometry: Updated to road node location if close_match and not poor_match
  • feed_tables[‘shapes’] - Adds/modifies columns:

    • shape_model_node_id (int): Matched roadway node ID for bus stops
    • match_distance_{crs_units} (float): Distance to matched node
    • shape_pt_lon, shape_pt_lat, geometry: Updated to road node location if valid match
  • feed_tables[‘stop_times’] - If GeoDataFrame, updates:

    • geometry: Updated to matched road node location for bus stops

Parameters:

Name Type Description Default
feed_tables dict[str, DataFrame]

dictionary of GTFS feed tables. Expects: - ‘stops’: Must have route_types column (list of RouteType enums) - ‘shapes’: Shape points to update - ‘stop_times’: Optional, updated if present as GeoDataFrame

required
roadway_net RoadwayNetwork

RoadwayNetwork with nodes to match against. Will be converted to GeoDataFrame if needed.

required
local_crs str

Coordinate reference system for projections (e.g., “EPSG:2227”)

required
crs_units str

Distance units for local_crs (‘feet’ or ‘meters’)

required
max_distance float

Maximum matching distance in crs_units

required
trace_shape_ids list[str] | None

Optional list of shape_ids for debug logging

None
use_name_matching bool

If True and nodes have ‘link_names’, will consider name compatibility when selecting best match within max_distance. Default is True.

True
name_match_weight float | None

Weight for name match score in combined scoring (0.0 to 1.0). Final score = (1 - name_match_weight) * normalized_distance + name_match_weight * name_score Defaults to NAME_MATCH_WEIGHT constant.

None
config WranglerConfig

WranglerConfig with TRANSIT settings for name matching thresholds.

DefaultConfig

Raises:

Type Description
TransitValidationError

If no bus-accessible nodes found near any bus stops

Notes
  • Only matches stops that serve BUS or TROLLEYBUS routes
  • Uses bus modal graph to ensure matched nodes are bus-accessible
  • Preserves original locations for non-bus stops
Source code in network_wrangler/utils/transit.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
def match_bus_stops_to_roadway_nodes(  # noqa: PLR0912, PLR0915
    feed_tables: dict[str, pd.DataFrame],
    roadway_net: RoadwayNetwork,
    local_crs: str,
    crs_units: str,
    max_distance: float,
    trace_shape_ids: list[str] | None = None,
    use_name_matching: bool = True,
    name_match_weight: float | None = None,
    config: WranglerConfig = DefaultConfig,
):
    """Match bus stops to bus-accessible nodes in the roadway network.

    Matches bus and trolleybus stops to the nearest bus-accessible nodes in the roadway
    network using spatial proximity and optionally street name compatibility.
    Updates stop and shape locations to snap to road nodes.

    Process Steps:

    1. Identifies bus stops (route_types BUS or TROLLEYBUS) in feed_tables['stops']
    2. Builds bus network graph from roadway to find accessible nodes
    3. Projects geometries to local CRS for accurate distance calculations
    4. Uses BallTree spatial index to find candidate nodes within max_distance
    5. If name matching is enabled and link_names exist, scores candidates by both
       distance and name compatibility, selecting best match within max_distance
    6. Marks stops with combined_score > 0.9 as poor_match=True (only when name matching enabled)
       - Excludes stops that serve station route types (rail, ferry, etc.) - these are handled separately
    7. For poor_match stops, their model_node_id is the nearest bus-accessible node
       (to use for creating connector links later)
    8. Updates stop locations to matched road node locations (except poor_match stops)
    9. Updates shape point locations for matched bus stops

    Modifies feed_tables in place:

    * feed_tables['stops'] - Adds/modifies columns:

        - is_bus_stop (bool): True if stop serves BUS or TROLLEYBUS routes
        - model_node_id (int): Matched roadway node ID (None if no close match)
        - match_distance_{crs_units} (float): Distance to matched node
        - close_match (bool): True if match found within max_distance
        - poor_match (bool): True if combined_score > 0.9 AND stop doesn't serve station routes (only when name matching enabled)
        - stop_lon, stop_lat, geometry: Updated to road node location if close_match and not poor_match

    * feed_tables['shapes'] - Adds/modifies columns:

        - shape_model_node_id (int): Matched roadway node ID for bus stops
        - match_distance_{crs_units} (float): Distance to matched node
        - shape_pt_lon, shape_pt_lat, geometry: Updated to road node location if valid match

    * feed_tables['stop_times'] - If GeoDataFrame, updates:

        - geometry: Updated to matched road node location for bus stops

    Args:
        feed_tables: dictionary of GTFS feed tables. Expects:
            - 'stops': Must have route_types column (list of RouteType enums)
            - 'shapes': Shape points to update
            - 'stop_times': Optional, updated if present as GeoDataFrame
        roadway_net: RoadwayNetwork with nodes to match against.
            Will be converted to GeoDataFrame if needed.
        local_crs: Coordinate reference system for projections (e.g., "EPSG:2227")
        crs_units: Distance units for local_crs ('feet' or 'meters')
        max_distance: Maximum matching distance in crs_units
        trace_shape_ids: Optional list of shape_ids for debug logging
        use_name_matching: If True and nodes have 'link_names', will consider name
            compatibility when selecting best match within max_distance. Default is True.
        name_match_weight: Weight for name match score in combined scoring (0.0 to 1.0).
            Final score = (1 - name_match_weight) * normalized_distance + name_match_weight * name_score
            Defaults to NAME_MATCH_WEIGHT constant.
        config: WranglerConfig with TRANSIT settings for name matching thresholds.

    Raises:
        TransitValidationError: If no bus-accessible nodes found near any bus stops

    Notes:
        - Only matches stops that serve BUS or TROLLEYBUS routes
        - Uses bus modal graph to ensure matched nodes are bus-accessible
        - Preserves original locations for non-bus stops
    """
    if crs_units not in ["feet", "meters"]:
        msg = f"crs_units must be one of 'feet' or 'meters'; received {crs_units}"
        raise ValueError(msg)

    # Use config default if name_match_weight not provided
    if name_match_weight is None:
        name_match_weight = config.TRANSIT.NAME_MATCH_WEIGHT

    # Make roadway network nodes a GeoDataFrame if it's not already
    if not isinstance(roadway_net.nodes_df, gpd.GeoDataFrame):
        # Convert to GeoDataFrame if needed
        roadway_net.nodes_df = gpd.GeoDataFrame(
            roadway_net.nodes_df,
            geometry=gpd.points_from_xy(roadway_net.nodes_df.X, roadway_net.nodes_df.Y),
            crs=LAT_LON_CRS,
        )

    # Collect bus stops to match
    feed_tables["stops"]["is_bus_stop"] = False
    feed_tables["stops"].loc[
        feed_tables["stops"]["route_types"].apply(
            lambda x: RouteType.BUS in x if isinstance(x, list) else False
        )
        | feed_tables["stops"]["route_types"].apply(
            lambda x: RouteType.TROLLEYBUS in x if isinstance(x, list) else False
        ),
        "is_bus_stop",
    ] = True
    bus_stops_gdf = (
        feed_tables["stops"].loc[feed_tables["stops"]["is_bus_stop"] == True].copy(deep=True)
    )

    # Reset index to ensure continuous indices from 0 to n-1
    bus_stops_gdf.reset_index(drop=True, inplace=True)

    WranglerLogger.info(f"Matching {len(bus_stops_gdf):,} bus stops to roadway nodes")
    WranglerLogger.debug(f"bus_stops_gdf:\n{bus_stops_gdf}")
    WranglerLogger.debug(f"feed_tables['shapes']:\n{feed_tables['shapes']}")
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace shapes for {trace_shape_id}:\n"
                f"{feed_tables['shapes'].loc[feed_tables['shapes']['shape_id'] == trace_shape_id]}"
            )

    # Build BallTree for bus-access nodes that are in the bus graph
    # This ensures we only match to connected nodes
    # TODO: @e-lo wonders if this is necessary?
    G_bus = roadway_net.get_modal_graph("bus")
    bus_graph_nodes = set(G_bus.nodes())
    roadway_net.nodes_df["bus_access"] = False
    roadway_net.nodes_df.loc[
        roadway_net.nodes_df["model_node_id"].isin(bus_graph_nodes), "bus_access"
    ] = True
    WranglerLogger.debug(
        f"bus-accessible nodes from graph: {roadway_net.nodes_df['bus_access'].sum():,}"
    )

    # Extract bus-accessible nodes
    # TODO: Extract to roadway/nodes/filters.py and use filter_nodes_to_links()
    bus_accessible_nodes_gdf = roadway_net.nodes_df[
        roadway_net.nodes_df["bus_access"] == True
    ].copy()
    WranglerLogger.debug(
        f"Found {len(bus_accessible_nodes_gdf):,} bus-accessible nodes (for mixed-traffic transit) "
        f"out of {len(roadway_net.nodes_df):,} total"
    )

    # Project bus nodes and bus_stops_gdf to specified CRS and bus_stops
    bus_accessible_nodes_gdf.to_crs(local_crs, inplace=True)
    bus_stops_gdf.to_crs(local_crs, inplace=True)

    # Save original projected geometries AFTER projection (for restoring unmatched stops later)
    bus_stops_gdf["geometry_original"] = bus_stops_gdf["geometry"].copy()

    # Initialize results
    bus_stops_gdf["model_node_id"] = None
    bus_stops_gdf[f"match_distance_{crs_units}"] = np.inf  # in crs_units

    # Build spatial index for bus nodes
    try:
        from sklearn.neighbors import BallTree
    except ImportError as e:
        msg = (
            "sklearn is required for transit stop matching. Install with: pip install scikit-learn"
        )
        raise ImportError(msg) from e

    bus_node_coords = np.array([(geom.x, geom.y) for geom in bus_accessible_nodes_gdf.geometry])
    bus_nodes_tree = BallTree(bus_node_coords)
    WranglerLogger.debug(f"Created BallTree for {len(bus_node_coords):,} bus nodes")

    # Get coordinates of stops to match
    bus_stop_coords = np.array([(geom.x, geom.y) for geom in bus_stops_gdf.geometry])

    # Query nearest neighbors - use more candidates if name matching is enabled
    k = 1  # Default to nearest neighbor only
    if use_name_matching and "link_names" in bus_accessible_nodes_gdf.columns:
        k = min(config.TRANSIT.K_NEAREST_CANDIDATES, len(bus_accessible_nodes_gdf))
        WranglerLogger.info(
            f"Using name-aware matching within {max_distance} {crs_units} "
            f"with name weight {name_match_weight}"
        )

    match_distances, match_indices = bus_nodes_tree.query(bus_stop_coords, k=k)

    # Process results based on whether we're doing name matching
    if k > 1:  # Name matching with multiple candidates
        # Initialize arrays to store best matches
        best_indices = np.zeros(len(bus_stops_gdf), dtype=int)
        best_distances = np.zeros(len(bus_stops_gdf))
        name_match_scores = np.zeros(len(bus_stops_gdf))
        normalized_dists = np.zeros(len(bus_stops_gdf))
        combined_scores = np.zeros(len(bus_stops_gdf))

        # Find best match for each stop considering both distance and name
        for stop_idx in range(len(bus_stops_gdf)):
            stop_name = bus_stops_gdf.iloc[stop_idx]["stop_name"]
            distances = match_distances[stop_idx]
            indices = match_indices[stop_idx]

            best_score = float("inf")
            best_idx = 0
            best_name_score = 0.0
            best_normalized_dist = 0.0
            best_combined_score = float("inf")

            # Evaluate candidates within max_distance
            candidates_found = False
            for i, (dist, node_idx) in enumerate(zip(distances, indices, strict=False)):
                # only look at candidates within max_distance
                if dist > max_distance:
                    continue

                candidates_found = True
                node_link_names = bus_accessible_nodes_gdf.iloc[node_idx].get("link_names", [])

                # Calculate name match score
                _, name_score, _ = assess_stop_name_roadway_compatibility(
                    stop_name,
                    node_link_names
                    if (node_link_names is not None and len(node_link_names) > 0)
                    else [],
                )

                # Combined score (lower is better)
                normalized_dist = dist / max_distance
                combined_score = (1 - name_match_weight) * normalized_dist + name_match_weight * (
                    1 - name_score
                )

                if combined_score < best_score:
                    best_score = combined_score
                    best_idx = i
                    best_name_score = name_score
                    best_normalized_dist = normalized_dist
                    best_combined_score = combined_score

            # If no candidates within max_distance, use closest regardless
            if not candidates_found:
                # best_idx is already 0 (closest)
                # Calculate normalized_dist as >1 to indicate beyond max_distance
                best_normalized_dist = distances[0] / max_distance
                best_combined_score = (
                    1 - name_match_weight
                ) * best_normalized_dist + name_match_weight * (1 - best_name_score)

            # Store best match
            best_indices[stop_idx] = indices[best_idx]
            best_distances[stop_idx] = distances[best_idx]
            name_match_scores[stop_idx] = best_name_score
            normalized_dists[stop_idx] = best_normalized_dist
            combined_scores[stop_idx] = best_combined_score

        # Create matches dataframe
        matches_df = pd.DataFrame(
            {
                "stop_idx": bus_stops_gdf.index,
                "match_node_idx": best_indices,
                "match_distance": best_distances,
                "name_match_score": name_match_scores,
                "normalized_dist": normalized_dists,
                "combined_score": combined_scores,
            }
        )
    else:
        # Simple nearest neighbor matching (k=1)
        matches_df = pd.DataFrame(
            {
                "stop_idx": bus_stops_gdf.index,
                "match_node_idx": match_indices.flatten(),
                "match_distance": match_distances.flatten(),
            }
        )

    # Check for close matches (within max_distance)
    matches_df["close_match"] = False
    matches_df.loc[matches_df["match_distance"] <= max_distance, "close_match"] = True
    WranglerLogger.debug(f"matches_df:\n{matches_df}")

    WranglerLogger.info(
        f"Found {matches_df.close_match.sum():,} close matches out of {len(bus_stops_gdf):,} total bus stops"
    )

    if matches_df.close_match.sum() == 0:
        exception = TransitValidationError("Found no bus-accessible nodes near bus stops.")
        raise exception

    # Get matched node information
    matched_nodes_gdf = bus_accessible_nodes_gdf.iloc[matches_df["match_node_idx"]]
    WranglerLogger.debug(f"matched_nodes_gdf:\n{matched_nodes_gdf}")

    # Update bus stops with matched node information (vectorized)
    # Update all matches with their corresponding node information
    # Note: geometry_original was already saved right after projection at line 682
    bus_stops_gdf.loc[matches_df["stop_idx"], f"match_distance_{crs_units}"] = matches_df[
        "match_distance"
    ].values
    bus_stops_gdf.loc[matches_df["stop_idx"], "close_match"] = matches_df["close_match"].values
    bus_stops_gdf.loc[matches_df["stop_idx"], "model_node_id"] = matched_nodes_gdf[
        "model_node_id"
    ].values
    bus_stops_gdf.loc[matches_df["stop_idx"], "stop_lon"] = matched_nodes_gdf["X"].values
    bus_stops_gdf.loc[matches_df["stop_idx"], "stop_lat"] = matched_nodes_gdf["Y"].values
    bus_stops_gdf.loc[matches_df["stop_idx"], "geometry"] = matched_nodes_gdf["geometry"].values

    # Add node link_names and name match scores if available
    if "link_names" in matched_nodes_gdf.columns:
        bus_stops_gdf.loc[matches_df["stop_idx"], "node_link_names"] = matched_nodes_gdf[
            "link_names"
        ].values
    if "name_match_score" in matches_df.columns:
        bus_stops_gdf.loc[matches_df["stop_idx"], "name_match_score"] = matches_df[
            "name_match_score"
        ].values
        bus_stops_gdf.loc[matches_df["stop_idx"], "normalized_dist"] = matches_df[
            "normalized_dist"
        ].values
        bus_stops_gdf.loc[matches_df["stop_idx"], "combined_score"] = matches_df[
            "combined_score"
        ].values
        # Report poor name matches
        poor_name_matches = matches_df[
            (matches_df["close_match"] == True) & (matches_df["name_match_score"] < 0.5)  # noqa: PLR2004
        ]
        if len(poor_name_matches) > 0:
            WranglerLogger.info(
                f"Found {len(poor_name_matches)} bus stops with low name compatibility (score < 0.5)"
            )

    # Mark stops with poor combined_score (> 0.9) as poor_match
    # poor_match only applies when name matching is enabled (combined_score exists)
    # These will be handled by add_unmatched_bus_stops_to_network()
    debug_cols = [
        "stop_id",
        "stop_name",
        "model_node_id",
        f"match_distance_{crs_units}",
        "close_match",
        "name_match_score",
        "node_link_names",
        "geometry",
    ]

    # combined_score only exists when use_name_matching=True and link_names column exists
    # poor_match is defined as having combined_score > 0.9
    # BUT exclude stops that serve station route types (rail, ferry, etc.) - they're handled separately
    if "combined_score" in bus_stops_gdf.columns:
        debug_cols = [*debug_cols, "combined_score", "poor_match"]

        # Check if stop serves any station route types (rail, ferry, etc.)
        bus_stops_gdf["serves_station_routes"] = bus_stops_gdf["route_types"].apply(
            lambda x: any(rt in STATION_ROUTE_TYPES for rt in x) if isinstance(x, list) else False
        )

        # poor_match = poor score AND not a station stop
        bus_stops_gdf["poor_match"] = (
            (bus_stops_gdf["close_match"] == True)
            & (bus_stops_gdf["combined_score"] > 0.9)  # noqa: PLR2004
            & (bus_stops_gdf["serves_station_routes"] == False)
        )
        poor_score_stops = bus_stops_gdf[bus_stops_gdf["poor_match"] == True]

        # Log excluded stops (poor score but serve station routes)
        excluded_station_stops = bus_stops_gdf[
            (bus_stops_gdf["close_match"] == True)
            & (bus_stops_gdf["combined_score"] > 0.9)(  # noqa: PLR2004 &
                bus_stops_gdf["serves_station_routes"] == True
            )
        ]
        if len(excluded_station_stops) > 0:
            WranglerLogger.info(
                f"Found {len(excluded_station_stops)} stops with poor combined_score (> 0.9) "
                f"that serve station route types (rail/ferry/etc). These will NOT be marked as "
                f"poor_match and will be handled as stations in step 7."
            )

        if len(poor_score_stops) > 0:
            WranglerLogger.info(
                f"Found {len(poor_score_stops)} bus-only stops with poor_match=True (combined_score > 0.9). "
                f"These will be treated as unmatched stops and added to the network with "
                f"connector links to nearest bus-accessible nodes."
            )

            # Restore original geometry for poor_match stops (they shouldn't snap to matched nodes)
            # Use the saved geometry from after projection but before matching
            bus_stops_gdf.loc[bus_stops_gdf["poor_match"], "geometry"] = bus_stops_gdf.loc[
                bus_stops_gdf["poor_match"], "geometry_original"
            ]

            debug_cols_with_poor = [*debug_cols, "poor_match"]
            WranglerLogger.debug(
                f"poor_match stops:\n"
                f"{bus_stops_gdf.loc[bus_stops_gdf['poor_match'], debug_cols_with_poor]}"
            )

        # Clean up temporary column
        bus_stops_gdf.drop(columns=["serves_station_routes"], inplace=True)
    else:
        # No name matching, so no poor_match stops (all matches are based on distance only)
        bus_stops_gdf["poor_match"] = False

    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace bus_stops_gdf for {trace_shape_id}:\n"
                f"{bus_stops_gdf.loc[bus_stops_gdf['shape_ids'].apply(lambda x, tid=trace_shape_id: tid in x), debug_cols]}"
            )

    # Clean up temporary column
    if "geometry_original" in bus_stops_gdf.columns:
        bus_stops_gdf.drop(columns=["geometry_original"], inplace=True)

    # verify model_node_id, f'match_distance_{crs_units}' and 'close_match' are not in feed_tables['stops']
    assert "model_node_id" not in feed_tables["stops"].columns
    assert f"match_distance_{crs_units}" not in feed_tables["stops"].columns
    assert "close_match" not in feed_tables["stops"].columns

    # Update feed_tables['stops'] by merging the updates
    merge_cols = [
        "stop_id",
        "model_node_id",
        f"match_distance_{crs_units}",
        "stop_lon",
        "stop_lat",
        "geometry",
        "close_match",
        "poor_match",
    ]
    feed_tables["stops"] = feed_tables["stops"].merge(
        bus_stops_gdf[merge_cols],
        on="stop_id",
        how="left",
        suffixes=("", "_bus"),
        validate="one_to_one",
    )
    # Only update stop location for close match AND not unmatched
    # (unmatched stops keep original location until they're added to network in step 6a)
    update_mask = (feed_tables["stops"]["close_match"] == True) & (
        feed_tables["stops"]["poor_match"] == False
    )
    feed_tables["stops"].loc[update_mask, "stop_lon"] = feed_tables["stops"]["stop_lon_bus"]
    feed_tables["stops"].loc[update_mask, "stop_lat"] = feed_tables["stops"]["stop_lat_bus"]
    feed_tables["stops"].loc[update_mask, "geometry"] = feed_tables["stops"]["geometry_bus"]
    # Drop bus-specific columns
    feed_tables["stops"].drop(
        columns=["stop_lon_bus", "stop_lat_bus", "geometry_bus"], inplace=True
    )

    WranglerLogger.debug(
        f"After merging with bus_stops_gdf, feed_tables['stops']:\n{feed_tables['stops']}"
    )
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace feed_tables['stops'] for {trace_shape_id}:\n"
                f"{feed_tables['stops'].loc[feed_tables['stops']['shape_ids'].apply(lambda x, tid=trace_shape_id: isinstance(x, list) and tid in x)]}"
            )

    # summary of match distances
    WranglerLogger.debug(
        f"\n{feed_tables['stops'][['close_match', f'match_distance_{crs_units}']].describe()}"
    )

    # Update shapes table similarly
    WranglerLogger.debug(f"feed_tables['shapes']:\n{feed_tables['shapes']}")
    # columns: shape_id, shape_pt_lat, shape_pt_lon, shape_pt_sequence, shape_dist_traveled, geometry
    #   trip_id, direction_id, route_id, agency_id, route_short_name, route_type, agency_name,
    #   match_distance_feet, stop_id, stop_name, stop_sequence
    # Note: this is adding: 'model_node_id'
    feed_tables["shapes"] = pd.merge(
        left=feed_tables["shapes"],
        right=bus_stops_gdf[
            [
                "stop_id",
                "model_node_id",
                f"match_distance_{crs_units}",
                "stop_lon",
                "stop_lat",
                "geometry",
                "close_match",
                "poor_match",
            ]
        ],
        on="stop_id",
        how="left",
        suffixes=("", "_bus"),
        validate="many_to_one",
    ).rename(columns={"model_node_id": "shape_model_node_id"})
    # Only update stop location for close match AND not unmatched
    # (unmatched stops keep original location in shapes until they're added to network in step 6a)
    shape_update_mask = (feed_tables["shapes"]["close_match"] == True) & (
        feed_tables["shapes"]["poor_match"] == False
    )
    feed_tables["shapes"].loc[shape_update_mask, "shape_pt_lon"] = feed_tables["shapes"][
        "stop_lon"
    ]
    feed_tables["shapes"].loc[shape_update_mask, "shape_pt_lat"] = feed_tables["shapes"][
        "stop_lat"
    ]
    feed_tables["shapes"].loc[shape_update_mask, "geometry"] = feed_tables["shapes"][
        "geometry_bus"
    ]
    feed_tables["shapes"].loc[shape_update_mask, f"match_distance_{crs_units}"] = feed_tables[
        "shapes"
    ][f"match_distance_{crs_units}_bus"]

    if trace_shape_ids:
        debug_cols = [
            "shape_pt_sequence",
            "stop_sequence",
            "stop_id",
            "stop_name",
            "shape_model_node_id",
            "match_distance_feet_bus",
            "poor_match",
        ]
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace feed_tables['shapes'] for {trace_shape_id}"
                f"at the end of match_bus_stops_to_roadway_nodes():\n"
                f"{feed_tables['shapes'].loc[feed_tables['shapes']['shape_id'] == trace_shape_id][debug_cols]}"
            )

    # Drop bus-specific columns
    feed_tables["shapes"].drop(
        columns=[
            "stop_lon",
            "stop_lat",
            "geometry_bus",
            f"match_distance_{crs_units}_bus",
            "close_match",
        ],
        inplace=True,
    )

    # update geometry of feed_tables['stop_times'] if needed
    if isinstance(feed_tables["stop_times"], gpd.GeoDataFrame):
        feed_tables["stop_times"] = pd.merge(
            left=feed_tables["stop_times"],
            right=bus_stops_gdf[["stop_id", "geometry"]],
            how="left",
            on="stop_id",
            validate="many_to_one",
            suffixes=("", "_bus"),
            indicator=True,
        )
        WranglerLogger.debug(
            f"Updating feed['stop_times'] stop locations for bus stops:\n{feed_tables['stop_times']._merge.value_counts()}"
        )
        feed_tables["stop_times"].loc[
            feed_tables["stop_times"]["geometry_bus"].notna(), "geometry"
        ] = feed_tables["stop_times"]["geometry_bus"]
        feed_tables["stop_times"].drop(columns=["geometry_bus", "_merge"], inplace=True)

network_wrangler.utils.transit.route_shapes_between_stops

route_shapes_between_stops(
    bus_stop_links_gdf,
    feed_tables,
    roadway_net,
    local_crs,
    crs_units,
    trace_shape_ids=None,
    errors="raise",
    default_link_attribute_dict=None,
)

Find shortest paths through the bus network between consecutive bus stops.

Replaces original bus route shapes with new shapes that follow the actual bus network by finding shortest paths between consecutive stops through bus-accessible roads.

Process Steps: 1. Sorts bus stop links by shape_id and stop_sequence 2. Gets bus modal graph from roadway network 3. For each consecutive stop pair in each shape: - Finds shortest path through bus network using NetworkX - Creates shape points for all nodes in the path - Preserves stop information at stop nodes 4. Replaces bus shapes in feed_tables[‘shapes’] with new routed shapes

Modifies feed_tables[‘shapes’] in place: - Removes existing bus/trolleybus shapes - Adds new shapes with points following road network paths - Each shape point has shape_model_node_id from roadway network - Stop points retain stop_id, stop_name, stop_sequence

Parameters:

Name Type Description Default
bus_stop_links_gdf GeoDataFrame

GeoDataFrame of bus stop pairs, required columns: - shape_id (str): Shape identifier - stop_sequence (int): Stop order in route - stop_id (str): Current stop ID - stop_name (str): Current stop name - next_stop_id (str): Next stop ID - next_stop_name (str): Next stop name - A (int): Current stop’s model_node_id - B (int): Next stop’s model_node_id - route_id, route_type, trip_id, direction_id: Route metadata

required
feed_tables dict[str, DataFrame]

dictionary with required tables: - ‘stops’: Must have is_bus_stop column - ‘shapes’: Will be modified to replace bus shapes

required
roadway_net RoadwayNetwork

RoadwayNetwork with bus modal graph

required
local_crs str

Coordinate reference system for projections

required
crs_units str

Distance units (‘feet’ or ‘meters’)

required
trace_shape_ids list[str] | None

Optional shape IDs for debug logging

None
errors Literal['raise', 'ignore']

‘raise’ or ‘ignore’

'raise'
default_link_attribute_dict dict[str, any] | None

Optional dict of column-name to default value to set on new links.

None

Raises:

Type Description
TransitValidationError

If no path exists between any consecutive stops. Exception includes no_bus_path_gdf with failed stop sequences.

Notes
  • Uses NetworkX shortest_path for routing
  • Intermediate nodes between stops are added as shape points
  • Original shape geometry is replaced with routed paths
Source code in network_wrangler/utils/transit.py
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
def route_shapes_between_stops(  # noqa: PLR0912, PLR0915
    bus_stop_links_gdf: gpd.GeoDataFrame,
    feed_tables: dict[str, pd.DataFrame],
    roadway_net: RoadwayNetwork,
    local_crs: str,
    crs_units: str,
    trace_shape_ids: list[str] | None = None,
    errors: Literal["raise", "ignore"] = "raise",
    default_link_attribute_dict: dict[str, any] | None = None,
):
    """Find shortest paths through the bus network between consecutive bus stops.

    Replaces original bus route shapes with new shapes that follow the actual bus network
    by finding shortest paths between consecutive stops through bus-accessible roads.

    Process Steps:
    1. Sorts bus stop links by shape_id and stop_sequence
    2. Gets bus modal graph from roadway network
    3. For each consecutive stop pair in each shape:
       - Finds shortest path through bus network using NetworkX
       - Creates shape points for all nodes in the path
       - Preserves stop information at stop nodes
    4. Replaces bus shapes in feed_tables['shapes'] with new routed shapes

    Modifies feed_tables['shapes'] in place:
    - Removes existing bus/trolleybus shapes
    - Adds new shapes with points following road network paths
    - Each shape point has shape_model_node_id from roadway network
    - Stop points retain stop_id, stop_name, stop_sequence

    Args:
        bus_stop_links_gdf: GeoDataFrame of bus stop pairs, required columns:
            - shape_id (str): Shape identifier
            - stop_sequence (int): Stop order in route
            - stop_id (str): Current stop ID
            - stop_name (str): Current stop name
            - next_stop_id (str): Next stop ID
            - next_stop_name (str): Next stop name
            - A (int): Current stop's model_node_id
            - B (int): Next stop's model_node_id
            - route_id, route_type, trip_id, direction_id: Route metadata
        feed_tables: dictionary with required tables:
            - 'stops': Must have is_bus_stop column
            - 'shapes': Will be modified to replace bus shapes
        roadway_net: RoadwayNetwork with bus modal graph
        local_crs: Coordinate reference system for projections
        crs_units: Distance units ('feet' or 'meters')
        trace_shape_ids: Optional shape IDs for debug logging
        errors: 'raise' or 'ignore'
        default_link_attribute_dict: Optional dict of column-name to default value to set on new links.

    Raises:
        TransitValidationError: If no path exists between any consecutive stops.
            Exception includes no_bus_path_gdf with failed stop sequences.

    Notes:
        - Uses NetworkX shortest_path for routing
        - Intermediate nodes between stops are added as shape points
        - Original shape geometry is replaced with routed paths
    """
    if crs_units not in ["feet", "meters"]:
        msg = f"crs_units must be one of 'feet' or 'meters'; received {crs_units}"
        raise ValueError(msg)

    WranglerLogger.info(f"Creating bus routes between bus stops")
    WranglerLogger.debug(
        f"bus stops:\n{feed_tables['stops'].loc[feed_tables['stops']['is_bus_stop'] == True]}"
    )
    WranglerLogger.debug(
        f"bus shapes:\n{feed_tables['shapes'].loc[feed_tables['shapes']['route_type'].isin([RouteType.BUS, RouteType.TROLLEYBUS])]}"
    )
    WranglerLogger.debug(f"bus_stop_links_gdf:\n{bus_stop_links_gdf}")
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace bus_stop_links_gdf for {trace_shape_id}:\n"
                f"{bus_stop_links_gdf.loc[bus_stop_links_gdf.shape_id == trace_shape_id]}"
            )

    # Create a shortest path through the bus network between each consecutive bus stop for a given shape_id,
    # traversing through intermediate roadway nodes
    # Create new shape links between those nodes
    # Check how far away the new shape links are from the given stop-to-stop shape links
    bus_stop_links_gdf.sort_values(by=["shape_id", "stop_sequence"], inplace=True)
    G_bus_multi = roadway_net.get_modal_graph("bus")

    # Convert MultiDiGraph to DiGraph for pathfinding
    # DiGraph is required for nx.shortest_simple_paths() used in shape-aware routing
    # Keep directionality but collapse multiple edges to shortest
    G_bus = nx.DiGraph()
    for u, v, data in G_bus_multi.edges(data=True):
        if G_bus.has_edge(u, v):
            # Keep edge with minimum distance
            if data.get("distance", float("inf")) < G_bus[u][v]["distance"]:
                G_bus[u][v]["distance"] = data.get("distance", 0)
        else:
            G_bus.add_edge(u, v, distance=data.get("distance", 0))

    WranglerLogger.debug(
        f"Converted MultiDiGraph ({G_bus_multi.number_of_edges()} edges) to DiGraph ({G_bus.number_of_edges()} edges)"
    )

    # collect node sequences for these shapes
    bus_node_sequence = []
    # also collect failed stop sequences
    no_path_sequence = []

    current_shape_id = None
    current_shape_pt_sequence = None
    for _idx, row in bus_stop_links_gdf.iterrows():
        # restart for each shape_id
        if current_shape_id != row["shape_id"]:
            current_shape_pt_sequence = 1
            current_shape_id = row["shape_id"]

        if trace_shape_ids and current_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace path {current_shape_id}: {_idx} Looking for path from {row['A']} to {row['B']}"
            )
            WranglerLogger.debug(f"\n{row}")

        # Check if either node is not in the bus graph (e.g., poor_match=True)
        # If so, skip pathfinding and add simple A->B connection to bus_node_sequence
        if not G_bus.has_node(row["A"]) or not G_bus.has_node(row["B"]):
            WranglerLogger.warning(
                f"Node not in bus graph for {row['shape_id']} from {row['A']} to {row['B']} "
                f"(likely poor_match). Adding direct A->B connection."
            )
            no_path_sequence.append(
                {
                    "shape_id": row["shape_id"],
                    "stop_id": row["stop_id"],
                    "next_stop_id": row["next_stop_id"],
                    "stop_sequence": row["stop_sequence"],
                }
            )

            # Add simple A->B path to bus_node_sequence (in correct order)
            path = [row["A"], row["B"]]
            if current_shape_pt_sequence != 1:
                path = path[1:]  # Skip first node to avoid duplication

            for path_node_id in path:
                bus_node_dict = {
                    "shape_id": row["shape_id"],
                    "route_id": row["route_id"],
                    "route_type": row["route_type"],
                    "trip_id": row["trip_id"],
                    "direction_id": row["direction_id"],
                    "shape_pt_sequence": current_shape_pt_sequence,
                    "shape_model_node_id": path_node_id,
                }
                # Set stop info for the nodes
                if path_node_id == row["A"]:
                    bus_node_dict["stop_id"] = row["stop_id"]
                    bus_node_dict["stop_name"] = row["stop_name"]
                    bus_node_dict["stop_sequence"] = row["stop_sequence"]
                elif path_node_id == row["B"]:
                    bus_node_dict["stop_id"] = row["next_stop_id"]
                    bus_node_dict["stop_name"] = row["next_stop_name"]
                    bus_node_dict["stop_sequence"] = row["stop_sequence"] + 1

                bus_node_sequence.append(bus_node_dict)
                current_shape_pt_sequence += 1
                last_stop_sequence = row["stop_sequence"]

            continue

        try:
            # Find shortest path with optional shape-aware selection
            use_shape_aware_routing = True  # Could be made a parameter

            if use_shape_aware_routing:
                # Get original shape points between these stops for comparison
                original_shape_points = get_original_shape_points_between_stops(
                    feed_tables,
                    row["shape_id"],
                    row["stop_sequence"],
                    row["stop_sequence"] + 1,
                    trace_shape_ids and current_shape_id in trace_shape_ids,
                )

                path = find_shape_aware_shortest_path(
                    G_bus,
                    row["A"],
                    row["B"],
                    original_shape_points,
                    roadway_net,
                    DefaultConfig.TRANSIT.SHAPE_DISTANCE_TOLERANCE,
                    trace_shape_ids and current_shape_id in trace_shape_ids,
                )
            else:
                # Standard shortest path
                path = nx.shortest_path(G_bus, row["A"], row["B"], weight="distance")

            if trace_shape_ids and current_shape_id in trace_shape_ids:
                WranglerLogger.debug(
                    f"trace path {current_shape_id}: Found path for {row['A']} to {row['B']}: len={len(path)} {path}"
                )

            # Create shape point rows for that path
            # Only include first point if it's the first path for the shape,
            # otherwise it'll be added twice -- as the last point of the previous path
            # and the first point of the current one
            if current_shape_pt_sequence != 1:
                path = path[1:]
            for path_node_id in path:
                bus_node_dict = {
                    "shape_id": row["shape_id"],
                    "route_id": row["route_id"],
                    "route_type": row["route_type"],
                    "trip_id": row["trip_id"],
                    "direction_id": row["direction_id"],
                    "shape_pt_sequence": current_shape_pt_sequence,
                    "shape_model_node_id": path_node_id,
                }
                # set these for the stops but leave blank for intermediate nodes
                if path_node_id == row["A"]:
                    bus_node_dict["stop_id"] = row["stop_id"]
                    bus_node_dict["stop_name"] = row["stop_name"]
                    bus_node_dict["stop_sequence"] = row["stop_sequence"]
                elif path_node_id == row["B"]:
                    bus_node_dict["stop_id"] = row["next_stop_id"]
                    bus_node_dict["stop_name"] = row["next_stop_name"]
                    bus_node_dict["stop_sequence"] = last_stop_sequence + 1

                bus_node_sequence.append(bus_node_dict)
                current_shape_pt_sequence += 1
                last_stop_sequence = row["stop_sequence"]

        except nx.NetworkXNoPath as e:
            WranglerLogger.warning(
                f"No path exists for {row['shape_id']} from {row['A']} ({row['stop_name']}) "
                f"to {row['B']} ({row['next_stop_name']})"
            )
            WranglerLogger.warning(e)
            # No path exists - add to no_path_sequence and add simple A->B to bus_node_sequence
            no_path_sequence.append(
                {
                    "shape_id": row["shape_id"],
                    "stop_id": row["stop_id"],
                    "next_stop_id": row["next_stop_id"],
                    "stop_sequence": row["stop_sequence"],
                }
            )

            # Add simple A->B path to bus_node_sequence (in correct order)
            path = [row["A"], row["B"]]
            if current_shape_pt_sequence != 1:
                path = path[1:]  # Skip first node to avoid duplication

            for path_node_id in path:
                bus_node_dict = {
                    "shape_id": row["shape_id"],
                    "route_id": row["route_id"],
                    "route_type": row["route_type"],
                    "trip_id": row["trip_id"],
                    "direction_id": row["direction_id"],
                    "shape_pt_sequence": current_shape_pt_sequence,
                    "shape_model_node_id": path_node_id,
                }
                # Set stop info for the nodes
                if path_node_id == row["A"]:
                    bus_node_dict["stop_id"] = row["stop_id"]
                    bus_node_dict["stop_name"] = row["stop_name"]
                    bus_node_dict["stop_sequence"] = row["stop_sequence"]
                elif path_node_id == row["B"]:
                    bus_node_dict["stop_id"] = row["next_stop_id"]
                    bus_node_dict["stop_name"] = row["next_stop_name"]
                    bus_node_dict["stop_sequence"] = row["stop_sequence"] + 1

                bus_node_sequence.append(bus_node_dict)
                current_shape_pt_sequence += 1
                last_stop_sequence = row["stop_sequence"]

        except nx.NodeNotFound as e:
            WranglerLogger.warning(
                f"Node not found for {row['shape_id']} from {row['A']} to {row['B']}: {e}. "
                f"Adding simple A->B connection."
            )
            # Node not in graph - add to no_path_sequence and add simple A->B to bus_node_sequence
            no_path_sequence.append(
                {
                    "shape_id": row["shape_id"],
                    "stop_id": row["stop_id"],
                    "next_stop_id": row["next_stop_id"],
                    "stop_sequence": row["stop_sequence"],
                }
            )

            # Add simple A->B path to bus_node_sequence (in correct order)
            path = [row["A"], row["B"]]
            if current_shape_pt_sequence != 1:
                path = path[1:]  # Skip first node to avoid duplication

            for path_node_id in path:
                bus_node_dict = {
                    "shape_id": row["shape_id"],
                    "route_id": row["route_id"],
                    "route_type": row["route_type"],
                    "trip_id": row["trip_id"],
                    "direction_id": row["direction_id"],
                    "shape_pt_sequence": current_shape_pt_sequence,
                    "shape_model_node_id": path_node_id,
                }
                # Set stop info for the nodes
                if path_node_id == row["A"]:
                    bus_node_dict["stop_id"] = row["stop_id"]
                    bus_node_dict["stop_name"] = row["stop_name"]
                    bus_node_dict["stop_sequence"] = row["stop_sequence"]
                elif path_node_id == row["B"]:
                    bus_node_dict["stop_id"] = row["next_stop_id"]
                    bus_node_dict["stop_name"] = row["next_stop_name"]
                    bus_node_dict["stop_sequence"] = row["stop_sequence"] + 1

                bus_node_sequence.append(bus_node_dict)
                current_shape_pt_sequence += 1
                last_stop_sequence = row["stop_sequence"]

    bus_node_sequence_df = pd.DataFrame(bus_node_sequence)
    WranglerLogger.debug(f"bus_node_sequence_df:\n{bus_node_sequence_df}")
    if trace_shape_ids:
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace bus_node_sequence_df for {trace_shape_id}:\n"
                f"{bus_node_sequence_df.loc[bus_node_sequence_df.shape_id == trace_shape_id]}"
            )

    if len(no_path_sequence) == 0:
        WranglerLogger.info(f"All bus route shapes mapped to roadway nodes")
    else:
        no_path_sequence_df = pd.DataFrame(no_path_sequence)
        WranglerLogger.debug(f"no_path_sequence_df:\n{no_path_sequence_df}")

        # join with bus_stop_links_gdf
        no_bus_path_gdf = gpd.GeoDataFrame(
            pd.merge(
                left=no_path_sequence_df,
                right=bus_stop_links_gdf,
                how="left",
                on=["stop_id", "next_stop_id", "stop_sequence", "shape_id"],
                validate="one_to_one",
            ),
            crs=bus_stop_links_gdf.crs,
        )
        WranglerLogger.debug(f"no_bus_path_gdf:\n{no_bus_path_gdf}")
        if trace_shape_ids:
            debug_cols = [
                "A",
                "B",
                "stop_sequence",
                "stop_id",
                "next_stop_id",
                "stop_name",
                "next_stop_name",
            ]
            for trace_shape_id in trace_shape_ids:
                WranglerLogger.debug(
                    f"trace no_bus_path_gdf for {trace_shape_id}:\n"
                    f"{no_bus_path_gdf.loc[no_bus_path_gdf['shape_id'] == trace_shape_id, debug_cols]}"
                )

        # raise an error if requested
        if errors == "raise":
            e = TransitValidationError(
                "Some bus stop sequences failed to find paths. See e.no_bus_path_gdf"
            )
            e.no_bus_path_gdf = no_bus_path_gdf
            raise e

        # if we're ignoring, then we need to create roadway network links for these - and mark them
        create_links_for_failed_bus_paths(
            roadway_net=roadway_net,
            no_bus_path_gdf=no_bus_path_gdf,
            local_crs=local_crs,
            crs_units=crs_units,
            trace_shape_ids=trace_shape_ids,
            default_link_attribute_dict=default_link_attribute_dict,
        )

    # create bus shapes
    # current shapes columns:
    #  shape_id, shape_pt_lat, shape_pt_lon, shape_pt_sequence, shape_dist_traveled, geometry,
    #  trip_id, direction_id, route_id, agency_id, route_short_name, route_type, agency_name, match_distance_feet,
    #  stop_id, stop_name, stop_sequence, shape_model_node_id

    # we have:
    #  shape_id, route_id, route_type, trip_id, direction_id, shape_pt_sequence, shape_model_node_id, stop_id, stop_name

    # Reorder to be similar
    # bus_node_sequence_df.sort_values(by=["trip_id",])
    bus_node_sequence_df = bus_node_sequence_df[
        [
            "shape_id",
            "shape_pt_sequence",
            "trip_id",
            "direction_id",
            "route_id",
            "route_type",
            "stop_id",
            "stop_name",
            "stop_sequence",
            "shape_model_node_id",
        ]
    ]
    # get agency_id, route_short_name, agency_name from existing feed_tables['shapes']
    bus_node_sequence_df = pd.merge(
        left=bus_node_sequence_df,
        right=feed_tables["shapes"][
            ["shape_id", "agency_id", "route_short_name", "agency_name"]
        ].drop_duplicates(),
        on="shape_id",
        how="left",
        validate="many_to_one",
    )

    # get lon, lat and geometry from roadway_net.nodes
    bus_node_sequence_gdf = gpd.GeoDataFrame(
        pd.merge(
            left=bus_node_sequence_df,
            right=roadway_net.nodes_df[["model_node_id", "X", "Y", "geometry"]].rename(
                columns={"model_node_id": "shape_model_node_id"}
            ),
            how="left",
            on="shape_model_node_id",
            validate="many_to_one",
        ).rename(columns={"X": "shape_pt_lon", "Y": "shape_pt_lat"}),
        crs=roadway_net.nodes_df.crs,
    )
    WranglerLogger.debug(f"Final bus_node_sequence_gdf:\n{bus_node_sequence_gdf}")
    if trace_shape_ids:
        debug_cols = [
            "shape_pt_sequence",
            "stop_sequence",
            "stop_id",
            "stop_name",
            "shape_model_node_id",
        ]
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace {trace_shape_id} bus_node_sequence_gdf:\n"
                f"{bus_node_sequence_gdf.loc[bus_node_sequence_gdf['shape_id'] == trace_shape_id, debug_cols]}"
            )

    feed_tables["shapes"].to_crs(LAT_LON_CRS, inplace=True)
    # replace bus links in feed_tables['shapes'] with bus_node_sequence_gdf
    feed_tables["shapes"] = pd.concat(
        [
            feed_tables["shapes"].loc[
                ~feed_tables["shapes"]["route_type"].isin([RouteType.BUS, RouteType.TROLLEYBUS])
            ],
            bus_node_sequence_gdf,
        ]
    )

    if trace_shape_ids:
        debug_cols = [
            "shape_pt_sequence",
            "stop_sequence",
            "stop_id",
            "stop_name",
            "shape_model_node_id",
        ]
        for trace_shape_id in trace_shape_ids:
            WranglerLogger.debug(
                f"trace feed_tables['shapes'] for {trace_shape_id} at the end of route_shapes_between_stops():\n"
                f"{feed_tables['shapes'].loc[feed_tables['shapes']['shape_id'] == trace_shape_id, debug_cols]}"
            )

network_wrangler.utils.transit.truncate_route_at_stop

truncate_route_at_stop(
    transit_data, route_id, direction_id, stop_id, truncate
)

Truncate all trips of a route at a specific stop.

Removes stops before or after the specified stop for all trips matching the given route_id and direction_id. This is useful for shortening routes at terminal stations or service boundaries. Modifies transit_data in place.

Parameters:

Name Type Description Default
transit_data GtfsModel | Feed

Either a GtfsModel or Feed object to modify. Modified in place.

required
route_id str

The route_id to truncate

required
direction_id int

The direction_id of trips to truncate (0 or 1)

required
stop_id str | int

The stop where truncation occurs. For GtfsModel, this should be a string stop_id. For Feed, this should be an integer model_node_id.

required
truncate Literal['before', 'after']

Either “before” to remove stops before stop_id, or “after” to remove stops after stop_id

required

Raises:

Type Description
ValueError

If truncate is not “before” or “after”

ValueError

If stop_id is not found in any trips of the route/direction

Example

Truncate outbound BART trips to end at Embarcadero (GtfsModel)

truncate_route_at_stop( … gtfs_model, … route_id=”BART-01”, … direction_id=0, … stop_id=”EMBR”, # string stop_id … truncate=”after”, … )

Truncate outbound BART trips to end at node 12345 (Feed)

truncate_route_at_stop( … feed, … route_id=”BART-01”, … direction_id=0, … stop_id=12345, # integer model_node_id … truncate=”after”, … )

Source code in network_wrangler/transit/filter.py
def truncate_route_at_stop(  # noqa: PLR0912, PLR0915
    transit_data: GtfsModel | Feed,
    route_id: str,
    direction_id: int,
    stop_id: str | int,
    truncate: Literal["before", "after"],
) -> None:
    """Truncate all trips of a route at a specific stop.

    Removes stops before or after the specified stop for all trips matching
    the given route_id and direction_id. This is useful for shortening routes
    at terminal stations or service boundaries. Modifies transit_data in place.

    Args:
        transit_data: Either a GtfsModel or Feed object to modify. Modified in place.
        route_id: The route_id to truncate
        direction_id: The direction_id of trips to truncate (0 or 1)
        stop_id: The stop where truncation occurs. For GtfsModel, this should be
                a string stop_id. For Feed, this should be an integer model_node_id.
        truncate: Either "before" to remove stops before stop_id, or
                 "after" to remove stops after stop_id

    Raises:
        ValueError: If truncate is not "before" or "after"
        ValueError: If stop_id is not found in any trips of the route/direction

    Example:
        >>> # Truncate outbound BART trips to end at Embarcadero (GtfsModel)
        >>> truncate_route_at_stop(
        ...     gtfs_model,
        ...     route_id="BART-01",
        ...     direction_id=0,
        ...     stop_id="EMBR",  # string stop_id
        ...     truncate="after",
        ... )
        >>> # Truncate outbound BART trips to end at node 12345 (Feed)
        >>> truncate_route_at_stop(
        ...     feed,
        ...     route_id="BART-01",
        ...     direction_id=0,
        ...     stop_id=12345,  # integer model_node_id
        ...     truncate="after",
        ... )
    """
    if truncate not in ["before", "after"]:
        msg = f"truncate must be 'before' or 'after', got '{truncate}'"
        raise ValueError(msg)

    WranglerLogger.info(
        f"Truncating route {route_id} direction {direction_id} {truncate} stop {stop_id}"
    )

    # Get data tables (references, not copies)
    routes_df = transit_data.routes
    trips_df = transit_data.trips
    stop_times_df = transit_data.stop_times
    stops_df = transit_data.stops
    is_gtfs = isinstance(transit_data, GtfsModel)

    # Find trips to truncate
    trips_to_truncate = trips_df[
        (trips_df.route_id == route_id) & (trips_df.direction_id == direction_id)
    ]

    if len(trips_to_truncate) == 0:
        WranglerLogger.warning(f"No trips found for route {route_id} direction {direction_id}")
        return  # No changes needed

    trip_ids_to_truncate = set(trips_to_truncate.trip_id)
    WranglerLogger.debug(f"Found {len(trip_ids_to_truncate)} trips to truncate")

    # Check if stop_id exists in any of these trips
    stop_times_for_route = stop_times_df[
        (stop_times_df.trip_id.isin(trip_ids_to_truncate)) & (stop_times_df.stop_id == stop_id)
    ]

    if len(stop_times_for_route) == 0:
        msg = f"Stop {stop_id} not found in any trips of route {route_id} direction {direction_id}"
        raise ValueError(msg)

    # Process stop_times to truncate trips
    truncated_stop_times = []
    trips_truncated = 0

    for trip_id in trip_ids_to_truncate:
        trip_stop_times = stop_times_df[stop_times_df.trip_id == trip_id]
        trip_stop_times = trip_stop_times.sort_values("stop_sequence")

        # Find the stop_sequence for the truncation stop
        stop_mask = trip_stop_times.stop_id == stop_id
        if not stop_mask.any():
            # This trip doesn't have the stop, keep all stops
            truncated_stop_times.append(trip_stop_times)
            continue

        stop_sequence_at_stop = trip_stop_times.loc[stop_mask, "stop_sequence"].iloc[0]

        # Truncate based on direction
        if truncate == "before":
            # Keep stops from stop_id onwards
            truncated_stops = trip_stop_times[
                trip_stop_times.stop_sequence >= stop_sequence_at_stop
            ].copy()  # Need copy here since we'll modify stop_sequence
        else:  # truncate == "after"
            # Keep stops up to and including stop_id
            truncated_stops = trip_stop_times[
                trip_stop_times.stop_sequence <= stop_sequence_at_stop
            ].copy()  # Need copy here since we'll modify stop_sequence

        # Renumber stop_sequence to be consecutive starting from 0
        if len(truncated_stops) > 0:
            truncated_stops["stop_sequence"] = range(len(truncated_stops))

        # Log truncation details
        original_count = len(trip_stop_times)
        truncated_count = len(truncated_stops)
        if truncated_count < original_count:
            trips_truncated += 1

            # Get removed stops details
            removed_stop_ids = set(trip_stop_times.stop_id) - set(truncated_stops.stop_id)
            if removed_stop_ids and len(removed_stop_ids) <= MAX_TRUNCATION_WARNING_STOPS:
                # Get stop names for removed stops
                removed_stops_info = stops_df[stops_df.stop_id.isin(removed_stop_ids)][
                    ["stop_id", "stop_name"]
                ]
                removed_stops_list = [
                    f"{row['stop_id']} ({row['stop_name']})"
                    for _, row in removed_stops_info.iterrows()
                ]

                WranglerLogger.debug(
                    f"Trip {trip_id}: truncated from {original_count} to {truncated_count} stops. "
                    f"Removed: {', '.join(removed_stops_list)}"
                )
            else:
                WranglerLogger.debug(
                    f"Trip {trip_id}: truncated from {original_count} to {truncated_count} stops"
                )

        truncated_stop_times.append(truncated_stops)

    WranglerLogger.info(f"Truncated {trips_truncated} trips")

    # Combine all stop times (truncated and non-truncated)
    other_stop_times = stop_times_df[~stop_times_df.trip_id.isin(trip_ids_to_truncate)]
    all_stop_times = pd.concat([other_stop_times, *truncated_stop_times], ignore_index=True)

    # Find stops that are still referenced
    stops_still_used = set(all_stop_times.stop_id.unique())
    filtered_stops = stops_df[stops_df.stop_id.isin(stops_still_used)]

    # Check if any of these stops reference parent stations
    if "parent_station" in filtered_stops.columns:
        # Get parent stations that are referenced by kept stops
        parent_stations = filtered_stops["parent_station"].dropna().unique()
        parent_stations = [ps for ps in parent_stations if ps != ""]  # Remove empty strings

        if len(parent_stations) > 0:
            # Find parent stations that aren't already in our filtered stops
            existing_stop_ids = set(filtered_stops.stop_id)
            missing_parent_stations = [ps for ps in parent_stations if ps not in existing_stop_ids]

            if len(missing_parent_stations) > 0:
                WranglerLogger.debug(
                    f"Adding back {len(missing_parent_stations)} parent stations referenced by kept stops"
                )

                # Get the parent station records
                parent_station_records = stops_df[stops_df.stop_id.isin(missing_parent_stations)]

                # Append parent stations to filtered stops
                filtered_stops = pd.concat(
                    [filtered_stops, parent_station_records], ignore_index=True
                )

    # Log removed stops
    removed_stops = set(stops_df.stop_id) - set(filtered_stops.stop_id)
    if removed_stops:
        WranglerLogger.debug(f"Removed {len(removed_stops)} stops that are no longer referenced")

        # Get details of removed stops
        removed_stops_df = stops_df[stops_df.stop_id.isin(removed_stops)][["stop_id", "stop_name"]]

        # Log up to 20 removed stops with their names
        sample_size = min(20, len(removed_stops_df))
        for _, stop in removed_stops_df.head(sample_size).iterrows():
            WranglerLogger.debug(f"  - Removed stop: {stop['stop_id']} ({stop['stop_name']})")

        if len(removed_stops) > sample_size:
            WranglerLogger.debug(f"  ... and {len(removed_stops) - sample_size} more stops")

    # Update transit_data in place
    transit_data.stop_times = all_stop_times
    transit_data.trips = trips_df
    transit_data.routes = routes_df
    transit_data.stops = filtered_stops

Functions to clip a TransitNetwork object to a boundary.

Clipped transit is an independent transit network that is a subset of the original transit network.

Example usage:

from network_wrangler.transit load_transit, write_transit
from network_wrangler.transit.clip import clip_transit

stpaul_transit = load_transit(example_dir / "stpaul")
boundary_file = test_dir / "data" / "ecolab.geojson"
clipped_network = clip_transit(stpaul_transit, boundary_file=boundary_file)
write_transit(clipped_network, out_dir, prefix="ecolab", format="geojson", true_shape=True)

network_wrangler.transit.clip.clip_feed_to_boundary

clip_feed_to_boundary(
    feed,
    ref_nodes_df,
    boundary_gdf=None,
    boundary_geocode=None,
    boundary_file=None,
    min_stops=DEFAULT_MIN_STOPS,
)

Clips a transit Feed object to a boundary and returns the resulting GeoDataFrames.

Retains only the stops within the boundary and trips that traverse them subject to a minimum number of stops per trip as defined by min_stops.

Parameters:

Name Type Description Default
feed Feed

Feed object to be clipped.

required
ref_nodes_df GeoDataFrame

geodataframe with node geometry to reference

required
boundary_geocode Union[str, dict]

A geocode string or dictionary representing the boundary. Defaults to None.

None
boundary_file Union[str, Path]

A path to the boundary file. Only used if boundary_geocode is None. Defaults to None.

None
boundary_gdf GeoDataFrame

A GeoDataFrame representing the boundary. Only used if boundary_geocode and boundary_file are None. Defaults to None.

None
min_stops int

minimum number of stops needed to retain a transit trip within clipped area. Defaults to DEFAULT_MIN_STOPS which is set to 2.

DEFAULT_MIN_STOPS
Source code in network_wrangler/transit/clip.py
def clip_feed_to_boundary(
    feed: Feed,
    ref_nodes_df: gpd.GeoDataFrame,
    boundary_gdf: gpd.GeoDataFrame | None = None,
    boundary_geocode: str | dict | None = None,
    boundary_file: str | Path | None = None,
    min_stops: int = DEFAULT_MIN_STOPS,
) -> Feed:
    """Clips a transit Feed object to a boundary and returns the resulting GeoDataFrames.

    Retains only the stops within the boundary and trips that traverse them subject to a minimum
    number of stops per trip as defined by `min_stops`.

    Args:
        feed: Feed object to be clipped.
        ref_nodes_df: geodataframe with node geometry to reference
        boundary_geocode (Union[str, dict], optional): A geocode string or dictionary
            representing the boundary. Defaults to None.
        boundary_file (Union[str, Path], optional): A path to the boundary file. Only used if
            boundary_geocode is None. Defaults to None.
        boundary_gdf (gpd.GeoDataFrame, optional): A GeoDataFrame representing the boundary.
            Only used if boundary_geocode and boundary_file are None. Defaults to None.
        min_stops: minimum number of stops needed to retain a transit trip within clipped area.
            Defaults to DEFAULT_MIN_STOPS which is set to 2.

    Returns: Feed object trimmed to the boundary.
    """
    WranglerLogger.info("Clipping transit network to boundary.")

    boundary_gdf = get_bounding_polygon(
        boundary_gdf=boundary_gdf,
        boundary_geocode=boundary_geocode,
        boundary_file=boundary_file,
    )

    shape_links_gdf = shapes_to_shape_links_gdf(feed.shapes, ref_nodes_df=ref_nodes_df)

    # make sure boundary_gdf.crs == network.crs
    if boundary_gdf.crs != shape_links_gdf.crs:
        boundary_gdf = boundary_gdf.to_crs(shape_links_gdf.crs)

    # get the boundary as a single polygon
    boundary = boundary_gdf.geometry.union_all()
    # get the shape_links that intersect the boundary
    clipped_shape_links = shape_links_gdf[shape_links_gdf.geometry.intersects(boundary)]

    # nodes within clipped_shape_links
    node_ids = list(set(clipped_shape_links.A.to_list() + clipped_shape_links.B.to_list()))
    WranglerLogger.debug(f"Clipping network to {len(node_ids)} nodes.")
    if not node_ids:
        msg = "No nodes found within the boundary."
        raise ValueError(msg)
    return _clip_feed_to_nodes(feed, node_ids, min_stops=min_stops)

network_wrangler.transit.clip.clip_feed_to_roadway

clip_feed_to_roadway(
    feed, roadway_net, min_stops=DEFAULT_MIN_STOPS
)

Returns a copy of transit feed clipped to the roadway network.

Parameters:

Name Type Description Default
feed Feed

Transit Feed to clip.

required
roadway_net RoadwayNetwork

Roadway network to clip to.

required
min_stops int

minimum number of stops needed to retain a transit trip within clipped area. Defaults to DEFAULT_MIN_STOPS which is set to 2.

DEFAULT_MIN_STOPS

Raises:

Type Description
ValueError

If no stops found within the roadway network.

Returns:

Name Type Description
Feed Feed

Clipped deep copy of feed limited to the roadway network.

Source code in network_wrangler/transit/clip.py
def clip_feed_to_roadway(
    feed: Feed,
    roadway_net: RoadwayNetwork,
    min_stops: int = DEFAULT_MIN_STOPS,
) -> Feed:
    """Returns a copy of transit feed clipped to the roadway network.

    Args:
        feed (Feed): Transit Feed to clip.
        roadway_net: Roadway network to clip to.
        min_stops: minimum number of stops needed to retain a transit trip within clipped area.
            Defaults to DEFAULT_MIN_STOPS which is set to 2.

    Raises:
        ValueError: If no stops found within the roadway network.

    Returns:
        Feed: Clipped deep copy of feed limited to the roadway network.
    """
    WranglerLogger.info("Clipping transit network to roadway network.")

    clipped_feed = _remove_links_from_feed(feed, roadway_net.links_df, min_stops=min_stops)

    return clipped_feed

network_wrangler.transit.clip.clip_transit

clip_transit(
    network,
    node_ids=None,
    boundary_geocode=None,
    boundary_file=None,
    boundary_gdf=None,
    ref_nodes_df=None,
    roadway_net=None,
    min_stops=DEFAULT_MIN_STOPS,
)

Returns a new TransitNetwork clipped to a boundary as determined by arguments.

Will clip based on which arguments are provided as prioritized below:

  1. If node_ids provided, will clip based on node_ids
  2. If boundary_geocode provided, will clip based on on search in OSM for that jurisdiction boundary using reference geometry from ref_nodes_df, roadway_net, or roadway_path
  3. If boundary_file provided, will clip based on that polygon using reference geometry from ref_nodes_df, roadway_net, or roadway_path
  4. If boundary_gdf provided, will clip based on that geodataframe using reference geometry from ref_nodes_df, roadway_net, or roadway_path
  5. If roadway_net provided, will clip based on that roadway network

Parameters:

Name Type Description Default
network TransitNetwork

TransitNetwork to clip.

required
node_ids list[str]

A list of node_ids to clip to. Defaults to None.

None
boundary_geocode Union[str, dict]

A geocode string or dictionary representing the boundary. Only used if node_ids are None. Defaults to None.

None
boundary_file Union[str, Path]

A path to the boundary file. Only used if node_ids and boundary_geocode are None. Defaults to None.

None
boundary_gdf GeoDataFrame

A GeoDataFrame representing the boundary. Only used if node_ids, boundary_geocode and boundary_file are None. Defaults to None.

None
ref_nodes_df None | GeoDataFrame

GeoDataFrame of geographic references for node_ids. Only used if node_ids is None and one of boundary_* is not None.

None
roadway_net None | RoadwayNetwork

Roadway Network instance to clip transit network to. Only used if node_ids is None and allof boundary_* are None

None
min_stops int

minimum number of stops needed to retain a transit trip within clipped area. Defaults to DEFAULT_MIN_STOPS which is set to 2.

DEFAULT_MIN_STOPS
Source code in network_wrangler/transit/clip.py
def clip_transit(
    network: TransitNetwork | str | Path,
    node_ids: None | list[str] = None,
    boundary_geocode: None | str | dict = None,
    boundary_file: str | Path | None = None,
    boundary_gdf: None | gpd.GeoDataFrame = None,
    ref_nodes_df: None | gpd.GeoDataFrame = None,
    roadway_net: None | RoadwayNetwork = None,
    min_stops: int = DEFAULT_MIN_STOPS,
) -> TransitNetwork:
    """Returns a new TransitNetwork clipped to a boundary as determined by arguments.

    Will clip based on which arguments are provided as prioritized below:

    1. If `node_ids` provided, will clip based on `node_ids`
    2. If `boundary_geocode` provided, will clip based on on search in OSM for that jurisdiction
        boundary using reference geometry from `ref_nodes_df`, `roadway_net`, or `roadway_path`
    3. If `boundary_file` provided, will clip based on that polygon  using reference geometry
        from `ref_nodes_df`, `roadway_net`, or `roadway_path`
    4. If `boundary_gdf` provided, will clip based on that geodataframe  using reference geometry
        from `ref_nodes_df`, `roadway_net`, or `roadway_path`
    5. If `roadway_net` provided, will clip based on that roadway network

    Args:
        network (TransitNetwork): TransitNetwork to clip.
        node_ids (list[str], optional): A list of node_ids to clip to. Defaults to None.
        boundary_geocode (Union[str, dict], optional): A geocode string or dictionary
            representing the boundary. Only used if node_ids are None. Defaults to None.
        boundary_file (Union[str, Path], optional): A path to the boundary file. Only used if
            node_ids and boundary_geocode are None. Defaults to None.
        boundary_gdf (gpd.GeoDataFrame, optional): A GeoDataFrame representing the boundary.
            Only used if node_ids, boundary_geocode and boundary_file are None. Defaults to None.
        ref_nodes_df: GeoDataFrame of geographic references for node_ids.  Only used if
            node_ids is None and one of boundary_* is not None.
        roadway_net: Roadway Network  instance to clip transit network to.  Only used if
            node_ids is None and allof boundary_* are None
        min_stops: minimum number of stops needed to retain a transit trip within clipped area.
            Defaults to DEFAULT_MIN_STOPS which is set to 2.
    """
    if not isinstance(network, TransitNetwork):
        network = load_transit(network)
    set_roadway_network = False
    feed = network.feed

    if node_ids is not None:
        clipped_feed = _clip_feed_to_nodes(feed, node_ids=node_ids, min_stops=min_stops)
    elif any(i is not None for i in [boundary_file, boundary_geocode, boundary_gdf]):
        if ref_nodes_df is None:
            ref_nodes_df = get_nodes(transit_net=network, roadway_net=roadway_net)

        clipped_feed = clip_feed_to_boundary(
            feed,
            ref_nodes_df,
            boundary_file=boundary_file,
            boundary_geocode=boundary_geocode,
            boundary_gdf=boundary_gdf,
            min_stops=min_stops,
        )
    elif roadway_net is not None:
        clipped_feed = clip_feed_to_roadway(feed, roadway_net=roadway_net)
        set_roadway_network = True
    else:
        msg = "Missing required arguments from clip_transit"
        raise ValueError(msg)

    # create a new TransitNetwork object with the clipped feed dataframes
    clipped_net = TransitNetwork(clipped_feed)

    if set_roadway_network:
        WranglerLogger.info("Setting roadway network for clipped transit network.")
        clipped_net.road_net = roadway_net
    return clipped_net

Functions to filter transit feeds by various criteria.

Filtered transit feeds are subsets of the original feed based on selection criteria like service_ids, route_types, etc.

network_wrangler.transit.filter.MAX_TRUNCATION_WARNING_STOPS module-attribute

MAX_TRUNCATION_WARNING_STOPS = 10

Maximum number of removed stops to list individually in truncation warnings.

Used in truncate_route_at_stop() to control verbosity of warning messages. If more stops are removed, only the count is shown instead of listing each stop.

network_wrangler.transit.filter.MIN_ROUTE_SEGMENTS module-attribute

MIN_ROUTE_SEGMENTS = 2

Minimum number of boundary segments before warning about complex route patterns.

Used in filter_transit_by_boundary() to detect routes that exit and re-enter the boundary.

network_wrangler.transit.filter.drop_transit_agency

drop_transit_agency(transit_data, agency_id)

Remove all routes, trips, stops, etc. for a specific agency or agencies.

Filters out all data associated with the specified agency_id(s), ensuring the resulting transit data remains valid by removing orphaned stops and maintaining referential integrity. Modifies transit_data in place.

Parameters:

Name Type Description Default
transit_data GtfsModel | Feed

Either a GtfsModel or Feed object to filter. Modified in place.

required
agency_id str | list[str]

Single agency_id string or list of agency_ids to remove

required
Example

Remove a single agency

drop_transit_agency(gtfs_model, “SFMTA”)

Remove multiple agencies

drop_transit_agency(gtfs_model, [“SFMTA”, “AC”])

Source code in network_wrangler/transit/filter.py
def drop_transit_agency(  # noqa: PLR0915
    transit_data: GtfsModel | Feed,
    agency_id: str | list[str],
) -> None:
    """Remove all routes, trips, stops, etc. for a specific agency or agencies.

    Filters out all data associated with the specified agency_id(s), ensuring
    the resulting transit data remains valid by removing orphaned stops and
    maintaining referential integrity. Modifies transit_data in place.

    Args:
        transit_data: Either a GtfsModel or Feed object to filter. Modified in place.
        agency_id: Single agency_id string or list of agency_ids to remove

    Example:
        >>> # Remove a single agency
        >>> drop_transit_agency(gtfs_model, "SFMTA")
        >>> # Remove multiple agencies
        >>> drop_transit_agency(gtfs_model, ["SFMTA", "AC"])
    """
    # Convert single agency_id to list for uniform handling
    agency_ids_to_remove = [agency_id] if isinstance(agency_id, str) else agency_id

    WranglerLogger.info(f"Removing transit data for agency/agencies: {agency_ids_to_remove}")

    # Get data tables (references, not copies)
    routes_df = transit_data.routes
    trips_df = transit_data.trips
    stop_times_df = transit_data.stop_times
    stops_df = transit_data.stops
    is_gtfs = isinstance(transit_data, GtfsModel)

    # Find routes to keep (those NOT belonging to agencies being removed)
    if "agency_id" in routes_df.columns:
        routes_to_keep = routes_df[~routes_df.agency_id.isin(agency_ids_to_remove)]
        routes_removed = len(routes_df) - len(routes_to_keep)
    else:
        # If no agency_id column in routes, log warning and keep all routes
        WranglerLogger.warning(
            "No agency_id column found in routes table - cannot filter by agency"
        )
        routes_to_keep = routes_df
        routes_removed = 0

    route_ids_to_keep = set(routes_to_keep.route_id)

    # Filter trips based on remaining routes
    trips_to_keep = trips_df[trips_df.route_id.isin(route_ids_to_keep)]
    trips_removed = len(trips_df) - len(trips_to_keep)
    trip_ids_to_keep = set(trips_to_keep.trip_id)

    # Filter stop_times based on remaining trips
    stop_times_to_keep = stop_times_df[stop_times_df.trip_id.isin(trip_ids_to_keep)]
    stop_times_removed = len(stop_times_df) - len(stop_times_to_keep)

    # Find stops that are still referenced
    stops_still_used = set(stop_times_to_keep.stop_id.unique())
    stops_to_keep = stops_df[stops_df.stop_id.isin(stops_still_used)]

    # Check if any of these stops reference parent stations
    if "parent_station" in stops_to_keep.columns:
        # Get parent stations that are referenced by kept stops
        parent_stations = stops_to_keep["parent_station"].dropna().unique()
        parent_stations = [ps for ps in parent_stations if ps != ""]  # Remove empty strings

        if len(parent_stations) > 0:
            # Find parent stations that aren't already in our filtered stops
            existing_stop_ids = set(stops_to_keep.stop_id)
            missing_parent_stations = [ps for ps in parent_stations if ps not in existing_stop_ids]

            if len(missing_parent_stations) > 0:
                WranglerLogger.debug(
                    f"Adding back {len(missing_parent_stations)} parent stations referenced by kept stops"
                )

                # Get the parent station records
                parent_station_records = stops_df[stops_df.stop_id.isin(missing_parent_stations)]

                # Append parent stations to filtered stops
                stops_to_keep = pd.concat(
                    [stops_to_keep, parent_station_records], ignore_index=True
                )

    stops_removed = len(stops_df) - len(stops_to_keep)

    WranglerLogger.info(
        f"Removed: {routes_removed:,} routes, {trips_removed:,} trips, "
        f"{stop_times_removed:,} stop_times, {stops_removed:,} stops"
    )

    WranglerLogger.info(
        f"Remaining: {len(routes_to_keep):,} routes, {len(trips_to_keep):,} trips, "
        f"{len(stops_to_keep):,} stops"
    )
    WranglerLogger.debug(
        f"Stops removed:\n{stops_df.loc[~stops_df['stop_id'].isin(stops_still_used)]}"
    )

    # Update tables in place, in order so that validation is ok
    transit_data.stop_times = stop_times_to_keep
    transit_data.trips = trips_to_keep
    transit_data.routes = routes_to_keep
    transit_data.stops = stops_to_keep

    # Handle agency table
    if hasattr(transit_data, "agency") and transit_data.agency is not None:
        # Keep agencies that are NOT being removed
        filtered_agency = transit_data.agency[
            ~transit_data.agency.agency_id.isin(agency_ids_to_remove)
        ]
        WranglerLogger.info(
            f"Removed {len(transit_data.agency) - len(filtered_agency):,} agencies"
        )
        transit_data.agency = filtered_agency

    # Handle shapes table
    if (
        hasattr(transit_data, "shapes")
        and transit_data.shapes is not None
        and "shape_id" in trips_to_keep.columns
    ):
        shape_ids = set(trips_to_keep.shape_id.dropna().unique())
        filtered_shapes = transit_data.shapes[transit_data.shapes.shape_id.isin(shape_ids)]
        WranglerLogger.info(
            f"Removed {len(transit_data.shapes) - len(filtered_shapes):,} shape points"
        )
        transit_data.shapes = filtered_shapes

    # Handle calendar table
    if hasattr(transit_data, "calendar") and transit_data.calendar is not None:
        # Keep only service_ids referenced by remaining trips
        service_ids = set(trips_to_keep.service_id.unique())
        transit_data.calendar = transit_data.calendar[
            transit_data.calendar.service_id.isin(service_ids)
        ]

    # Handle calendar_dates table
    if hasattr(transit_data, "calendar_dates") and transit_data.calendar_dates is not None:
        # Keep only service_ids referenced by remaining trips
        service_ids = set(trips_to_keep.service_id.unique())
        transit_data.calendar_dates = transit_data.calendar_dates[
            transit_data.calendar_dates.service_id.isin(service_ids)
        ]

    # Handle frequencies table
    if hasattr(transit_data, "frequencies") and transit_data.frequencies is not None:
        # Keep only frequencies for remaining trips
        transit_data.frequencies = transit_data.frequencies[
            transit_data.frequencies.trip_id.isin(trip_ids_to_keep)
        ]

network_wrangler.transit.filter.filter_feed_by_service_ids

filter_feed_by_service_ids(feed, service_ids)

Filter a transit feed to only include trips for specified service_ids.

Filters trips, stop_times, and stops to only include data related to the specified service_ids. Also ensures parent stations are retained if referenced.

Parameters:

Name Type Description Default
feed Feed | GtfsModel

Feed or GtfsModel object to filter

required
service_ids list[str]

List of service_ids to retain

required

Returns:

Type Description
Feed | GtfsModel

Union[Feed, GtfsModel]: Filtered copy of feed with only trips/stops/stop_times for specified service_ids. Returns same type as input.

Source code in network_wrangler/transit/filter.py
def filter_feed_by_service_ids(  # noqa: PLR0915
    feed: Feed | GtfsModel,
    service_ids: list[str],
) -> Feed | GtfsModel:
    """Filter a transit feed to only include trips for specified service_ids.

    Filters trips, stop_times, and stops to only include data related to the
    specified service_ids. Also ensures parent stations are retained if referenced.

    Args:
        feed: Feed or GtfsModel object to filter
        service_ids: List of service_ids to retain

    Returns:
        Union[Feed, GtfsModel]: Filtered copy of feed with only trips/stops/stop_times
            for specified service_ids. Returns same type as input.
    """
    WranglerLogger.info(f"Filtering feed to {len(service_ids):,} service_ids")

    # Remember the input type to return the same type
    is_feed = isinstance(feed, Feed)

    # Extract dataframes to work with them directly (avoiding validation during filtering)
    feed_dfs = {}
    for table_name in feed.table_names:
        if hasattr(feed, table_name) and getattr(feed, table_name) is not None:
            feed_dfs[table_name] = getattr(feed, table_name).copy()

    # Filter trips for these service_ids
    original_trip_count = len(feed_dfs["trips"])
    feed_dfs["trips"]["service_id"] = feed_dfs["trips"]["service_id"].astype(str)

    # Create a DataFrame from the list for merging
    service_ids_df = pd.DataFrame({"service_id": service_ids})
    feed_dfs["trips"] = feed_dfs["trips"].merge(
        right=service_ids_df, on="service_id", how="left", indicator=True
    )
    WranglerLogger.debug(
        f"trips._merge.value_counts():\n{feed_dfs['trips']._merge.value_counts()}"
    )
    feed_dfs["trips"] = (
        feed_dfs["trips"]
        .loc[feed_dfs["trips"]._merge == "both"]
        .drop(columns=["_merge"])
        .reset_index(drop=True)
    )
    WranglerLogger.info(
        f"Filtered trips from {original_trip_count:,} to {len(feed_dfs['trips']):,}"
    )

    # Filter stop_times for these trip_ids
    feed_dfs["trips"]["trip_id"] = feed_dfs["trips"]["trip_id"].astype(str)
    trip_ids = feed_dfs["trips"][["trip_id"]].drop_duplicates().reset_index(drop=True)
    WranglerLogger.debug(f"After filtering trips to trip_ids (len={len(trip_ids):,})")

    feed_dfs["stop_times"]["trip_id"] = feed_dfs["stop_times"]["trip_id"].astype(str)
    feed_dfs["stop_times"] = feed_dfs["stop_times"].merge(
        right=trip_ids, how="left", indicator=True
    )
    WranglerLogger.debug(
        f"stop_times._merge.value_counts():\n{feed_dfs['stop_times']._merge.value_counts()}"
    )
    feed_dfs["stop_times"] = (
        feed_dfs["stop_times"]
        .loc[feed_dfs["stop_times"]._merge == "both"]
        .drop(columns=["_merge"])
        .reset_index(drop=True)
    )

    # Filter stops for these stop_ids
    feed_dfs["stop_times"]["stop_id"] = feed_dfs["stop_times"]["stop_id"].astype(str)
    stop_ids = feed_dfs["stop_times"][["stop_id"]].drop_duplicates().reset_index(drop=True)
    stop_ids_set = set(stop_ids["stop_id"])
    WranglerLogger.debug(f"After filtering stop_times to stop_ids (len={len(stop_ids):,})")

    feed_dfs["stops"]["stop_id"] = feed_dfs["stops"]["stop_id"].astype(str)

    # Identify parent stations that should be kept
    parent_stations_to_keep = set()
    if "parent_station" in feed_dfs["stops"].columns:
        # Find all parent stations referenced by stops that will be kept
        stops_to_keep = feed_dfs["stops"][feed_dfs["stops"]["stop_id"].isin(stop_ids_set)]
        parent_stations = stops_to_keep["parent_station"].dropna().unique()
        parent_stations_to_keep = {ps for ps in parent_stations if ps != ""}

        if len(parent_stations_to_keep) > 0:
            WranglerLogger.info(
                f"Preserving {len(parent_stations_to_keep)} parent stations referenced by kept stops"
            )

    # Create combined set of stop_ids to keep (original stops + parent stations)
    all_stop_ids_to_keep = stop_ids_set | parent_stations_to_keep

    # Filter stops to include both regular stops and their parent stations
    original_stop_count = len(feed_dfs["stops"])
    feed_dfs["stops"] = feed_dfs["stops"][
        feed_dfs["stops"]["stop_id"].isin(all_stop_ids_to_keep)
    ].reset_index(drop=True)

    WranglerLogger.debug(
        f"Filtered stops from {original_stop_count:,} to {len(feed_dfs['stops']):,} "
        f"(including {len(parent_stations_to_keep)} parent stations)"
    )

    # Check for stop_times with invalid stop_ids after all filtering is complete
    valid_stop_ids = set(feed_dfs["stops"]["stop_id"])
    invalid_mask = ~feed_dfs["stop_times"]["stop_id"].isin(valid_stop_ids)
    invalid_stop_times = feed_dfs["stop_times"][invalid_mask]

    if len(invalid_stop_times) > 0:
        WranglerLogger.warning(
            f"Found {len(invalid_stop_times):,} stop_times entries with invalid stop_ids after filtering"
        )

        # Join with trips to get route_id
        invalid_with_routes = invalid_stop_times.merge(
            feed_dfs["trips"][["trip_id", "route_id"]], on="trip_id", how="left"
        )

        # Log unique invalid stop_ids
        invalid_stop_ids = invalid_stop_times["stop_id"].unique()
        WranglerLogger.warning(
            f"Invalid stop_ids ({len(invalid_stop_ids)} unique): {invalid_stop_ids.tolist()}"
        )

        # Log sample of invalid entries with trip and route context
        sample_invalid = invalid_with_routes.head(10)
        WranglerLogger.warning(
            f"Sample invalid stop_times entries:\n{sample_invalid[['trip_id', 'route_id', 'stop_id', 'stop_sequence']]}"
        )

        # Log summary by route
        route_summary = (
            invalid_with_routes.groupby("route_id")["stop_id"]
            .agg(["count", "nunique"])
            .sort_values("count", ascending=False)
        )
        route_summary.columns = ["invalid_stop_times_count", "unique_invalid_stops"]
        WranglerLogger.warning(f"Invalid stop_times by route (top 20):\n{route_summary.head(20)}")

        WranglerLogger.debug(f"All invalid stop_times entries with routes:\n{invalid_with_routes}")
    else:
        WranglerLogger.info("All stop_times entries have valid stop_ids after filtering")

    # Filter other tables to match filtered trips
    if "shapes" in feed_dfs:
        shape_ids = feed_dfs["trips"]["shape_id"].dropna().unique()
        feed_dfs["shapes"] = feed_dfs["shapes"][
            feed_dfs["shapes"]["shape_id"].isin(shape_ids)
        ].reset_index(drop=True)
        WranglerLogger.debug(f"Filtered shapes to {len(feed_dfs['shapes']):,} records")

    if "routes" in feed_dfs:
        route_ids = feed_dfs["trips"]["route_id"].unique()
        feed_dfs["routes"] = feed_dfs["routes"][
            feed_dfs["routes"]["route_id"].isin(route_ids)
        ].reset_index(drop=True)
        WranglerLogger.debug(f"Filtered routes to {len(feed_dfs['routes']):,} records")

    # Feed has frequencies, GtfsModel doesn't
    if is_feed and "frequencies" in feed_dfs:
        feed_dfs["frequencies"]["trip_id"] = feed_dfs["frequencies"]["trip_id"].astype(str)
        feed_dfs["frequencies"] = feed_dfs["frequencies"][
            feed_dfs["frequencies"]["trip_id"].isin(feed_dfs["trips"]["trip_id"])
        ].reset_index(drop=True)
        WranglerLogger.debug(f"Filtered frequencies to {len(feed_dfs['frequencies']):,} records")

    # Create the appropriate object type with the filtered dataframes
    if is_feed:
        return Feed(**feed_dfs)
    return GtfsModel(**feed_dfs)

network_wrangler.transit.filter.filter_transit_by_boundary

filter_transit_by_boundary(
    transit_data,
    boundary,
    partially_include_route_type_action=None,
)

Filter transit routes based on whether they have stops within a boundary.

Removes routes that are entirely outside the boundary shapefile. Routes that are partially within the boundary are kept by default, but can be configured per route type to be truncated at the boundary. Modifies transit_data in place.

Parameters:

Name Type Description Default
transit_data GtfsModel | Feed

Either a GtfsModel or Feed object to filter. Modified in place.

required
boundary str | Path | GeoDataFrame

Path to boundary shapefile or a GeoDataFrame with boundary polygon(s)

required
partially_include_route_type_action dict[RouteType, str] | None

Optional dictionary mapping RouteType enum to action for routes partially within boundary: - “truncate”: Truncate route to only include stops within boundary Route types not specified in this dictionary will be kept entirely (default).

None
Example

from network_wrangler.models.gtfs.types import RouteType

Remove routes entirely outside the Bay Area

filtered_gtfs = filter_transit_by_boundary(gtfs_model, “bay_area_boundary.shp”)

Truncate rail routes at boundary, keep other route types unchanged

filtered_gtfs = filter_transit_by_boundary( … gtfs_model, … “bay_area_boundary.shp”, … partially_include_route_type_action={ … RouteType.RAIL: “truncate”, # Rail - will be truncated at boundary … # Other route types not listed will be kept entirely … }, … )

Todo

This is similar to clip_feed_to_boundary – consolidate?

Source code in network_wrangler/transit/filter.py
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
def filter_transit_by_boundary(  # noqa: PLR0912, PLR0915
    transit_data: GtfsModel | Feed,
    boundary: str | Path | gpd.GeoDataFrame,
    partially_include_route_type_action: dict[RouteType, str] | None = None,
) -> None:
    """Filter transit routes based on whether they have stops within a boundary.

    Removes routes that are entirely outside the boundary shapefile. Routes that are
    partially within the boundary are kept by default, but can be configured per
    route type to be truncated at the boundary. Modifies transit_data in place.

    Args:
        transit_data: Either a GtfsModel or Feed object to filter. Modified in place.
        boundary: Path to boundary shapefile or a GeoDataFrame with boundary polygon(s)
        partially_include_route_type_action: Optional dictionary mapping RouteType enum to
            action for routes partially within boundary:
            - "truncate": Truncate route to only include stops within boundary
            Route types not specified in this dictionary will be kept entirely (default).

    Example:
        >>> from network_wrangler.models.gtfs.types import RouteType
        >>> # Remove routes entirely outside the Bay Area
        >>> filtered_gtfs = filter_transit_by_boundary(gtfs_model, "bay_area_boundary.shp")
        >>> # Truncate rail routes at boundary, keep other route types unchanged
        >>> filtered_gtfs = filter_transit_by_boundary(
        ...     gtfs_model,
        ...     "bay_area_boundary.shp",
        ...     partially_include_route_type_action={
        ...         RouteType.RAIL: "truncate",  # Rail - will be truncated at boundary
        ...         # Other route types not listed will be kept entirely
        ...     },
        ... )

    !!! todo
        This is similar to [`clip_feed_to_boundary`][network_wrangler.transit.clip.clip_feed_to_boundary] -- consolidate?

    """
    WranglerLogger.info("Filtering transit routes by boundary")

    # Log input parameters
    WranglerLogger.debug(
        f"partially_include_route_type_action: {partially_include_route_type_action}"
    )

    # Load boundary if it's a file path
    if isinstance(boundary, str | Path):
        WranglerLogger.debug(f"Loading boundary from file: {boundary}")
        boundary_gdf = gpd.read_file(boundary)
    else:
        WranglerLogger.debug("Using provided boundary GeoDataFrame")
        boundary_gdf = boundary

    WranglerLogger.debug(f"Boundary has {len(boundary_gdf)} polygon(s)")

    # Ensure boundary is in a geographic CRS for spatial operations
    if boundary_gdf.crs is None:
        WranglerLogger.warning("Boundary has no CRS, assuming EPSG:4326")
        boundary_gdf = boundary_gdf.set_crs(LAT_LON_CRS)
    else:
        WranglerLogger.debug(f"Boundary CRS: {boundary_gdf.crs}")

    # Get references to tables (not copies since we'll modify in place)
    is_gtfs = isinstance(transit_data, GtfsModel)
    stops_df = transit_data.stops
    routes_df = transit_data.routes
    trips_df = transit_data.trips
    stop_times_df = transit_data.stop_times

    if is_gtfs:
        WranglerLogger.debug("Processing GtfsModel data")
    else:
        WranglerLogger.debug("Processing Feed data")

    WranglerLogger.debug(
        f"Input data has {len(stops_df)} stops, {len(routes_df)} routes, {len(trips_df)} trips, {len(stop_times_df)} stop_times"
    )

    # Create GeoDataFrame from stops
    stops_gdf = gpd.GeoDataFrame(
        stops_df,
        geometry=gpd.points_from_xy(stops_df.stop_lon, stops_df.stop_lat),
        crs=LAT_LON_CRS,
    )

    # Reproject to match boundary CRS if needed
    if stops_gdf.crs != boundary_gdf.crs:
        WranglerLogger.debug(f"Reprojecting stops from {stops_gdf.crs} to {boundary_gdf.crs}")
        stops_gdf = stops_gdf.to_crs(boundary_gdf.crs)

    # Spatial join to find stops within boundary
    WranglerLogger.debug("Performing spatial join to find stops within boundary")
    stops_in_boundary = gpd.sjoin(stops_gdf, boundary_gdf, how="inner", predicate="within")
    stops_in_boundary_ids = set(stops_in_boundary.stop_id.unique())

    # Log some stops that are outside boundary for debugging
    stops_outside_boundary = set(stops_df.stop_id) - stops_in_boundary_ids
    if stops_outside_boundary:
        sample_outside = list(stops_outside_boundary)[:5]
        WranglerLogger.debug(f"Sample of stops outside boundary: {sample_outside}")

    WranglerLogger.info(
        f"Found {len(stops_in_boundary_ids):,} stops within boundary "
        f"out of {len(stops_df):,} total stops"
    )

    # Find which routes to keep
    # Get unique stop-route pairs from stop_times and trips
    stop_route_pairs = pd.merge(
        stop_times_df[["trip_id", "stop_id"]], trips_df[["trip_id", "route_id"]], on="trip_id"
    )[["stop_id", "route_id"]].drop_duplicates()

    # Group by route to find which stops each route serves
    route_stops = stop_route_pairs.groupby("route_id")["stop_id"].apply(set).reset_index()
    route_stops.columns = ["route_id", "stop_ids"]

    # Add route_type information
    route_stops = pd.merge(
        route_stops, routes_df[["route_id", "route_type"]], on="route_id", how="left"
    )

    # Initialize with default filters
    if partially_include_route_type_action is None:
        partially_include_route_type_action = {}

    # Convert RouteType enum keys to int values for comparison with dataframe
    normalized_route_type_action = {}
    for key, value in partially_include_route_type_action.items():
        if not isinstance(key, RouteType):
            msg = f"Keys in partially_include_route_type_action must be RouteType enum, got {type(key)}"
            raise TypeError(msg)
        normalized_route_type_action[key.value] = value
    partially_include_route_type_action = normalized_route_type_action

    # Track routes to truncate
    routes_to_truncate = {}

    # Determine which routes to keep and how to handle them
    def determine_route_handling(row):
        route_id = row["route_id"]
        route_type = row["route_type"]
        stop_ids = row["stop_ids"]

        # Check if route has stops both inside and outside boundary
        stops_inside = stop_ids.intersection(stops_in_boundary_ids)
        stops_outside = stop_ids - stops_in_boundary_ids

        # If all stops are outside, always remove
        if len(stops_inside) == 0:
            WranglerLogger.debug(
                f"Route {route_id} (type {route_type}): all {len(stop_ids)} stops outside boundary - REMOVE"
            )
            return "remove"

        # If all stops are inside, always keep
        if len(stops_outside) == 0:
            return "keep"

        # Route has stops both inside and outside - check partially_include_route_type_action
        WranglerLogger.debug(
            f"Route {route_id} (type {route_type}): {len(stops_inside)} stops inside, "
            f"{len(stops_outside)} stops outside boundary"
        )

        if route_type in partially_include_route_type_action:
            action = partially_include_route_type_action[route_type]
            WranglerLogger.debug(
                f"  - Applying configured action for route_type {route_type}: {action}"
            )
            if action == "truncate":
                return "truncate"

        # Default to keep if not specified
        WranglerLogger.debug(
            f"  - No action configured for route_type {route_type}, defaulting to KEEP"
        )
        return "keep"

    route_stops["handling"] = route_stops.apply(determine_route_handling, axis=1)
    WranglerLogger.debug(f"route_stops with handling set:\n{route_stops}")

    routes_to_keep = set(
        route_stops[route_stops["handling"].isin(["keep", "truncate"])]["route_id"]
    )
    routes_to_remove = set(route_stops[route_stops["handling"] == "remove"]["route_id"])
    routes_needing_truncation = set(route_stops[route_stops["handling"] == "truncate"]["route_id"])

    WranglerLogger.info(
        f"Keeping {len(routes_to_keep):,} routes out of {len(routes_df):,} total routes"
    )

    if routes_to_remove:
        WranglerLogger.info(f"Removing {len(routes_to_remove):,} routes entirely outside boundary")
        WranglerLogger.debug(f"Routes being removed: {sorted(routes_to_remove)[:10]}...")

    if routes_needing_truncation:
        WranglerLogger.info(f"Truncating {len(routes_needing_truncation):,} routes at boundary")
        WranglerLogger.debug(
            f"Routes being truncated: {sorted(routes_needing_truncation)[:10]}..."
        )

    # Filter data
    filtered_routes = routes_df[routes_df.route_id.isin(routes_to_keep)]
    filtered_trips = trips_df[trips_df.route_id.isin(routes_to_keep)]
    filtered_trip_ids = set(filtered_trips.trip_id)

    # Handle truncation by calling truncate_route_at_stop for each route needing truncation
    if routes_needing_truncation:
        WranglerLogger.debug(f"Processing truncation for {len(routes_needing_truncation)} routes")

        # Start with the current filtered data
        # Need to ensure stop_times only includes trips that are in filtered_trips
        filtered_stop_times_for_truncation = stop_times_df[
            stop_times_df.trip_id.isin(filtered_trip_ids)
        ]

        # First update transit_data with filtered data before truncation (in order to maintain validation)
        transit_data.stop_times = filtered_stop_times_for_truncation
        transit_data.trips = filtered_trips
        transit_data.routes = filtered_routes

        # Process each route that needs truncation
        for route_id in routes_needing_truncation:
            WranglerLogger.debug(f"Processing truncation for route {route_id}")

            # Get trips for this route
            route_trips = trips_df[trips_df.route_id == route_id]

            # Group by direction_id
            for direction_id in route_trips.direction_id.unique():
                dir_trips = route_trips[route_trips.direction_id == direction_id]
                if len(dir_trips) == 0:
                    continue

                # Analyze stop patterns for this route/direction
                # Get a representative trip (first one)
                sample_trip_id = dir_trips.iloc[0].trip_id
                sample_stop_times = transit_data.stop_times[
                    transit_data.stop_times.trip_id == sample_trip_id
                ].sort_values("stop_sequence")

                # Find which stops are inside/outside boundary
                stop_boundary_status = sample_stop_times["stop_id"].isin(stops_in_boundary_ids)

                # Check if route exits and re-enters boundary (complex case)
                boundary_changes = stop_boundary_status.ne(stop_boundary_status.shift()).cumsum()
                num_segments = boundary_changes.nunique()

                if num_segments > MIN_ROUTE_SEGMENTS:
                    # Complex case: route exits and re-enters boundary
                    route_info = routes_df[routes_df.route_id == route_id].iloc[0]
                    route_name = route_info.get("route_short_name", route_id)
                    msg = (
                        f"Route {route_name} ({route_id}) direction {direction_id} has a complex "
                        f"boundary crossing pattern (crosses boundary {num_segments - 1} times). "
                        f"Can only handle routes that exit boundary at beginning or end."
                    )
                    raise ValueError(msg)

                # Determine truncation type
                first_stop_inside = stop_boundary_status.iloc[0]
                last_stop_inside = stop_boundary_status.iloc[-1]

                if not first_stop_inside and not last_stop_inside:
                    # All stops outside - shouldn't happen as route would be removed
                    continue
                if first_stop_inside and last_stop_inside:
                    # All stops inside - no truncation needed
                    continue
                if not first_stop_inside and last_stop_inside:
                    # Starts outside, ends inside - truncate before first inside stop
                    # Find first True value (first stop inside boundary)
                    first_inside_pos = stop_boundary_status.tolist().index(True)
                    first_inside_stop = sample_stop_times.iloc[first_inside_pos]["stop_id"]

                    WranglerLogger.debug(
                        f"Route {route_id} dir {direction_id}: truncating before stop {first_inside_stop}"
                    )
                    truncate_route_at_stop(
                        transit_data, route_id, direction_id, first_inside_stop, "before"
                    )
                elif first_stop_inside and not last_stop_inside:
                    # Starts inside, ends outside - truncate after last inside stop
                    # Find last True value (last stop inside boundary)
                    reversed_list = stop_boundary_status.tolist()[::-1]
                    last_inside_pos = len(reversed_list) - 1 - reversed_list.index(True)
                    last_inside_stop = sample_stop_times.iloc[last_inside_pos]["stop_id"]

                    WranglerLogger.debug(
                        f"Route {route_id} dir {direction_id}: truncating after stop {last_inside_stop}"
                    )
                    truncate_route_at_stop(
                        transit_data, route_id, direction_id, last_inside_stop, "after"
                    )

        # After truncation, transit_data has been modified in place
        # Update references to current state (in order to maintain validation)
        filtered_stop_times = transit_data.stop_times
        filtered_trips = transit_data.trips
        filtered_routes = transit_data.routes
        filtered_stops = transit_data.stops
    else:
        # No truncation needed - update transit_data with filtered data
        filtered_stop_times = stop_times_df[stop_times_df.trip_id.isin(filtered_trip_ids)]
        filtered_stops = stops_df[stops_df.stop_id.isin(filtered_stop_times.stop_id.unique())]

        # Check if any of the filtered stops reference parent stations
        if "parent_station" in filtered_stops.columns:
            # Get parent stations that are referenced by kept stops
            parent_stations = filtered_stops["parent_station"].dropna().unique()
            parent_stations = [ps for ps in parent_stations if ps != ""]  # Remove empty strings

            if len(parent_stations) > 0:
                # Find parent stations that aren't already in our filtered stops
                existing_stop_ids = set(filtered_stops.stop_id)
                missing_parent_stations = [
                    ps for ps in parent_stations if ps not in existing_stop_ids
                ]

                if len(missing_parent_stations) > 0:
                    WranglerLogger.debug(
                        f"Adding back {len(missing_parent_stations)} parent stations referenced by kept stops"
                    )

                    # Get the parent station records
                    parent_station_records = stops_df[
                        stops_df.stop_id.isin(missing_parent_stations)
                    ]

                    # Append parent stations to filtered stops
                    filtered_stops = pd.concat(
                        [filtered_stops, parent_station_records], ignore_index=True
                    )

        transit_data.stop_times = filtered_stop_times
        transit_data.trips = filtered_trips
        transit_data.routes = filtered_routes
        transit_data.stops = filtered_stops

    # Log details about removed stops
    stops_still_used = set(filtered_stops.stop_id.unique())
    removed_stops = set(stops_df.stop_id) - stops_still_used
    if removed_stops:
        WranglerLogger.debug(f"Removed {len(removed_stops)} stops that are no longer referenced")

        # Get details of removed stops
        removed_stops_df = stops_df[stops_df["stop_id"].isin(removed_stops)][
            ["stop_id", "stop_name"]
        ]

        # Log up to 20 removed stops with their names
        sample_size = min(20, len(removed_stops_df))
        for _, stop in removed_stops_df.head(sample_size).iterrows():
            WranglerLogger.debug(f"  - Removed stop: {stop['stop_id']} ({stop['stop_name']})")

        if len(removed_stops) > sample_size:
            WranglerLogger.debug(f"  ... and {len(removed_stops) - sample_size} more stops")

    WranglerLogger.info(
        f"After filtering: {len(filtered_routes):,} routes, "
        f"{len(filtered_trips):,} trips, {len(filtered_stops):,} stops"
    )

    # Log summary of filtering by action type
    route_handling_summary = route_stops.groupby("handling").size()
    WranglerLogger.debug(f"Route handling summary:\n{route_handling_summary}")

    # Log route type distribution for routes with mixed stops
    mixed_routes = route_stops[
        (route_stops["handling"].isin(["keep", "truncate"]))
        & (
            route_stops["route_id"].isin(routes_needing_truncation) | route_stops["handling"]
            == "keep"
        )
    ]
    if len(mixed_routes) > 0:
        route_type_summary = mixed_routes.groupby("route_type")["handling"].value_counts()
        WranglerLogger.debug(f"Route types with partial stops:\n{route_type_summary}")

    # Update other tables in transit_data in place
    if is_gtfs:
        # For GtfsModel, also filter shapes and other tables if they exist
        if (
            hasattr(transit_data, "agency")
            and transit_data.agency is not None
            and "agency_id" in filtered_routes.columns
        ):
            agency_ids = set(filtered_routes.agency_id.dropna().unique())
            transit_data.agency = transit_data.agency[
                transit_data.agency.agency_id.isin(agency_ids)
            ]

        if (
            hasattr(transit_data, "shapes")
            and transit_data.shapes is not None
            and "shape_id" in filtered_trips.columns
        ):
            shape_ids = set(filtered_trips.shape_id.dropna().unique())
            transit_data.shapes = transit_data.shapes[transit_data.shapes.shape_id.isin(shape_ids)]

        if hasattr(transit_data, "calendar") and transit_data.calendar is not None:
            # Keep only service_ids referenced by remaining trips
            service_ids = set(filtered_trips.service_id.unique())
            transit_data.calendar = transit_data.calendar[
                transit_data.calendar.service_id.isin(service_ids)
            ]

        if hasattr(transit_data, "calendar_dates") and transit_data.calendar_dates is not None:
            # Keep only service_ids referenced by remaining trips
            service_ids = set(filtered_trips.service_id.unique())
            transit_data.calendar_dates = transit_data.calendar_dates[
                transit_data.calendar_dates.service_id.isin(service_ids)
            ]

        if hasattr(transit_data, "frequencies") and transit_data.frequencies is not None:
            # Keep only frequencies for remaining trips
            transit_data.frequencies = transit_data.frequencies[
                transit_data.frequencies.trip_id.isin(filtered_trip_ids)
            ]

    else:  # Feed
        # For Feed, also handle frequencies and shapes
        if (
            hasattr(transit_data, "shapes")
            and transit_data.shapes is not None
            and "shape_id" in filtered_trips.columns
        ):
            shape_ids = set(filtered_trips.shape_id.dropna().unique())
            transit_data.shapes = transit_data.shapes[transit_data.shapes.shape_id.isin(shape_ids)]

        if hasattr(transit_data, "frequencies") and transit_data.frequencies is not None:
            # Keep only frequencies for remaining trips
            transit_data.frequencies = transit_data.frequencies[
                transit_data.frequencies.trip_id.isin(filtered_trip_ids)
            ]

network_wrangler.transit.filter.truncate_route_at_stop

truncate_route_at_stop(
    transit_data, route_id, direction_id, stop_id, truncate
)

Truncate all trips of a route at a specific stop.

Removes stops before or after the specified stop for all trips matching the given route_id and direction_id. This is useful for shortening routes at terminal stations or service boundaries. Modifies transit_data in place.

Parameters:

Name Type Description Default
transit_data GtfsModel | Feed

Either a GtfsModel or Feed object to modify. Modified in place.

required
route_id str

The route_id to truncate

required
direction_id int

The direction_id of trips to truncate (0 or 1)

required
stop_id str | int

The stop where truncation occurs. For GtfsModel, this should be a string stop_id. For Feed, this should be an integer model_node_id.

required
truncate Literal['before', 'after']

Either “before” to remove stops before stop_id, or “after” to remove stops after stop_id

required

Raises:

Type Description
ValueError

If truncate is not “before” or “after”

ValueError

If stop_id is not found in any trips of the route/direction

Example

Truncate outbound BART trips to end at Embarcadero (GtfsModel)

truncate_route_at_stop( … gtfs_model, … route_id=”BART-01”, … direction_id=0, … stop_id=”EMBR”, # string stop_id … truncate=”after”, … )

Truncate outbound BART trips to end at node 12345 (Feed)

truncate_route_at_stop( … feed, … route_id=”BART-01”, … direction_id=0, … stop_id=12345, # integer model_node_id … truncate=”after”, … )

Source code in network_wrangler/transit/filter.py
def truncate_route_at_stop(  # noqa: PLR0912, PLR0915
    transit_data: GtfsModel | Feed,
    route_id: str,
    direction_id: int,
    stop_id: str | int,
    truncate: Literal["before", "after"],
) -> None:
    """Truncate all trips of a route at a specific stop.

    Removes stops before or after the specified stop for all trips matching
    the given route_id and direction_id. This is useful for shortening routes
    at terminal stations or service boundaries. Modifies transit_data in place.

    Args:
        transit_data: Either a GtfsModel or Feed object to modify. Modified in place.
        route_id: The route_id to truncate
        direction_id: The direction_id of trips to truncate (0 or 1)
        stop_id: The stop where truncation occurs. For GtfsModel, this should be
                a string stop_id. For Feed, this should be an integer model_node_id.
        truncate: Either "before" to remove stops before stop_id, or
                 "after" to remove stops after stop_id

    Raises:
        ValueError: If truncate is not "before" or "after"
        ValueError: If stop_id is not found in any trips of the route/direction

    Example:
        >>> # Truncate outbound BART trips to end at Embarcadero (GtfsModel)
        >>> truncate_route_at_stop(
        ...     gtfs_model,
        ...     route_id="BART-01",
        ...     direction_id=0,
        ...     stop_id="EMBR",  # string stop_id
        ...     truncate="after",
        ... )
        >>> # Truncate outbound BART trips to end at node 12345 (Feed)
        >>> truncate_route_at_stop(
        ...     feed,
        ...     route_id="BART-01",
        ...     direction_id=0,
        ...     stop_id=12345,  # integer model_node_id
        ...     truncate="after",
        ... )
    """
    if truncate not in ["before", "after"]:
        msg = f"truncate must be 'before' or 'after', got '{truncate}'"
        raise ValueError(msg)

    WranglerLogger.info(
        f"Truncating route {route_id} direction {direction_id} {truncate} stop {stop_id}"
    )

    # Get data tables (references, not copies)
    routes_df = transit_data.routes
    trips_df = transit_data.trips
    stop_times_df = transit_data.stop_times
    stops_df = transit_data.stops
    is_gtfs = isinstance(transit_data, GtfsModel)

    # Find trips to truncate
    trips_to_truncate = trips_df[
        (trips_df.route_id == route_id) & (trips_df.direction_id == direction_id)
    ]

    if len(trips_to_truncate) == 0:
        WranglerLogger.warning(f"No trips found for route {route_id} direction {direction_id}")
        return  # No changes needed

    trip_ids_to_truncate = set(trips_to_truncate.trip_id)
    WranglerLogger.debug(f"Found {len(trip_ids_to_truncate)} trips to truncate")

    # Check if stop_id exists in any of these trips
    stop_times_for_route = stop_times_df[
        (stop_times_df.trip_id.isin(trip_ids_to_truncate)) & (stop_times_df.stop_id == stop_id)
    ]

    if len(stop_times_for_route) == 0:
        msg = f"Stop {stop_id} not found in any trips of route {route_id} direction {direction_id}"
        raise ValueError(msg)

    # Process stop_times to truncate trips
    truncated_stop_times = []
    trips_truncated = 0

    for trip_id in trip_ids_to_truncate:
        trip_stop_times = stop_times_df[stop_times_df.trip_id == trip_id]
        trip_stop_times = trip_stop_times.sort_values("stop_sequence")

        # Find the stop_sequence for the truncation stop
        stop_mask = trip_stop_times.stop_id == stop_id
        if not stop_mask.any():
            # This trip doesn't have the stop, keep all stops
            truncated_stop_times.append(trip_stop_times)
            continue

        stop_sequence_at_stop = trip_stop_times.loc[stop_mask, "stop_sequence"].iloc[0]

        # Truncate based on direction
        if truncate == "before":
            # Keep stops from stop_id onwards
            truncated_stops = trip_stop_times[
                trip_stop_times.stop_sequence >= stop_sequence_at_stop
            ].copy()  # Need copy here since we'll modify stop_sequence
        else:  # truncate == "after"
            # Keep stops up to and including stop_id
            truncated_stops = trip_stop_times[
                trip_stop_times.stop_sequence <= stop_sequence_at_stop
            ].copy()  # Need copy here since we'll modify stop_sequence

        # Renumber stop_sequence to be consecutive starting from 0
        if len(truncated_stops) > 0:
            truncated_stops["stop_sequence"] = range(len(truncated_stops))

        # Log truncation details
        original_count = len(trip_stop_times)
        truncated_count = len(truncated_stops)
        if truncated_count < original_count:
            trips_truncated += 1

            # Get removed stops details
            removed_stop_ids = set(trip_stop_times.stop_id) - set(truncated_stops.stop_id)
            if removed_stop_ids and len(removed_stop_ids) <= MAX_TRUNCATION_WARNING_STOPS:
                # Get stop names for removed stops
                removed_stops_info = stops_df[stops_df.stop_id.isin(removed_stop_ids)][
                    ["stop_id", "stop_name"]
                ]
                removed_stops_list = [
                    f"{row['stop_id']} ({row['stop_name']})"
                    for _, row in removed_stops_info.iterrows()
                ]

                WranglerLogger.debug(
                    f"Trip {trip_id}: truncated from {original_count} to {truncated_count} stops. "
                    f"Removed: {', '.join(removed_stops_list)}"
                )
            else:
                WranglerLogger.debug(
                    f"Trip {trip_id}: truncated from {original_count} to {truncated_count} stops"
                )

        truncated_stop_times.append(truncated_stops)

    WranglerLogger.info(f"Truncated {trips_truncated} trips")

    # Combine all stop times (truncated and non-truncated)
    other_stop_times = stop_times_df[~stop_times_df.trip_id.isin(trip_ids_to_truncate)]
    all_stop_times = pd.concat([other_stop_times, *truncated_stop_times], ignore_index=True)

    # Find stops that are still referenced
    stops_still_used = set(all_stop_times.stop_id.unique())
    filtered_stops = stops_df[stops_df.stop_id.isin(stops_still_used)]

    # Check if any of these stops reference parent stations
    if "parent_station" in filtered_stops.columns:
        # Get parent stations that are referenced by kept stops
        parent_stations = filtered_stops["parent_station"].dropna().unique()
        parent_stations = [ps for ps in parent_stations if ps != ""]  # Remove empty strings

        if len(parent_stations) > 0:
            # Find parent stations that aren't already in our filtered stops
            existing_stop_ids = set(filtered_stops.stop_id)
            missing_parent_stations = [ps for ps in parent_stations if ps not in existing_stop_ids]

            if len(missing_parent_stations) > 0:
                WranglerLogger.debug(
                    f"Adding back {len(missing_parent_stations)} parent stations referenced by kept stops"
                )

                # Get the parent station records
                parent_station_records = stops_df[stops_df.stop_id.isin(missing_parent_stations)]

                # Append parent stations to filtered stops
                filtered_stops = pd.concat(
                    [filtered_stops, parent_station_records], ignore_index=True
                )

    # Log removed stops
    removed_stops = set(stops_df.stop_id) - set(filtered_stops.stop_id)
    if removed_stops:
        WranglerLogger.debug(f"Removed {len(removed_stops)} stops that are no longer referenced")

        # Get details of removed stops
        removed_stops_df = stops_df[stops_df.stop_id.isin(removed_stops)][["stop_id", "stop_name"]]

        # Log up to 20 removed stops with their names
        sample_size = min(20, len(removed_stops_df))
        for _, stop in removed_stops_df.head(sample_size).iterrows():
            WranglerLogger.debug(f"  - Removed stop: {stop['stop_id']} ({stop['stop_name']})")

        if len(removed_stops) > sample_size:
            WranglerLogger.debug(f"  ... and {len(removed_stops) - sample_size} more stops")

    # Update transit_data in place
    transit_data.stop_times = all_stop_times
    transit_data.trips = trips_df
    transit_data.routes = routes_df
    transit_data.stops = filtered_stops

Utilities for working with transit geodataframes.

shapes_to_shape_links_gdf(
    shapes,
    ref_nodes_df=None,
    from_field="A",
    to_field="B",
    crs=LAT_LON_CRS,
)

Translates shapes to shape links geodataframe using geometry from ref_nodes_df if provided.

TODO: Add join to links and then shapes to get true geometry.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

Feed shapes table

required
ref_nodes_df DataFrame[RoadNodesTable] | None

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
from_field str

Field used for the link’s from node model_node_id. Defaults to “A”.

'A'
to_field str

Field used for the link’s to node model_node_id. Defaults to “B”.

'B'
crs int

Coordinate reference system. SHouldn’t be changed unless you know what you are doing. Defaults to LAT_LON_CRS which is WGS84 lat/long.

LAT_LON_CRS

Returns:

Type Description
GeoDataFrame

gpd.GeoDataFrame: description

Source code in network_wrangler/transit/geo.py
def shapes_to_shape_links_gdf(
    shapes: DataFrame[WranglerShapesTable],
    ref_nodes_df: DataFrame[RoadNodesTable] | None = None,
    from_field: str = "A",
    to_field: str = "B",
    crs: int = LAT_LON_CRS,
) -> gpd.GeoDataFrame:
    """Translates shapes to shape links geodataframe using geometry from ref_nodes_df if provided.

    TODO: Add join to links and then shapes to get true geometry.

    Args:
        shapes: Feed shapes table
        ref_nodes_df: If specified, will use geometry from these nodes.  Otherwise, will use
            geometry in shapes file. Defaults to None.
        from_field: Field used for the link's from node `model_node_id`. Defaults to "A".
        to_field: Field used for the link's to node `model_node_id`. Defaults to "B".
        crs (int, optional): Coordinate reference system. SHouldn't be changed unless you know
            what you are doing. Defaults to LAT_LON_CRS which is WGS84 lat/long.

    Returns:
        gpd.GeoDataFrame: _description_
    """
    if ref_nodes_df is not None:
        shapes = update_shapes_geometry(shapes, ref_nodes_df)
    tr_links = unique_shape_links(shapes, from_field=from_field, to_field=to_field)
    # WranglerLogger.debug(f"tr_links :\n{tr_links }")

    geometry = linestring_from_lats_lons(
        tr_links,
        [f"shape_pt_lat_{from_field}", f"shape_pt_lat_{to_field}"],
        [f"shape_pt_lon_{from_field}", f"shape_pt_lon_{to_field}"],
    )
    # WranglerLogger.debug(f"geometry\n{geometry}")
    shapes_gdf = gpd.GeoDataFrame(tr_links, geometry=geometry, crs=crs).set_crs(LAT_LON_CRS)
    return shapes_gdf

network_wrangler.transit.geo.shapes_to_trip_shapes_gdf

shapes_to_trip_shapes_gdf(
    shapes, ref_nodes_df=None, crs=LAT_LON_CRS
)

Geodataframe with one polyline shape per shape_id.

TODO: add information about the route and trips.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

WranglerShapesTable

required
trips

WranglerTripsTable

required
ref_nodes_df DataFrame[RoadNodesTable] | None

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
crs int

int, optional, default 4326

LAT_LON_CRS
Source code in network_wrangler/transit/geo.py
def shapes_to_trip_shapes_gdf(
    shapes: DataFrame[WranglerShapesTable],
    # trips: WranglerTripsTable,
    ref_nodes_df: DataFrame[RoadNodesTable] | None = None,
    crs: int = LAT_LON_CRS,
) -> gpd.GeoDataFrame:
    """Geodataframe with one polyline shape per shape_id.

    TODO: add information about the route and trips.

    Args:
        shapes: WranglerShapesTable
        trips: WranglerTripsTable
        ref_nodes_df: If specified, will use geometry from these nodes.  Otherwise, will use
            geometry in shapes file. Defaults to None.
        crs: int, optional, default 4326
    """
    if ref_nodes_df is not None:
        shapes = update_shapes_geometry(shapes, ref_nodes_df)

    shape_geom = (
        shapes[["shape_id", "shape_pt_lat", "shape_pt_lon"]]
        .groupby("shape_id")
        .agg(list)
        .apply(lambda x: LineString(zip(x.iloc[1], x.iloc[0], strict=True)), axis=1)
    )

    route_shapes_gdf = gpd.GeoDataFrame(
        data=shape_geom.index, geometry=shape_geom.values, crs=crs
    ).set_crs(LAT_LON_CRS)

    return route_shapes_gdf
stop_times_to_stop_time_links_gdf(
    stop_times,
    stops,
    ref_nodes_df=None,
    from_field="A",
    to_field="B",
)

Stop times geodataframe as links using geometry from stops.txt or optionally another df.

Parameters:

Name Type Description Default
stop_times WranglerStopTimesTable

Feed stop times table.

required
stops WranglerStopsTable

Feed stops table.

required
ref_nodes_df DataFrame

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
from_field str

Field used for the link’s from node model_node_id. Defaults to “A”.

'A'
to_field str

Field used for the link’s to node model_node_id. Defaults to “B”.

'B'
Source code in network_wrangler/transit/geo.py
def stop_times_to_stop_time_links_gdf(
    stop_times: DataFrame[WranglerStopTimesTable],
    stops: DataFrame[WranglerStopsTable],
    ref_nodes_df: DataFrame[RoadNodesTable] | None = None,
    from_field: str = "A",
    to_field: str = "B",
) -> gpd.GeoDataFrame:
    """Stop times geodataframe as links using geometry from stops.txt or optionally another df.

    Args:
        stop_times (WranglerStopTimesTable): Feed stop times table.
        stops (WranglerStopsTable): Feed stops table.
        ref_nodes_df (pd.DataFrame, optional): If specified, will use geometry from these nodes.
            Otherwise, will use geometry in shapes file. Defaults to None.
        from_field: Field used for the link's from node `model_node_id`. Defaults to "A".
        to_field: Field used for the link's to node `model_node_id`. Defaults to "B".
    """
    from ..utils.geo import linestring_from_lats_lons

    if ref_nodes_df is not None:
        stops = update_stops_geometry(stops, ref_nodes_df)

    lat_fields = []
    lon_fields = []
    tr_links = unique_stop_time_links(stop_times, from_field=from_field, to_field=to_field)
    for f in (from_field, to_field):
        tr_links = tr_links.merge(
            stops[["stop_id", "stop_lat", "stop_lon"]],
            right_on="stop_id",
            left_on=f,
            how="left",
        )
        lon_f = f"{f}_X"
        lat_f = f"{f}_Y"
        tr_links = tr_links.rename(columns={"stop_lon": lon_f, "stop_lat": lat_f})
        lon_fields.append(lon_f)
        lat_fields.append(lat_f)

    geometry = linestring_from_lats_lons(tr_links, lat_fields, lon_fields)
    return gpd.GeoDataFrame(tr_links, geometry=geometry).set_crs(LAT_LON_CRS)

network_wrangler.transit.geo.stop_times_to_stop_time_points_gdf

stop_times_to_stop_time_points_gdf(
    stop_times, stops, ref_nodes_df=None
)

Stoptimes geodataframe as points using geometry from stops.txt or optionally another df.

Parameters:

Name Type Description Default
stop_times WranglerStopTimesTable

Feed stop times table.

required
stops WranglerStopsTable

Feed stops table.

required
ref_nodes_df DataFrame

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
Source code in network_wrangler/transit/geo.py
def stop_times_to_stop_time_points_gdf(
    stop_times: DataFrame[WranglerStopTimesTable],
    stops: DataFrame[WranglerStopsTable],
    ref_nodes_df: DataFrame[RoadNodesTable] | None = None,
) -> gpd.GeoDataFrame:
    """Stoptimes geodataframe as points using geometry from stops.txt or optionally another df.

    Args:
        stop_times (WranglerStopTimesTable): Feed stop times table.
        stops (WranglerStopsTable): Feed stops table.
        ref_nodes_df (pd.DataFrame, optional): If specified, will use geometry from these nodes.
            Otherwise, will use geometry in shapes file. Defaults to None.
    """
    if ref_nodes_df is not None:
        stops = update_stops_geometry(stops, ref_nodes_df)

    stop_times_geo = stop_times.merge(
        stops[["stop_id", "stop_lat", "stop_lon"]],
        right_on="stop_id",
        left_on="stop_id",
        how="left",
    )
    return gpd.GeoDataFrame(
        stop_times_geo,
        geometry=gpd.points_from_xy(stop_times_geo["stop_lon"], stop_times_geo["stop_lat"]),
        crs=LAT_LON_CRS,
    )

network_wrangler.transit.geo.update_shapes_geometry

update_shapes_geometry(shapes, ref_nodes_df)

Returns shapes table with geometry updated from ref_nodes_df.

NOTE: does not update “geometry” field if it exists.

Source code in network_wrangler/transit/geo.py
def update_shapes_geometry(
    shapes: DataFrame[WranglerShapesTable], ref_nodes_df: DataFrame[RoadNodesTable]
) -> DataFrame[WranglerShapesTable]:
    """Returns shapes table with geometry updated from ref_nodes_df.

    NOTE: does not update "geometry" field if it exists.
    """
    return update_point_geometry(
        shapes,
        ref_nodes_df,
        id_field="shape_model_node_id",
        lon_field="shape_pt_lon",
        lat_field="shape_pt_lat",
    )

network_wrangler.transit.geo.update_stops_geometry

update_stops_geometry(stops, ref_nodes_df)

Returns stops table with geometry updated from ref_nodes_df.

NOTE: does not update “geometry” field if it exists.

Source code in network_wrangler/transit/geo.py
def update_stops_geometry(
    stops: DataFrame[WranglerStopsTable], ref_nodes_df: DataFrame[RoadNodesTable]
) -> DataFrame[WranglerStopsTable]:
    """Returns stops table with geometry updated from ref_nodes_df.

    NOTE: does not update "geometry" field if it exists.
    """
    return update_point_geometry(
        stops, ref_nodes_df, id_field="stop_id", lon_field="stop_lon", lat_field="stop_lat"
    )

Functions for reading and writing transit feeds and networks.

network_wrangler.transit.io.convert_transit_serialization

convert_transit_serialization(
    input_path,
    output_format,
    out_dir=".",
    input_file_format="csv",
    out_prefix="",
    overwrite=True,
)

Converts a transit network from one serialization to another.

Parameters:

Name Type Description Default
input_path str | Path

path to the input network

required
output_format TransitFileTypes

the format of the output files. Should be txt, csv, or parquet.

required
out_dir Path | str

directory to write the network to. Defaults to current directory.

'.'
input_file_format TransitFileTypes

the file_format of the files to read. Should be txt, csv, or parquet. Defaults to “txt”

'csv'
out_prefix str

prefix to add to the file name. Defaults to “”

''
overwrite bool

if True, will overwrite the files if they already exist. Defaults to True

True
Source code in network_wrangler/transit/io.py
def convert_transit_serialization(
    input_path: str | Path,
    output_format: TransitFileTypes,
    out_dir: Path | str = ".",
    input_file_format: TransitFileTypes = "csv",
    out_prefix: str = "",
    overwrite: bool = True,
):
    """Converts a transit network from one serialization to another.

    Args:
        input_path: path to the input network
        output_format: the format of the output files. Should be txt, csv, or parquet.
        out_dir: directory to write the network to. Defaults to current directory.
        input_file_format: the file_format of the files to read. Should be txt, csv, or parquet.
            Defaults to "txt"
        out_prefix: prefix to add to the file name. Defaults to ""
        overwrite: if True, will overwrite the files if they already exist. Defaults to True
    """
    WranglerLogger.info(
        f"Loading transit net from {input_path} with input type {input_file_format}"
    )
    net = load_transit(input_path, file_format=input_file_format)
    WranglerLogger.info(f"Writing transit network to {out_dir} in {output_format} format.")
    write_transit(
        net,
        prefix=out_prefix,
        out_dir=out_dir,
        file_format=output_format,
        overwrite=overwrite,
    )

network_wrangler.transit.io.load_feed_from_dfs

load_feed_from_dfs(feed_dfs, wrangler_flavored=True)

Create a Feed or GtfsModel object from a dictionary of DataFrames representing a GTFS feed.

Parameters:

Name Type Description Default
feed_dfs dict

A dictionary containing DataFrames representing the tables of a GTFS feed.

required
wrangler_flavored bool

If True, creates a Wrangler-enhanced Feed] object. If False, creates a pure GtfsModel object. Defaults to True.

True

Returns:

Type Description
Feed | GtfsModel

Union[Feed, GtfsModel]: A Feed or GtfsModel object representing the transit network.

Raises:

Type Description
ValueError

If the feed_dfs dictionary does not contain all the required tables.

Example usage:

feed_dfs = {
    "agency": agency_df,
    "routes": routes_df,
    "stops": stops_df,
    "trips": trips_df,
    "stop_times": stop_times_df,
}
feed = load_feed_from_dfs(feed_dfs)  # Creates Feed by default
gtfs_model = load_feed_from_dfs(feed_dfs, wrangler_flavored=False)  # Creates GtfsModel

Source code in network_wrangler/transit/io.py
def load_feed_from_dfs(feed_dfs: dict, wrangler_flavored: bool = True) -> Feed | GtfsModel:
    """Create a Feed or GtfsModel object from a dictionary of DataFrames representing a GTFS feed.

    Args:
        feed_dfs (dict): A dictionary containing DataFrames representing the tables of a GTFS feed.
        wrangler_flavored: If True, creates a Wrangler-enhanced Feed] object.
                           If False, creates a pure GtfsModel object. Defaults to True.

    Returns:
        Union[Feed, GtfsModel]: A Feed or GtfsModel object representing the transit network.

    Raises:
        ValueError: If the feed_dfs dictionary does not contain all the required tables.

    Example usage:
    ```python
    feed_dfs = {
        "agency": agency_df,
        "routes": routes_df,
        "stops": stops_df,
        "trips": trips_df,
        "stop_times": stop_times_df,
    }
    feed = load_feed_from_dfs(feed_dfs)  # Creates Feed by default
    gtfs_model = load_feed_from_dfs(feed_dfs, wrangler_flavored=False)  # Creates GtfsModel
    ```
    """
    # Use the appropriate model class based on the parameter
    model_class = Feed if wrangler_flavored else GtfsModel

    if not all(table in feed_dfs for table in model_class.table_names):
        model_name = "Feed" if wrangler_flavored else "GtfsModel"
        msg = f"feed_dfs must contain the following tables for {model_name}: {model_class.table_names}"
        raise ValueError(msg)

    feed = model_class(**feed_dfs)

    return feed

network_wrangler.transit.io.load_feed_from_path

load_feed_from_path(
    feed_path,
    file_format="txt",
    wrangler_flavored=True,
    service_ids_filter=None,
    **read_kwargs,
)

Create a Feed or GtfsModel object from the path to a GTFS transit feed.

Parameters:

Name Type Description Default
feed_path Union[Path, str]

The path to the GTFS transit feed.

required
file_format TransitFileTypes

the format of the files to read. Defaults to “txt”

'txt'
wrangler_flavored bool

If True, creates a Wrangler-enhanced Feed object. If False, creates a pure GtfsModel object. Defaults to True.

True
service_ids_filter Optional[list[str]]

If not None, filter to these service_ids. Assumes service_id is a str.

None
**read_kwargs

Additional keyword arguments to pass to the file reader (e.g., low_memory, dtype)

{}

Returns:

Type Description
Feed | GtfsModel

Union[Feed, GtfsModel]: The Feed or GtfsModel object created from the GTFS transit feed.

Source code in network_wrangler/transit/io.py
def load_feed_from_path(
    feed_path: Path | str,
    file_format: TransitFileTypes = "txt",
    wrangler_flavored: bool = True,
    service_ids_filter: list[str] | None = None,
    **read_kwargs,
) -> Feed | GtfsModel:
    """Create a Feed or GtfsModel object from the path to a GTFS transit feed.

    Args:
        feed_path (Union[Path, str]): The path to the GTFS transit feed.
        file_format: the format of the files to read. Defaults to "txt"
        wrangler_flavored: If True, creates a Wrangler-enhanced Feed object.
                          If False, creates a pure GtfsModel object. Defaults to True.
        service_ids_filter (Optional[list[str]]): If not None, filter to these service_ids. Assumes service_id is a str.
        **read_kwargs: Additional keyword arguments to pass to the file reader (e.g., low_memory, dtype)

    Returns:
        Union[Feed, GtfsModel]: The Feed or GtfsModel object created from the GTFS transit feed.
    """
    feed_path = _feed_path_ref(Path(feed_path))  # unzips if needs to be unzipped

    if not feed_path.is_dir():
        msg = f"Feed path not a directory: {feed_path}"
        raise NotADirectoryError(msg)

    WranglerLogger.info(f"Reading GTFS feed tables from {feed_path}")

    # Use the appropriate table names based on the model type
    model_class = Feed if wrangler_flavored else GtfsModel
    feed_possible_files = {
        table: list(feed_path.glob(f"*{table}.{file_format}"))
        for table in model_class.table_names + model_class.optional_table_names
    }
    WranglerLogger.debug(f"model_class={model_class}  feed_possible_files={feed_possible_files}")

    # make sure we have all the tables we need -- missing optional is ok
    _missing_files = []
    for table_name in list(feed_possible_files.keys()):
        if not feed_possible_files[table_name]:
            # remove those that don't have files
            del feed_possible_files[table_name]

            # missiong optional is ok
            if table_name in model_class.table_names:
                _missing_files.append(table_name)

    if _missing_files:
        WranglerLogger.debug(f"!!! Missing transit files: {_missing_files}")
        model_name = "Feed" if wrangler_flavored else "GtfsModel"
        msg = f"Required GTFS {model_name} table(s) not in {feed_path}: \n  {_missing_files}"
        raise RequiredTableError(msg)

    # but don't want to have more than one file per search
    _ambiguous_files = [t for t, v in feed_possible_files.items() if len(v) > 1]
    if _ambiguous_files:
        WranglerLogger.warning(
            f"! More than one file matches following tables. "
            + f"Using the first on the list: {_ambiguous_files}"
        )

    feed_files = {t: f[0] for t, f in feed_possible_files.items()}
    feed_dfs = {
        table: _read_table_from_file(table, file, **read_kwargs)
        for table, file in feed_files.items()
    }

    # Create the feed object first
    feed_obj = load_feed_from_dfs(feed_dfs, wrangler_flavored=wrangler_flavored)
    WranglerLogger.debug(f"loaded {type(feed_obj)} from dfs:\n{feed_obj}")

    # Apply service_ids filter if provided
    if service_ids_filter is not None:
        feed_obj = filter_feed_by_service_ids(feed_obj, service_ids_filter)

    return feed_obj

network_wrangler.transit.io.load_transit

load_transit(feed, file_format='txt', config=DefaultConfig)

Create a TransitNetwork object.

This function takes in a feed parameter, which can be one of the following types:

  • Feed: A Feed object representing a transit feed.
  • dict[str, pd.DataFrame]: A dictionary of DataFrames representing transit data.
  • str or Path: A string or a Path object representing the path to a transit feed file.

Parameters:

Name Type Description Default
feed Feed | GtfsModel | dict[str, DataFrame] | str | Path

Feed boject, dict of transit data frames, or path to transit feed data

required
file_format TransitFileTypes

the format of the files to read. Defaults to “txt”

'txt'
config WranglerConfig

WranglerConfig object. Defaults to DefaultConfig.

DefaultConfig

Returns:

Type Description
TransitNetwork

object representing the loaded transit network.

Example usage:

transit_network_from_zip = load_transit("path/to/gtfs.zip")

transit_network_from_unzipped_dir = load_transit("path/to/files")

transit_network_from_parquet = load_transit("path/to/files", file_format="parquet")

dfs_of_transit_data = {"routes": routes_df, "stops": stops_df, "trips": trips_df...}
transit_network_from_dfs = load_transit(dfs_of_transit_data)

Source code in network_wrangler/transit/io.py
def load_transit(
    feed: Feed | GtfsModel | dict[str, pd.DataFrame] | str | Path,
    file_format: TransitFileTypes = "txt",
    config: WranglerConfig = DefaultConfig,
) -> TransitNetwork:
    """Create a [`TransitNetwork`][network_wrangler.transit.network.TransitNetwork] object.

    This function takes in a `feed` parameter, which can be one of the following types:

    - `Feed`: A Feed object representing a transit feed.
    - `dict[str, pd.DataFrame]`: A dictionary of DataFrames representing transit data.
    - `str` or `Path`: A string or a Path object representing the path to a transit feed file.

    Args:
        feed: Feed boject, dict of transit data frames, or path to transit feed data
        file_format: the format of the files to read. Defaults to "txt"
        config: WranglerConfig object. Defaults to DefaultConfig.

    Returns:
        (TransitNetwork): object representing the loaded transit network.

    Raises:
    ValueError: If the `feed` parameter is not one of the supported types.

    Example usage:
    ```python
    transit_network_from_zip = load_transit("path/to/gtfs.zip")

    transit_network_from_unzipped_dir = load_transit("path/to/files")

    transit_network_from_parquet = load_transit("path/to/files", file_format="parquet")

    dfs_of_transit_data = {"routes": routes_df, "stops": stops_df, "trips": trips_df...}
    transit_network_from_dfs = load_transit(dfs_of_transit_data)
    ```

    """
    if isinstance(feed, Path | str):
        feed = Path(feed)
        feed_obj = load_feed_from_path(feed, file_format=file_format)
        feed_obj.feed_path = feed
    elif isinstance(feed, dict):
        feed_obj = load_feed_from_dfs(feed)
    elif isinstance(feed, GtfsModel):
        feed_obj = Feed(**feed.__dict__)
    else:
        if not isinstance(feed, Feed):
            msg = f"TransitNetwork must be seeded with a Feed, dict of dfs or Path. Found {type(feed)}"
            raise ValueError(msg)
        feed_obj = feed

    return TransitNetwork(feed_obj, config=config)

network_wrangler.transit.io.write_feed_geo

write_feed_geo(
    feed,
    ref_nodes_df,
    out_dir,
    file_format="geojson",
    out_prefix=None,
    overwrite=True,
)

Write a Feed object to a directory in a geospatial format.

Parameters:

Name Type Description Default
feed Feed

Feed object to write

required
ref_nodes_df GeoDataFrame

Reference nodes dataframe to use for geometry

required
out_dir str | Path

directory to write the network to

required
file_format Literal['geojson', 'shp', 'parquet']

the format of the output files. Defaults to “geojson”

'geojson'
out_prefix

prefix to add to the file name

None
overwrite bool

if True, will overwrite the files if they already exist. Defaults to True

True
Source code in network_wrangler/transit/io.py
def write_feed_geo(
    feed: Feed,
    ref_nodes_df: gpd.GeoDataFrame,
    out_dir: str | Path,
    file_format: Literal["geojson", "shp", "parquet"] = "geojson",
    out_prefix=None,
    overwrite: bool = True,
) -> None:
    """Write a Feed object to a directory in a geospatial format.

    Args:
        feed: Feed object to write
        ref_nodes_df: Reference nodes dataframe to use for geometry
        out_dir: directory to write the network to
        file_format: the format of the output files. Defaults to "geojson"
        out_prefix: prefix to add to the file name
        overwrite: if True, will overwrite the files if they already exist. Defaults to True
    """
    from .geo import shapes_to_shape_links_gdf

    out_dir = Path(out_dir)
    if not out_dir.is_dir():
        if out_dir.parent.is_dir():
            out_dir.mkdir()
        else:
            msg = f"Output directory {out_dir} ands its parent path does not exist"
            raise FileNotFoundError(msg)

    prefix = f"{out_prefix}_" if out_prefix else ""
    shapes_outpath = out_dir / f"{prefix}trn_shapes.{file_format}"
    shapes_gdf = shapes_to_shape_links_gdf(feed.shapes, ref_nodes_df=ref_nodes_df)
    write_table(shapes_gdf, shapes_outpath, overwrite=overwrite)

    stops_outpath = out_dir / f"{prefix}trn_stops.{file_format}"
    stops_gdf = to_points_gdf(feed.stops, ref_nodes_df=ref_nodes_df)
    write_table(stops_gdf, stops_outpath, overwrite=overwrite)

network_wrangler.transit.io.write_transit

write_transit(
    transit_obj,
    out_dir=".",
    prefix=None,
    file_format="txt",
    overwrite=True,
)

Writes a transit network, feed, or GTFS model to files.

Parameters:

Name Type Description Default
transit_obj TransitNetwork | Feed | GtfsModel

a TransitNetwork, Feed, or GtfsModel instance

required
out_dir Path | str

directory to write the network to

'.'
file_format Literal['txt', 'csv', 'parquet']

the format of the output files. Defaults to “txt” which is csv with txt file format.

'txt'
prefix Path | str | None

prefix to add to the file name

None
overwrite bool

if True, will overwrite the files if they already exist. Defaults to True

True
Source code in network_wrangler/transit/io.py
def write_transit(
    transit_obj: TransitNetwork | Feed | GtfsModel,
    out_dir: Path | str = ".",
    prefix: Path | str | None = None,
    file_format: Literal["txt", "csv", "parquet"] = "txt",
    overwrite: bool = True,
) -> None:
    """Writes a transit network, feed, or GTFS model to files.

    Args:
        transit_obj: a TransitNetwork, Feed, or GtfsModel instance
        out_dir: directory to write the network to
        file_format: the format of the output files. Defaults to "txt" which is csv with txt
            file format.
        prefix: prefix to add to the file name
        overwrite: if True, will overwrite the files if they already exist. Defaults to True
    """
    out_dir = Path(out_dir)
    prefix = f"{prefix}_" if prefix else ""

    # Determine the data source based on input type
    if isinstance(transit_obj, TransitNetwork):
        data_source = transit_obj.feed
        source_type = "TransitNetwork"
    elif isinstance(transit_obj, Feed | GtfsModel):
        data_source = transit_obj
        source_type = type(transit_obj).__name__
    else:
        msg = (
            f"transit_obj must be a TransitNetwork, Feed, or GtfsModel instance, "
            f"not {type(transit_obj).__name__}"
        )
        raise TypeError(msg)

    # Write tables
    tables_written = 0
    for table in data_source.table_names:
        df = data_source.get_table(table)
        if df is not None and len(df) > 0:  # Only write non-empty tables
            outpath = out_dir / f"{prefix}{table}.{file_format}"
            write_table(df, outpath, overwrite=overwrite)
            tables_written += 1

    WranglerLogger.info(f"Wrote {tables_written} {source_type} tables to {out_dir}")

ModelTransit class and functions for managing consistency between roadway and transit networks.

NOTE: this is not thoroughly tested and may not be fully functional.

network_wrangler.transit.model_transit.ModelTransit

ModelTransit class for managing consistency between roadway and transit networks.

Source code in network_wrangler/transit/model_transit.py
class ModelTransit:
    """ModelTransit class for managing consistency between roadway and transit networks."""

    def __init__(
        self,
        transit_net: TransitNetwork,
        roadway_net: RoadwayNetwork,
        shift_transit_to_managed_lanes: bool = True,
    ):
        """ModelTransit class for managing consistency between roadway and transit networks."""
        self.transit_net = transit_net
        self.roadway_net = roadway_net
        self._roadway_net_version = None
        self._transit_feed_version = None
        self._transit_shifted_to_ML = shift_transit_to_managed_lanes

    @property
    def model_roadway_net(self):
        """ModelRoadwayNetwork associated with this ModelTransit."""
        return self.roadway_net.model_net

    @property
    def consistent_nets(self) -> bool:
        """Indicate if roadway and transit networks have changed since self.m_feed updated."""
        return bool(
            self.roadway_net.modification_version == self._roadway_net_version
            and self.transit_net.feed.modification_version == self._transit_feed_version
        )

    @property
    def m_feed(self):
        """TransitNetwork.feed with updates for consistency with associated ModelRoadwayNetwork."""
        if self.consistent_nets:
            return self._m_feed
        # NOTE: look at this
        # If networks have changed, update model transit and update reference version
        self._roadway_net_version = self.roadway_net.modification_version
        self._transit_feed_version = self.transit_net.feed.modification_version

        if not self._transit_shifted_to_ML:
            self._m_feed = copy.deepcopy(self.transit_net.feed)
            return self._m_feed
        return None

network_wrangler.transit.model_transit.ModelTransit.consistent_nets property

consistent_nets

Indicate if roadway and transit networks have changed since self.m_feed updated.

network_wrangler.transit.model_transit.ModelTransit.m_feed property

m_feed

TransitNetwork.feed with updates for consistency with associated ModelRoadwayNetwork.

network_wrangler.transit.model_transit.ModelTransit.model_roadway_net property

model_roadway_net

ModelRoadwayNetwork associated with this ModelTransit.

network_wrangler.transit.model_transit.ModelTransit.__init__

__init__(
    transit_net,
    roadway_net,
    shift_transit_to_managed_lanes=True,
)

ModelTransit class for managing consistency between roadway and transit networks.

Source code in network_wrangler/transit/model_transit.py
def __init__(
    self,
    transit_net: TransitNetwork,
    roadway_net: RoadwayNetwork,
    shift_transit_to_managed_lanes: bool = True,
):
    """ModelTransit class for managing consistency between roadway and transit networks."""
    self.transit_net = transit_net
    self.roadway_net = roadway_net
    self._roadway_net_version = None
    self._transit_feed_version = None
    self._transit_shifted_to_ML = shift_transit_to_managed_lanes

Classes and functions for selecting transit trips from a transit network.

Usage:

Create a TransitSelection object by providing a TransitNetwork object and a selection dictionary:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
```python
selection_dict = {
    "links": {...},
    "nodes": {...},
    "route_properties": {...},
    "trip_properties": {...},
    "timespans": {...},
}
transit_selection = TransitSelection(transit_network, selection_dict)
```

Access the selected trip ids or dataframe as follows:

1
2
3
4
```python
selected_trips = transit_selection.selected_trips
selected_trips_df = transit_selection.selected_trips_df
```

Note: The selection dictionary should conform to the SelectTransitTrips model defined in the models.projects.transit_selection module.

network_wrangler.transit.selection.TransitSelection

Object to perform and store information about a selection from a project card “facility”.

Attributes:

Name Type Description
selection_dict
selected_trips list
selected_trips_df DataFrame[WranglerTripsTable]

pd.DataFrame: DataFrame of selected trips

sel_key
net
Source code in network_wrangler/transit/selection.py
class TransitSelection:
    """Object to perform and store information about a selection from a project card "facility".

    Attributes:
        selection_dict: dict: Dictionary of selection criteria
        selected_trips: list: List of selected trips
        selected_trips_df: pd.DataFrame: DataFrame of selected trips
        sel_key: str: Hash of selection_dict
        net: TransitNetwork: Network to select from
    """

    def __init__(
        self,
        net: TransitNetwork,
        selection_dict: dict | SelectTransitTrips,
    ):
        """Constructor for TransitSelection object.

        Args:
            net (TransitNetwork): Transit network object to select from.
            selection_dict: Selection dictionary conforming to SelectTransitTrips
        """
        self.net = net
        self.selection_dict = selection_dict

        # Initialize
        self._selected_trips_df = None
        self.sel_key = dict_to_hexkey(selection_dict)
        self._stored_feed_version = self.net.feed.modification_version

        WranglerLogger.debug(f"...created TransitSelection object: {selection_dict}")

    def __nonzero__(self):
        """Return True if there are selected trips."""
        return len(self.selected_trips_df) > 0

    @property
    def selection_dict(self):
        """Getter for selection_dict."""
        return self._selection_dict

    @selection_dict.setter
    def selection_dict(self, value: dict | SelectTransitTrips):
        self._selection_dict = self.validate_selection_dict(value)

    def validate_selection_dict(self, selection_dict: dict | SelectTransitTrips) -> dict:
        """Check that selection dictionary has valid and used properties consistent with network.

        Checks that selection_dict is a valid TransitSelectionDict:
            - query vars exist in respective Feed tables
        Args:
            selection_dict (dict): selection dictionary

        Raises:
            TransitSelectionNetworkConsistencyError: If not consistent with transit network
            ValidationError: if format not consistent with SelectTransitTrips
        """
        if not isinstance(selection_dict, SelectTransitTrips):
            selection_dict = SelectTransitTrips(**selection_dict)
        selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
        WranglerLogger.debug(f"SELECT DICT - before Validation: \n{selection_dict}")
        _trip_selection_fields = list((selection_dict.get("trip_properties", {}) or {}).keys())
        _missing_trip_fields = set(_trip_selection_fields) - set(self.net.feed.trips.columns)

        if _missing_trip_fields:
            msg = f"Fields in trip selection dictionary but not trips.txt: {_missing_trip_fields}"
            raise TransitSelectionNetworkConsistencyError(msg)

        _route_selection_fields = list((selection_dict.get("route_properties", {}) or {}).keys())
        _missing_route_fields = set(_route_selection_fields) - set(self.net.feed.routes.columns)

        if _missing_route_fields:
            msg = (
                f"Fields in route selection dictionary but not routes.txt: {_missing_route_fields}"
            )
            raise TransitSelectionNetworkConsistencyError(msg)
        return selection_dict

    @property
    def selected_trips(self) -> list:
        """List of selected trip_ids."""
        if self.selected_trips_df is None:
            return []
        return self.selected_trips_df.trip_id.tolist()

    @property
    def selected_trips_df(self) -> DataFrame[WranglerTripsTable]:
        """Lazily evaluates selection for trips or returns stored value in self._selected_trips_df.

        Will re-evaluate if the current feed modification version is different than the stored
        one from the last selection.

        Returns:
            DataFrame[WranglerTripsTable] of selected trips
        """
        if (
            self._selected_trips_df is not None
        ) and self._stored_feed_version == self.net.feed.modification_version:
            return self._selected_trips_df

        self._selected_trips_df = self._select_trips()
        self._stored_feed_version = self.net.feed.modification_version
        return self._selected_trips_df

    @property
    def selected_frequencies_df(self) -> DataFrame[WranglerFrequenciesTable]:
        """DataFrame of selected frequencies."""
        sel_freq_df = self.net.feed.frequencies.loc[
            self.net.feed.frequencies.trip_id.isin(self.selected_trips_df.trip_id)
        ]
        # if timespans are selected, filter to those that overlap
        if self.selection_dict.get("timespans"):
            sel_freq_df = filter_df_to_overlapping_timespans(
                sel_freq_df, self.selection_dict.get("timespans")
            )
        return sel_freq_df

    @property
    def selected_shapes_df(self) -> DataFrame[WranglerShapesTable]:
        """DataFrame of selected shapes.

        Can visualize the selected shapes quickly using the following code:

        ```python
        all_routes = net.feed.shapes.plot(color="gray")
        selection.selected_shapes_df.plot(ax=all_routes, color="red")
        ```

        """
        return self.net.feed.shapes.loc[
            self.net.feed.shapes.shape_id.isin(self.selected_trips_df.shape_id)
        ]

    def _select_trips(self) -> DataFrame[WranglerTripsTable]:
        """Selects transit trips based on selection dictionary.

        Returns:
            DataFrame[WranglerTripsTable]: trips_df DataFrame of selected trips
        """
        return _filter_trips_by_selection_dict(
            self.net.feed,
            self.selection_dict,
        )

network_wrangler.transit.selection.TransitSelection.selected_frequencies_df property

selected_frequencies_df

DataFrame of selected frequencies.

network_wrangler.transit.selection.TransitSelection.selected_shapes_df property

selected_shapes_df

DataFrame of selected shapes.

Can visualize the selected shapes quickly using the following code:

all_routes = net.feed.shapes.plot(color="gray")
selection.selected_shapes_df.plot(ax=all_routes, color="red")

network_wrangler.transit.selection.TransitSelection.selected_trips property

selected_trips

List of selected trip_ids.

network_wrangler.transit.selection.TransitSelection.selected_trips_df property

selected_trips_df

Lazily evaluates selection for trips or returns stored value in self._selected_trips_df.

Will re-evaluate if the current feed modification version is different than the stored one from the last selection.

Returns:

Type Description
DataFrame[WranglerTripsTable]

DataFrame[WranglerTripsTable] of selected trips

network_wrangler.transit.selection.TransitSelection.selection_dict property writable

selection_dict

Getter for selection_dict.

network_wrangler.transit.selection.TransitSelection.__init__

__init__(net, selection_dict)

Constructor for TransitSelection object.

Parameters:

Name Type Description Default
net TransitNetwork

Transit network object to select from.

required
selection_dict dict | SelectTransitTrips

Selection dictionary conforming to SelectTransitTrips

required
Source code in network_wrangler/transit/selection.py
def __init__(
    self,
    net: TransitNetwork,
    selection_dict: dict | SelectTransitTrips,
):
    """Constructor for TransitSelection object.

    Args:
        net (TransitNetwork): Transit network object to select from.
        selection_dict: Selection dictionary conforming to SelectTransitTrips
    """
    self.net = net
    self.selection_dict = selection_dict

    # Initialize
    self._selected_trips_df = None
    self.sel_key = dict_to_hexkey(selection_dict)
    self._stored_feed_version = self.net.feed.modification_version

    WranglerLogger.debug(f"...created TransitSelection object: {selection_dict}")

network_wrangler.transit.selection.TransitSelection.__nonzero__

__nonzero__()

Return True if there are selected trips.

Source code in network_wrangler/transit/selection.py
def __nonzero__(self):
    """Return True if there are selected trips."""
    return len(self.selected_trips_df) > 0

network_wrangler.transit.selection.TransitSelection.validate_selection_dict

validate_selection_dict(selection_dict)

Check that selection dictionary has valid and used properties consistent with network.

Checks that selection_dict is a valid TransitSelectionDict
  • query vars exist in respective Feed tables

Raises:

Type Description
TransitSelectionNetworkConsistencyError

If not consistent with transit network

ValidationError

if format not consistent with SelectTransitTrips

Source code in network_wrangler/transit/selection.py
def validate_selection_dict(self, selection_dict: dict | SelectTransitTrips) -> dict:
    """Check that selection dictionary has valid and used properties consistent with network.

    Checks that selection_dict is a valid TransitSelectionDict:
        - query vars exist in respective Feed tables
    Args:
        selection_dict (dict): selection dictionary

    Raises:
        TransitSelectionNetworkConsistencyError: If not consistent with transit network
        ValidationError: if format not consistent with SelectTransitTrips
    """
    if not isinstance(selection_dict, SelectTransitTrips):
        selection_dict = SelectTransitTrips(**selection_dict)
    selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
    WranglerLogger.debug(f"SELECT DICT - before Validation: \n{selection_dict}")
    _trip_selection_fields = list((selection_dict.get("trip_properties", {}) or {}).keys())
    _missing_trip_fields = set(_trip_selection_fields) - set(self.net.feed.trips.columns)

    if _missing_trip_fields:
        msg = f"Fields in trip selection dictionary but not trips.txt: {_missing_trip_fields}"
        raise TransitSelectionNetworkConsistencyError(msg)

    _route_selection_fields = list((selection_dict.get("route_properties", {}) or {}).keys())
    _missing_route_fields = set(_route_selection_fields) - set(self.net.feed.routes.columns)

    if _missing_route_fields:
        msg = (
            f"Fields in route selection dictionary but not routes.txt: {_missing_route_fields}"
        )
        raise TransitSelectionNetworkConsistencyError(msg)
    return selection_dict

Functions to check for transit network validity and consistency with roadway network.

shape_links_without_road_links(tr_shapes, rd_links_df)

Validate that links in transit shapes exist in referenced roadway links.

Parameters:

Name Type Description Default
tr_shapes DataFrame[WranglerShapesTable]

transit shapes from shapes.txt to validate foreign key to.

required
rd_links_df DataFrame[RoadLinksTable]

Links dataframe from roadway network to validate

required

Returns:

Type Description
DataFrame

df with shape_id and A, B, as well as whatever other columns were in tr_shapes

Source code in network_wrangler/transit/validate.py
def shape_links_without_road_links(
    tr_shapes: DataFrame[WranglerShapesTable],
    rd_links_df: DataFrame[RoadLinksTable],
) -> pd.DataFrame:
    """Validate that links in transit shapes exist in referenced roadway links.

    Args:
        tr_shapes: transit shapes from shapes.txt to validate foreign key to.
        rd_links_df: Links dataframe from roadway network to validate

    Returns:
        df with shape_id and A, B, as well as whatever other columns were in tr_shapes
    """
    WranglerLogger.debug(
        f"shape_links_without_road_links(): tr_shapes.head():\n{tr_shapes.head()}"
    )
    tr_shape_links = unique_shape_links(tr_shapes)
    WranglerLogger.debug(
        f"shape_links_without_road_links(): tr_shape_links.head():\n{tr_shape_links.head()}"
    )
    rd_links_transit_ok = rd_links_df[
        (rd_links_df["drive_access"])
        | (rd_links_df["bus_only"])
        | (rd_links_df["rail_only"])
        | (rd_links_df["ferry_only"] if "ferry_only" in rd_links_df.columns else False)
    ]

    merged_df = tr_shape_links.merge(
        rd_links_transit_ok[["A", "B"]],
        how="left",
        on=["A", "B"],
        indicator=True,
    )

    missing_links_df = merged_df.loc[merged_df._merge == "left_only"]
    if len(missing_links_df):
        WranglerLogger.error(
            f"! Transit shape links missing in roadway network: \n {missing_links_df}"
        )
    return missing_links_df
stop_times_without_road_links(tr_stop_times, rd_links_df)

Validate that links in transit shapes exist in referenced roadway links.

Parameters:

Name Type Description Default
tr_stop_times DataFrame[WranglerStopTimesTable]

transit stop_times from stop_times.txt to validate foreign key to.

required
rd_links_df DataFrame[RoadLinksTable]

Links dataframe from roadway network to validate

required

Returns:

Type Description
DataFrame

df with shape_id and A, B

Source code in network_wrangler/transit/validate.py
def stop_times_without_road_links(
    tr_stop_times: DataFrame[WranglerStopTimesTable],
    rd_links_df: DataFrame[RoadLinksTable],
) -> pd.DataFrame:
    """Validate that links in transit shapes exist in referenced roadway links.

    Args:
        tr_stop_times: transit stop_times from stop_times.txt to validate foreign key to.
        rd_links_df: Links dataframe from roadway network to validate

    Returns:
        df with shape_id and A, B
    """
    tr_links = unique_stop_time_links(tr_stop_times)

    rd_links_transit_ok = rd_links_df[
        (rd_links_df["drive_access"]) | (rd_links_df["bus_only"]) | (rd_links_df["rail_only"])
    ]

    merged_df = tr_links.merge(
        rd_links_transit_ok[["A", "B"]],
        how="left",
        on=["A", "B"],
        indicator=True,
    )

    missing_links_df = merged_df.loc[merged_df._merge == "left_only", ["trip_id", "A", "B"]]
    if len(missing_links_df):
        WranglerLogger.error(
            f"! Transit stop_time links missing in roadway network: \n {missing_links_df}"
        )
    return missing_links_df[["trip_id", "A", "B"]]

network_wrangler.transit.validate.transit_nodes_without_road_nodes

transit_nodes_without_road_nodes(
    feed, nodes_df, rd_field="model_node_id"
)

Validate all of a transit feeds node foreign keys exist in referenced roadway nodes.

Parameters:

Name Type Description Default
feed Feed

Transit Feed to query.

required
nodes_df DataFrame

Nodes dataframe from roadway network to validate foreign key to. Defaults to self.roadway_net.nodes_df

required
rd_field str

field in roadway nodes to check against. Defaults to “model_node_id”

'model_node_id'

Returns:

Type Description
list[int]

boolean indicating if relationships are all valid

Source code in network_wrangler/transit/validate.py
def transit_nodes_without_road_nodes(
    feed: Feed,
    nodes_df: DataFrame[RoadNodesTable],
    rd_field: str = "model_node_id",
) -> list[int]:
    """Validate all of a transit feeds node foreign keys exist in referenced roadway nodes.

    Args:
        feed: Transit Feed to query.
        nodes_df (pd.DataFrame, optional): Nodes dataframe from roadway network to validate
            foreign key to. Defaults to self.roadway_net.nodes_df
        rd_field: field in roadway nodes to check against. Defaults to "model_node_id"

    Returns:
        boolean indicating if relationships are all valid
    """
    feed_nodes_series = [
        feed.stops["stop_id"],
        feed.shapes["shape_model_node_id"],
        feed.stop_times["stop_id"],
    ]
    tr_nodes = set(concat_with_attr(feed_nodes_series).unique())
    rd_nodes = set(nodes_df[rd_field].unique().tolist())
    # nodes in tr_nodes but not rd_nodes
    missing_tr_nodes = list(tr_nodes - rd_nodes)

    if missing_tr_nodes:
        WranglerLogger.error(
            f"! Transit nodes in missing in roadway network: \n {missing_tr_nodes}"
        )
    return missing_tr_nodes

network_wrangler.transit.validate.transit_road_net_consistency

transit_road_net_consistency(feed, road_net)

Checks foreign key and network link relationships between transit feed and a road_net.

Parameters:

Name Type Description Default
feed Feed

Transit Feed.

required
road_net RoadwayNetwork

Roadway network to check relationship with.

required

Returns:

Name Type Description
bool bool

boolean indicating if road_net is consistent with transit network.

Source code in network_wrangler/transit/validate.py
def transit_road_net_consistency(feed: Feed, road_net: RoadwayNetwork) -> bool:
    """Checks foreign key and network link relationships between transit feed and a road_net.

    Args:
        feed: Transit Feed.
        road_net (RoadwayNetwork): Roadway network to check relationship with.

    Returns:
        bool: boolean indicating if road_net is consistent with transit network.
    """
    _missing_links = shape_links_without_road_links(feed.shapes, road_net.links_df)
    _missing_nodes = transit_nodes_without_road_nodes(feed, road_net.nodes_df)
    _consistency = _missing_links.empty and not _missing_nodes
    return _consistency

network_wrangler.transit.validate.validate_transit_in_dir

validate_transit_in_dir(
    dir,
    file_format="txt",
    road_dir=None,
    road_file_format="geojson",
)

Validates a roadway network in a directory to the wrangler data model specifications.

Parameters:

Name Type Description Default
dir Path

The transit network file directory.

required
file_format str

The format of roadway network file name. Defaults to “txt”.

'txt'
road_dir Path

The roadway network file directory. Defaults to None.

None
road_file_format str

The format of roadway network file name. Defaults to “geojson”.

'geojson'
output_dir str

The output directory for the validation report. Defaults to “.”.

required
Source code in network_wrangler/transit/validate.py
def validate_transit_in_dir(
    dir: Path,
    file_format: TransitFileTypes = "txt",
    road_dir: Path | None = None,
    road_file_format: RoadwayFileTypes = "geojson",
) -> bool:
    """Validates a roadway network in a directory to the wrangler data model specifications.

    Args:
        dir (Path): The transit network file directory.
        file_format (str): The format of roadway network file name. Defaults to "txt".
        road_dir (Path): The roadway network file directory. Defaults to None.
        road_file_format (str): The format of roadway network file name. Defaults to "geojson".
        output_dir (str): The output directory for the validation report. Defaults to ".".
    """
    from .io import load_transit

    try:
        t = load_transit(dir, file_format=file_format)
    except SchemaErrors as e:
        WranglerLogger.error(f"!!! [Transit Network invalid] - Failed Loading to Feed object\n{e}")
        return False
    if road_dir is not None:
        from ..roadway import load_roadway_from_dir
        from .network import TransitRoadwayConsistencyError

        try:
            r = load_roadway_from_dir(road_dir, file_format=road_file_format)
        except FileNotFoundError:
            WranglerLogger.error(f"! Roadway network not found in {road_dir}")
            return False
        except Exception as e:
            WranglerLogger.error(
                f"! Error loading roadway network. \
                                 Skipping validation of road to transit network.\n{e}"
            )
        try:
            t.road_net = r
        except TransitRoadwayConsistencyError as e:
            WranglerLogger.error(
                f"!!! [Tranit Network inconsistent] Error in road to transit \
                                 network consistency.\n{e}"
            )
            return False

    return True