Skip to content

API Documentation

Common Usage

Base Objects

Scenario objects manage how a collection of projects is applied to the networks.

Scenarios are built from a base scenario and a list of project cards.

A project card is a YAML file (or similar) that describes a change to the network. The project card can contain multiple changes, each of which is applied to the network in sequence.

Create a Scenario

Instantiate a scenario by seeding it with a base scenario and optionally some project cards.

from network_wrangler import create_scenario

my_scenario = create_scenario(
    base_scenario=my_base_year_scenario,
    card_search_dir=project_card_directory,
    filter_tags=["baseline2050"],
)

A base_year_scenario is a dictionary representation of key components of a scenario:

  • road_net: RoadwayNetwork instance
  • transit_net: TransitNetwork instance
  • applied_projects: list of projects that have been applied to the base scenario so that the scenario knows if there will be conflicts with future projects or if a future project’s pre-requisite is satisfied.
  • conflicts: dictionary of conflicts for project that have been applied to the base scenario so that the scenario knows if there will be conflicts with future projects.
my_base_year_scenario = {
    "road_net": load_from_roadway_dir(STPAUL_DIR),
    "transit_net": load_transit(STPAUL_DIR),
    "applied_projects": [],
    "conflicts": {},
}

Add Projects to a Scenario

In addition to adding projects when you create the scenario, project cards can be added to a scenario using the add_project_cards method.

from projectcard import read_cards

project_card_dict = read_cards(card_location, filter_tags=["Baseline2030"], recursive=True)
my_scenario.add_project_cards(project_card_dict.values())

Where card_location can be a single path, list of paths, a directory, or a glob pattern.

Apply Projects to a Scenario

Projects can be applied to a scenario using the apply_all_projects method. Before applying projects, the scenario will check that all pre-requisites are satisfied, that there are no conflicts, and that the projects are in the planned projects list.

If you want to check the order of projects before applying them, you can use the queued_projects prooperty.

my_scenario.queued_projects
my_scenario.apply_all_projects()

You can review the resulting scenario, roadway network, and transit networks.

my_scenario.applied_projects
my_scenario.road_net.links_gdf.explore()
my_scenario.transit_net.feed.shapes_gdf.explore()

Write a Scenario to Disk

Scenarios (and their networks) can be written to disk using the write method which in addition to writing out roadway and transit networks, will serialize the scenario to a yaml-like file and can also write out the project cards that have been applied.

my_scenario.write(
    "output_dir",
    "scenario_name_to_use",
    overwrite=True,
    projects_write=True,
    file_format="parquet",
)
Example Serialized Scenario File
applied_projects: &id001
- project a
- project b
base_scenario:
applied_projects: *id001
roadway:
    dir: /Users/elizabeth/Documents/urbanlabs/MetCouncil/NetworkWrangler/working/network_wrangler/examples/small
    file_format: geojson
transit:
    dir: /Users/elizabeth/Documents/urbanlabs/MetCouncil/NetworkWrangler/working/network_wrangler/examples/small
config:
CPU:
    EST_PD_READ_SPEED:
    csv: 0.03
    geojson: 0.03
    json: 0.15
    parquet: 0.005
    txt: 0.04
IDS:
    ML_LINK_ID_METHOD: range
    ML_LINK_ID_RANGE: &id002 !!python/tuple
    - 950000
    - 999999
    ML_LINK_ID_SCALAR: 15000
    ML_NODE_ID_METHOD: range
    ML_NODE_ID_RANGE: *id002
    ML_NODE_ID_SCALAR: 15000
    ROAD_SHAPE_ID_METHOD: scalar
    ROAD_SHAPE_ID_SCALAR: 1000
    TRANSIT_SHAPE_ID_METHOD: scalar
    TRANSIT_SHAPE_ID_SCALAR: 1000000
MODEL_ROADWAY:
    ADDITIONAL_COPY_FROM_GP_TO_ML: []
    ADDITIONAL_COPY_TO_ACCESS_EGRESS: []
    ML_OFFSET_METERS: -10
conflicts: {}
corequisites: {}
name: first_scenario
prerequisites: {}
roadway:
dir: /Users/elizabeth/Documents/urbanlabs/MetCouncil/NetworkWrangler/working/network_wrangler/tests/out/first_scenario/roadway
file_format: parquet
transit:
dir: /Users/elizabeth/Documents/urbanlabs/MetCouncil/NetworkWrangler/working/network_wrangler/tests/out/first_scenario/transit
file_format: txt

Load a scenario from disk

And if you want to reload scenario that you “wrote”, you can use the load_scenario function.

from network_wrangler import load_scenario

my_scenario = load_scenario("output_dir/scenario_name_to_use_scenario.yml")

BASE_SCENARIO_SUGGESTED_PROPS: list[str] = ['road_net', 'transit_net', 'applied_projects', 'conflicts'] module-attribute

List of card types that that will be applied to the transit network.

ROADWAY_CARD_TYPES: list[str] = ['roadway_property_change', 'roadway_deletion', 'roadway_addition', 'pycode'] module-attribute

List of card types that that will be applied to the transit network AFTER being applied to the roadway network.

TRANSIT_CARD_TYPES: list[str] = ['transit_property_change', 'transit_routing_change', 'transit_route_addition', 'transit_service_deletion'] module-attribute

List of card types that that will be applied to the roadway network.

Scenario

Holds information about a scenario.

Typical usage example:

my_base_year_scenario = {
    "road_net": load_roadway(
        links_file=STPAUL_LINK_FILE,
        nodes_file=STPAUL_NODE_FILE,
        shapes_file=STPAUL_SHAPE_FILE,
    ),
    "transit_net": load_transit(STPAUL_DIR),
}

# create a future baseline scenario from base by searching for all cards in dir w/ baseline tag
project_card_directory = Path(STPAUL_DIR) / "project_cards"
my_scenario = create_scenario(
    base_scenario=my_base_year_scenario,
    card_search_dir=project_card_directory,
    filter_tags=["baseline2050"],
)

# check project card queue and then apply the projects
my_scenario.queued_projects
my_scenario.apply_all_projects()

# check applied projects, write it out, and create a summary report.
my_scenario.applied_projects
my_scenario.write("baseline")
my_scenario.summary

# Add some projects to create a build scenario based on a list of files.
build_card_filenames = [
    "3_multiple_roadway_attribute_change.yml",
    "road.prop_changes.segment.yml",
    "4_simple_managed_lane.yml",
]
my_scenario.add_projects_from_files(build_card_filenames)
my_scenario.write("build2050")
my_scenario.summary

Attributes:

Name Type Description
base_scenario dict

dictionary representation of a scenario

road_net Optional[RoadwayNetwork]

instance of RoadwayNetwork for the scenario

transit_net Optional[TransitNetwork]

instance of TransitNetwork for the scenario

project_cards dict[str, ProjectCard]

Mapping[ProjectCard.name,ProjectCard] Storage of all project cards by name.

queued_projects

Projects which are “shovel ready” - have had pre-requisits checked and done any required re-ordering. Similar to a git staging, project cards aren’t recognized in this collecton once they are moved to applied.

applied_projects list[str]

list of project names that have been applied

projects

list of all projects either planned, queued, or applied

prerequisites dict[str, list[str]]

dictionary storing prerequiste info as projectA: [prereqs-for-projectA]

corequisites dict[str, list[str]]

dictionary storing corequisite info asprojectA: [coreqs-for-projectA]

conflicts dict[str, list[str]]

dictionary storing conflict info as projectA: [conflicts-for-projectA]

config

WranglerConfig instance.

Source code in network_wrangler/scenario.py
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
class Scenario:
    """Holds information about a scenario.

    Typical usage example:

    ```python
    my_base_year_scenario = {
        "road_net": load_roadway(
            links_file=STPAUL_LINK_FILE,
            nodes_file=STPAUL_NODE_FILE,
            shapes_file=STPAUL_SHAPE_FILE,
        ),
        "transit_net": load_transit(STPAUL_DIR),
    }

    # create a future baseline scenario from base by searching for all cards in dir w/ baseline tag
    project_card_directory = Path(STPAUL_DIR) / "project_cards"
    my_scenario = create_scenario(
        base_scenario=my_base_year_scenario,
        card_search_dir=project_card_directory,
        filter_tags=["baseline2050"],
    )

    # check project card queue and then apply the projects
    my_scenario.queued_projects
    my_scenario.apply_all_projects()

    # check applied projects, write it out, and create a summary report.
    my_scenario.applied_projects
    my_scenario.write("baseline")
    my_scenario.summary

    # Add some projects to create a build scenario based on a list of files.
    build_card_filenames = [
        "3_multiple_roadway_attribute_change.yml",
        "road.prop_changes.segment.yml",
        "4_simple_managed_lane.yml",
    ]
    my_scenario.add_projects_from_files(build_card_filenames)
    my_scenario.write("build2050")
    my_scenario.summary
    ```

    Attributes:
        base_scenario: dictionary representation of a scenario
        road_net: instance of RoadwayNetwork for the scenario
        transit_net: instance of TransitNetwork for the scenario
        project_cards: Mapping[ProjectCard.name,ProjectCard] Storage of all project cards by name.
        queued_projects: Projects which are "shovel ready" - have had pre-requisits checked and
            done any required re-ordering. Similar to a git staging, project cards aren't
            recognized in this collecton once they are moved to applied.
        applied_projects: list of project names that have been applied
        projects: list of all projects either planned, queued, or applied
        prerequisites:  dictionary storing prerequiste info as `projectA: [prereqs-for-projectA]`
        corequisites:  dictionary storing corequisite info as`projectA: [coreqs-for-projectA]`
        conflicts: dictionary storing conflict info as `projectA: [conflicts-for-projectA]`
        config: WranglerConfig instance.
    """

    def __init__(
        self,
        base_scenario: Union[Scenario, dict],
        project_card_list: Optional[list[ProjectCard]] = None,
        config: Optional[Union[WranglerConfig, dict, Path, list[Path]]] = None,
        name: str = "",
    ):
        """Constructor.

        Args:
            base_scenario: A base scenario object to base this isntance off of, or a dict which
                describes the scenario attributes including applied projects and respective
                conflicts. `{"applied_projects": [],"conflicts":{...}}`
            project_card_list: Optional list of ProjectCard instances to add to planned projects.
                Defaults to None.
            config: WranglerConfig instance or a dictionary of configuration settings or a path to
                one or more configuration files. Configurations that are not explicity set will
                default to the values in the default configuration in
                `/configs/wrangler/default.yml`.
            name: Optional name for the scenario.
        """
        WranglerLogger.info("Creating Scenario")
        self.config = load_wrangler_config(config)

        if project_card_list is None:
            project_card_list = []

        if isinstance(base_scenario, Scenario):
            base_scenario = base_scenario.__dict__

        self.base_scenario: dict = extract_base_scenario_metadata(base_scenario)

        if not set(BASE_SCENARIO_SUGGESTED_PROPS) <= set(base_scenario.keys()):
            WranglerLogger.warning(
                f"Base_scenario doesn't contain {BASE_SCENARIO_SUGGESTED_PROPS}"
            )
        self.name: str = name
        # if the base scenario had roadway or transit networks, use them as the basis.
        self.road_net: Optional[RoadwayNetwork] = copy.deepcopy(
            base_scenario.pop("road_net", None)
        )

        self.transit_net: Optional[TransitNetwork] = copy.deepcopy(
            base_scenario.pop("transit_net", None)
        )
        if self.road_net and self.transit_net:
            self.transit_net.road_net = self.road_net

        # Set configs for networks to be the same as scenario.
        if isinstance(self.road_net, RoadwayNetwork):
            self.road_net.config = self.config
        if isinstance(self.transit_net, TransitNetwork):
            self.transit_net.config = self.config

        self.project_cards: dict[str, ProjectCard] = {}
        self._planned_projects: list[str] = []
        self._queued_projects = None
        self.applied_projects: list[str] = base_scenario.pop("applied_projects", [])

        self.prerequisites: dict[str, list[str]] = base_scenario.pop("prerequisites", {})
        self.corequisites: dict[str, list[str]] = base_scenario.pop("corequisites", {})
        self.conflicts: dict[str, list[str]] = base_scenario.pop("conflicts", {})

        for p in project_card_list:
            self._add_project(p)

    @property
    def projects(self):
        """Returns a list of all projects in the scenario: applied and planned."""
        return self.applied_projects + self._planned_projects

    @property
    def queued_projects(self):
        """Returns a list version of _queued_projects queue.

        Queued projects are thos that have been planned, have all pre-requisites satisfied, and
        have been ordered based on pre-requisites.

        If no queued projects, will dynamically generate from planned projects based on
        pre-requisites and return the queue.
        """
        if not self._queued_projects:
            self._check_projects_requirements_satisfied(self._planned_projects)
            self._queued_projects = self.order_projects(self._planned_projects)
        return list(self._queued_projects)

    def __str__(self):
        """String representation of the Scenario object."""
        s = [f"{key}: {value}" for key, value in self.__dict__.items()]
        return "\n".join(s)

    def _add_dependencies(self, project_name, dependencies: dict) -> None:
        """Add dependencies from a project card to relevant scenario variables.

        Updates existing "prerequisites", "corequisites" and "conflicts".
        Lowercases everything to enable string matching.

        Args:
            project_name: name of project you are adding dependencies for.
            dependencies: Dictionary of depndencies by dependency type and list of associated
                projects.
        """
        project_name = project_name.lower()

        for d, v in dependencies.items():
            _dep = list(map(str.lower, v))
            WranglerLogger.debug(f"Adding {_dep} to {project_name} dependency table.")
            self.__dict__[d].update({project_name: _dep})

    def _add_project(
        self,
        project_card: ProjectCard,
        validate: bool = True,
        filter_tags: Optional[list[str]] = None,
    ) -> None:
        """Adds a single ProjectCard instances to the Scenario.

        Checks that a project of same name is not already in scenario.
        If selected, will validate ProjectCard before adding.
        If provided, will only add ProjectCard if it matches at least one filter_tags.

        Resets scenario queued_projects.

        Args:
            project_card (ProjectCard): ProjectCard instance to add to scenario.
            validate (bool, optional): If True, will validate the projectcard before
                being adding it to the scenario. Defaults to True.
            filter_tags: If used, will only add the project card if
                its tags match one or more of these filter_tags. Defaults to []
                which means no tag-filtering will occur.

        """
        filter_tags = filter_tags or []
        project_name = project_card.project.lower()
        filter_tags = list(map(str.lower, filter_tags))

        if project_name in self.projects:
            msg = f"Names not unique from existing scenario projects: {project_card.project}"
            raise ProjectCardError(msg)

        if filter_tags and set(project_card.tags).isdisjoint(set(filter_tags)):
            WranglerLogger.debug(
                f"Skipping {project_name} - no overlapping tags with {filter_tags}."
            )
            return

        if validate:
            project_card.validate()

        WranglerLogger.info(f"Adding {project_name} to scenario.")
        self.project_cards[project_name] = project_card
        self._planned_projects.append(project_name)
        self._queued_projects = None
        self._add_dependencies(project_name, project_card.dependencies)

    def add_project_cards(
        self,
        project_card_list: list[ProjectCard],
        validate: bool = True,
        filter_tags: Optional[list[str]] = None,
    ) -> None:
        """Adds a list of ProjectCard instances to the Scenario.

        Checks that a project of same name is not already in scenario.
        If selected, will validate ProjectCard before adding.
        If provided, will only add ProjectCard if it matches at least one filter_tags.

        Args:
            project_card_list: List of ProjectCard instances to add to
                scenario.
            validate (bool, optional): If True, will require each ProjectCard is validated before
                being added to scenario. Defaults to True.
            filter_tags: If used, will filter ProjectCard instances
                and only add those whose tags match one or more of these filter_tags.
                Defaults to [] - which means no tag-filtering will occur.
        """
        filter_tags = filter_tags or []
        for p in project_card_list:
            self._add_project(p, validate=validate, filter_tags=filter_tags)

    def _check_projects_requirements_satisfied(self, project_list: list[str]):
        """Checks all requirements are satisified to apply this specific set of projects.

        Including:
        1. has an associaed project card
        2. is in scenario's planned projects
        3. pre-requisites satisfied
        4. co-requisies satisfied by applied or co-applied projects
        5. no conflicing applied or co-applied projects

        Args:
            project_list: list of projects to check requirements for.
        """
        self._check_projects_planned(project_list)
        self._check_projects_have_project_cards(project_list)
        self._check_projects_prerequisites(project_list)
        self._check_projects_corequisites(project_list)
        self._check_projects_conflicts(project_list)

    def _check_projects_planned(self, project_names: list[str]) -> None:
        """Checks that a list of projects are in the scenario's planned projects."""
        _missing_ps = [p for p in project_names if p not in self._planned_projects]
        if _missing_ps:
            msg = f"Projects are not in planned projects: \n {_missing_ps}. \
                Add them by using add_project_cards()."
            WranglerLogger.debug(msg)
            raise ValueError(msg)

    def _check_projects_have_project_cards(self, project_list: list[str]) -> bool:
        """Checks that a list of projects has an associated project card in the scenario."""
        _missing = [p for p in project_list if p not in self.project_cards]
        if _missing:
            WranglerLogger.error(
                f"Projects referenced which are missing project cards: {_missing}"
            )
            return False
        return True

    def _check_projects_prerequisites(self, project_names: list[str]) -> None:
        """Check a list of projects' pre-requisites have been or will be applied to scenario."""
        if set(project_names).isdisjoint(set(self.prerequisites.keys())):
            return
        _prereqs = []
        for p in project_names:
            _prereqs += self.prerequisites.get(p, [])
        _projects_applied = self.applied_projects + project_names
        _missing = list(set(_prereqs) - set(_projects_applied))
        if _missing:
            WranglerLogger.debug(
                f"project_names: {project_names}\nprojects_have_or_will_be_applied: \
                    {_projects_applied}\nmissing: {_missing}"
            )
            msg = f"Missing {len(_missing)} pre-requisites."
            raise ScenarioPrerequisiteError(msg)

    def _check_projects_corequisites(self, project_names: list[str]) -> None:
        """Check a list of projects' co-requisites have been or will be applied to scenario."""
        if set(project_names).isdisjoint(set(self.corequisites.keys())):
            return
        _coreqs = []
        for p in project_names:
            _coreqs += self.corequisites.get(p, [])
        _projects_applied = self.applied_projects + project_names
        _missing = list(set(_coreqs) - set(_projects_applied))
        if _missing:
            WranglerLogger.debug(
                f"project_names: {project_names}\nprojects_have_or_will_be_applied: \
                    {_projects_applied}\nmissing: {_missing}"
            )
            msg = f"Missing {len(_missing)} corequisites."
            raise ScenarioCorequisiteError(msg)

    def _check_projects_conflicts(self, project_names: list[str]) -> None:
        """Checks that list of projects' conflicts have not been or will be applied to scenario."""
        # WranglerLogger.debug("Checking Conflicts...")
        projects_to_check = project_names + self.applied_projects
        # WranglerLogger.debug(f"\nprojects_to_check:{projects_to_check}\nprojects_with_conflicts:{set(self.conflicts.keys())}")
        if set(projects_to_check).isdisjoint(set(self.conflicts.keys())):
            # WranglerLogger.debug("Projects have no conflicts to check")
            return
        _conflicts = []
        for p in project_names:
            _conflicts += self.conflicts.get(p, [])
        _conflict_problems = [p for p in _conflicts if p in projects_to_check]
        if _conflict_problems:
            WranglerLogger.warning(f"Conflict Problems: \n{_conflict_problems}")
            _conf_dict = {
                k: v
                for k, v in self.conflicts.items()
                if k in projects_to_check and not set(v).isdisjoint(set(_conflict_problems))
            }
            WranglerLogger.debug(f"Problematic Conflicts: \n{_conf_dict}")
            msg = f"Found {len(_conflict_problems)} conflicts: {_conflict_problems}"
            raise ScenarioConflictError(msg)

    def order_projects(self, project_list: list[str]) -> deque:
        """Orders a list of projects based on moving up pre-requisites into a deque.

        Args:
            project_list: list of projects to order

        Returns: deque for applying projects.
        """
        project_list = [p.lower() for p in project_list]
        assert self._check_projects_have_project_cards(project_list)

        # build prereq (adjacency) list for topological sort
        adjacency_list: dict[str, list] = defaultdict(list)
        visited_list: dict[str, bool] = defaultdict(bool)

        for project in project_list:
            visited_list[project] = False
            if not self.prerequisites.get(project):
                continue
            for prereq in self.prerequisites[project]:
                # this will always be true, else would have been flagged in missing \
                # prerequsite check, but just in case
                if prereq.lower() in project_list:
                    if adjacency_list.get(prereq.lower()):
                        adjacency_list[prereq.lower()].append(project)
                    else:
                        adjacency_list[prereq.lower()] = [project]

        # sorted_project_names is topological sorted project card names (based on prerequsiite)
        _ordered_projects = topological_sort(
            adjacency_list=adjacency_list, visited_list=visited_list
        )

        if set(_ordered_projects) != set(project_list):
            _missing = list(set(project_list) - set(_ordered_projects))
            msg = f"Project sort resulted in missing projects: {_missing}"
            raise ValueError(msg)

        project_deque = deque(_ordered_projects)

        WranglerLogger.debug(f"Ordered Projects: \n{project_deque}")

        return project_deque

    def apply_all_projects(self):
        """Applies all planned projects in the queue."""
        # Call this to make sure projects are appropriately queued in hidden variable.
        self.queued_projects  # noqa: B018

        # Use hidden variable.
        while self._queued_projects:
            self._apply_project(self._queued_projects.popleft())

        # set this so it will trigger re-queuing any more projects.
        self._queued_projects = None

    def _apply_change(self, change: Union[ProjectCard, SubProject]) -> None:
        """Applies a specific change specified in a project card.

        Change type must be in at least one of:
        - ROADWAY_CARD_TYPES
        - TRANSIT_CARD_TYPES

        Args:
            change: a project card or subproject card
        """
        if change.change_type in ROADWAY_CARD_TYPES:
            if not self.road_net:
                msg = "Missing Roadway Network"
                raise ValueError(msg)
            if change.change_type in SECONDARY_TRANSIT_CARD_TYPES and self.transit_net:
                self.road_net.apply(change, transit_net=self.transit_net)
            else:
                self.road_net.apply(change)
        if change.change_type in TRANSIT_CARD_TYPES:
            if not self.transit_net:
                msg = "Missing Transit Network"
                raise ValueError(msg)
            self.transit_net.apply(change)

        if change.change_type not in ROADWAY_CARD_TYPES + TRANSIT_CARD_TYPES:
            msg = f"Project {change.project}: Don't understand project cat: {change.change_type}"
            raise ProjectCardError(msg)

    def _apply_project(self, project_name: str) -> None:
        """Applies project card to scenario.

        If a list of changes is specified in referenced project card, iterates through each change.

        Args:
            project_name (str): name of project to be applied.
        """
        project_name = project_name.lower()

        WranglerLogger.info(f"Applying {project_name} from file:\
                            {self.project_cards[project_name].file}")

        p = self.project_cards[project_name]
        WranglerLogger.debug(f"types: {p.change_types}")
        WranglerLogger.debug(f"type: {p.change_type}")
        if p._sub_projects:
            for sp in p._sub_projects:
                WranglerLogger.debug(f"- applying subproject: {sp.change_type}")
                self._apply_change(sp)

        else:
            self._apply_change(p)

        self._planned_projects.remove(project_name)
        self.applied_projects.append(project_name)

    def apply_projects(self, project_list: list[str]):
        """Applies a specific list of projects from the planned project queue.

        Will order the list of projects based on pre-requisites.

        NOTE: does not check co-requisites b/c that isn't possible when applying a single project.

        Args:
            project_list: List of projects to be applied. All need to be in the planned project
                queue.
        """
        project_list = [p.lower() for p in project_list]

        self._check_projects_requirements_satisfied(project_list)
        ordered_project_queue = self.order_projects(project_list)

        while ordered_project_queue:
            self._apply_project(ordered_project_queue.popleft())

        # Set so that when called again it will retrigger queueing from planned projects.
        self._ordered_projects = None

    def write(
        self,
        path: Path,
        name: str,
        overwrite: bool = True,
        roadway_write: bool = True,
        transit_write: bool = True,
        projects_write: bool = True,
        roadway_convert_complex_link_properties_to_single_field: bool = False,
        roadway_out_dir: Optional[Path] = None,
        roadway_prefix: Optional[str] = None,
        roadway_file_format: RoadwayFileTypes = "parquet",
        roadway_true_shape: bool = False,
        transit_out_dir: Optional[Path] = None,
        transit_prefix: Optional[str] = None,
        transit_file_format: TransitFileTypes = "txt",
        projects_out_dir: Optional[Path] = None,
    ) -> Path:
        """Writes scenario networks and summary to disk and returns path to scenario file.

        Args:
            path: Path to write scenario networks and scenario summary to.
            name: Name to use.
            overwrite: If True, will overwrite the files if they already exist.
            roadway_write: If True, will write out the roadway network.
            transit_write: If True, will write out the transit network.
            projects_write: If True, will write out the project cards.
            roadway_convert_complex_link_properties_to_single_field: If True, will convert complex
                link properties to a single field.
            roadway_out_dir: Path to write the roadway network files to.
            roadway_prefix: Prefix to add to the file name.
            roadway_file_format: File format to write the roadway network to
            roadway_true_shape: If True, will write the true shape of the roadway network
            transit_out_dir: Path to write the transit network files to.
            transit_prefix: Prefix to add to the file name.
            transit_file_format: File format to write the transit network to
            projects_out_dir: Path to write the project cards to.
        """
        path = Path(path)
        path.mkdir(parents=True, exist_ok=True)

        if self.road_net and roadway_write:
            if roadway_out_dir is None:
                roadway_out_dir = path / "roadway"
            roadway_out_dir.mkdir(parents=True, exist_ok=True)

            write_roadway(
                net=self.road_net,
                out_dir=roadway_out_dir,
                prefix=roadway_prefix or name,
                convert_complex_link_properties_to_single_field=roadway_convert_complex_link_properties_to_single_field,
                file_format=roadway_file_format,
                true_shape=roadway_true_shape,
                overwrite=overwrite,
            )
        if self.transit_net and transit_write:
            if transit_out_dir is None:
                transit_out_dir = path / "transit"
            transit_out_dir.mkdir(parents=True, exist_ok=True)
            write_transit(
                self.transit_net,
                out_dir=transit_out_dir,
                prefix=transit_prefix or name,
                file_format=transit_file_format,
                overwrite=overwrite,
            )
        if projects_write:
            if projects_out_dir is None:
                projects_out_dir = path / "projects"
            write_applied_projects(
                self,
                out_dir=projects_out_dir,
                overwrite=overwrite,
            )

        scenario_data = self.summary
        if transit_write:
            scenario_data["transit"] = {
                "dir": str(transit_out_dir),
                "file_format": transit_file_format,
            }
        if roadway_write:
            scenario_data["roadway"] = {
                "dir": str(roadway_out_dir),
                "file_format": roadway_file_format,
            }
        if projects_write:
            scenario_data["project_cards"] = {"dir": str(projects_out_dir)}
        scenario_file_path = Path(path) / f"{name}_scenario.yml"
        with scenario_file_path.open("w") as f:
            yaml.dump(scenario_data, f, default_flow_style=False, allow_unicode=True)
        return scenario_file_path

    @property
    def summary(self) -> dict:
        """A high level summary of the created scenario and public attributes."""
        skip = ["road_net", "base_scenario", "transit_net", "project_cards", "config"]
        summary_dict = {
            k: v for k, v in self.__dict__.items() if not k.startswith("_") and k not in skip
        }
        summary_dict["config"] = self.config.to_dict()

        """
        # Handle nested dictionary for "base_scenario"
        skip_base = ["project_cards"]
        if "base_scenario" in self.__dict__:
            base_summary_dict = {
                k: v
                for k, v in self.base_scenario.items()
                if not k.startswith("_") and k not in skip_base
            }
            summary_dict["base_scenario"] = base_summary_dict
        """

        return summary_dict

projects property

Returns a list of all projects in the scenario: applied and planned.

queued_projects property

Returns a list version of _queued_projects queue.

Queued projects are thos that have been planned, have all pre-requisites satisfied, and have been ordered based on pre-requisites.

If no queued projects, will dynamically generate from planned projects based on pre-requisites and return the queue.

summary: dict property

A high level summary of the created scenario and public attributes.

__init__(base_scenario, project_card_list=None, config=None, name='')

Constructor.

Parameters:

Name Type Description Default
base_scenario Union[Scenario, dict]

A base scenario object to base this isntance off of, or a dict which describes the scenario attributes including applied projects and respective conflicts. {"applied_projects": [],"conflicts":{...}}

required
project_card_list Optional[list[ProjectCard]]

Optional list of ProjectCard instances to add to planned projects. Defaults to None.

None
config Optional[Union[WranglerConfig, dict, Path, list[Path]]]

WranglerConfig instance or a dictionary of configuration settings or a path to one or more configuration files. Configurations that are not explicity set will default to the values in the default configuration in /configs/wrangler/default.yml.

None
name str

Optional name for the scenario.

''
Source code in network_wrangler/scenario.py
def __init__(
    self,
    base_scenario: Union[Scenario, dict],
    project_card_list: Optional[list[ProjectCard]] = None,
    config: Optional[Union[WranglerConfig, dict, Path, list[Path]]] = None,
    name: str = "",
):
    """Constructor.

    Args:
        base_scenario: A base scenario object to base this isntance off of, or a dict which
            describes the scenario attributes including applied projects and respective
            conflicts. `{"applied_projects": [],"conflicts":{...}}`
        project_card_list: Optional list of ProjectCard instances to add to planned projects.
            Defaults to None.
        config: WranglerConfig instance or a dictionary of configuration settings or a path to
            one or more configuration files. Configurations that are not explicity set will
            default to the values in the default configuration in
            `/configs/wrangler/default.yml`.
        name: Optional name for the scenario.
    """
    WranglerLogger.info("Creating Scenario")
    self.config = load_wrangler_config(config)

    if project_card_list is None:
        project_card_list = []

    if isinstance(base_scenario, Scenario):
        base_scenario = base_scenario.__dict__

    self.base_scenario: dict = extract_base_scenario_metadata(base_scenario)

    if not set(BASE_SCENARIO_SUGGESTED_PROPS) <= set(base_scenario.keys()):
        WranglerLogger.warning(
            f"Base_scenario doesn't contain {BASE_SCENARIO_SUGGESTED_PROPS}"
        )
    self.name: str = name
    # if the base scenario had roadway or transit networks, use them as the basis.
    self.road_net: Optional[RoadwayNetwork] = copy.deepcopy(
        base_scenario.pop("road_net", None)
    )

    self.transit_net: Optional[TransitNetwork] = copy.deepcopy(
        base_scenario.pop("transit_net", None)
    )
    if self.road_net and self.transit_net:
        self.transit_net.road_net = self.road_net

    # Set configs for networks to be the same as scenario.
    if isinstance(self.road_net, RoadwayNetwork):
        self.road_net.config = self.config
    if isinstance(self.transit_net, TransitNetwork):
        self.transit_net.config = self.config

    self.project_cards: dict[str, ProjectCard] = {}
    self._planned_projects: list[str] = []
    self._queued_projects = None
    self.applied_projects: list[str] = base_scenario.pop("applied_projects", [])

    self.prerequisites: dict[str, list[str]] = base_scenario.pop("prerequisites", {})
    self.corequisites: dict[str, list[str]] = base_scenario.pop("corequisites", {})
    self.conflicts: dict[str, list[str]] = base_scenario.pop("conflicts", {})

    for p in project_card_list:
        self._add_project(p)

__str__()

String representation of the Scenario object.

Source code in network_wrangler/scenario.py
def __str__(self):
    """String representation of the Scenario object."""
    s = [f"{key}: {value}" for key, value in self.__dict__.items()]
    return "\n".join(s)

add_project_cards(project_card_list, validate=True, filter_tags=None)

Adds a list of ProjectCard instances to the Scenario.

Checks that a project of same name is not already in scenario. If selected, will validate ProjectCard before adding. If provided, will only add ProjectCard if it matches at least one filter_tags.

Parameters:

Name Type Description Default
project_card_list list[ProjectCard]

List of ProjectCard instances to add to scenario.

required
validate bool

If True, will require each ProjectCard is validated before being added to scenario. Defaults to True.

True
filter_tags Optional[list[str]]

If used, will filter ProjectCard instances and only add those whose tags match one or more of these filter_tags. Defaults to [] - which means no tag-filtering will occur.

None
Source code in network_wrangler/scenario.py
def add_project_cards(
    self,
    project_card_list: list[ProjectCard],
    validate: bool = True,
    filter_tags: Optional[list[str]] = None,
) -> None:
    """Adds a list of ProjectCard instances to the Scenario.

    Checks that a project of same name is not already in scenario.
    If selected, will validate ProjectCard before adding.
    If provided, will only add ProjectCard if it matches at least one filter_tags.

    Args:
        project_card_list: List of ProjectCard instances to add to
            scenario.
        validate (bool, optional): If True, will require each ProjectCard is validated before
            being added to scenario. Defaults to True.
        filter_tags: If used, will filter ProjectCard instances
            and only add those whose tags match one or more of these filter_tags.
            Defaults to [] - which means no tag-filtering will occur.
    """
    filter_tags = filter_tags or []
    for p in project_card_list:
        self._add_project(p, validate=validate, filter_tags=filter_tags)

apply_all_projects()

Applies all planned projects in the queue.

Source code in network_wrangler/scenario.py
def apply_all_projects(self):
    """Applies all planned projects in the queue."""
    # Call this to make sure projects are appropriately queued in hidden variable.
    self.queued_projects  # noqa: B018

    # Use hidden variable.
    while self._queued_projects:
        self._apply_project(self._queued_projects.popleft())

    # set this so it will trigger re-queuing any more projects.
    self._queued_projects = None

apply_projects(project_list)

Applies a specific list of projects from the planned project queue.

Will order the list of projects based on pre-requisites.

NOTE: does not check co-requisites b/c that isn’t possible when applying a single project.

Parameters:

Name Type Description Default
project_list list[str]

List of projects to be applied. All need to be in the planned project queue.

required
Source code in network_wrangler/scenario.py
def apply_projects(self, project_list: list[str]):
    """Applies a specific list of projects from the planned project queue.

    Will order the list of projects based on pre-requisites.

    NOTE: does not check co-requisites b/c that isn't possible when applying a single project.

    Args:
        project_list: List of projects to be applied. All need to be in the planned project
            queue.
    """
    project_list = [p.lower() for p in project_list]

    self._check_projects_requirements_satisfied(project_list)
    ordered_project_queue = self.order_projects(project_list)

    while ordered_project_queue:
        self._apply_project(ordered_project_queue.popleft())

    # Set so that when called again it will retrigger queueing from planned projects.
    self._ordered_projects = None

order_projects(project_list)

Orders a list of projects based on moving up pre-requisites into a deque.

Parameters:

Name Type Description Default
project_list list[str]

list of projects to order

required
Source code in network_wrangler/scenario.py
def order_projects(self, project_list: list[str]) -> deque:
    """Orders a list of projects based on moving up pre-requisites into a deque.

    Args:
        project_list: list of projects to order

    Returns: deque for applying projects.
    """
    project_list = [p.lower() for p in project_list]
    assert self._check_projects_have_project_cards(project_list)

    # build prereq (adjacency) list for topological sort
    adjacency_list: dict[str, list] = defaultdict(list)
    visited_list: dict[str, bool] = defaultdict(bool)

    for project in project_list:
        visited_list[project] = False
        if not self.prerequisites.get(project):
            continue
        for prereq in self.prerequisites[project]:
            # this will always be true, else would have been flagged in missing \
            # prerequsite check, but just in case
            if prereq.lower() in project_list:
                if adjacency_list.get(prereq.lower()):
                    adjacency_list[prereq.lower()].append(project)
                else:
                    adjacency_list[prereq.lower()] = [project]

    # sorted_project_names is topological sorted project card names (based on prerequsiite)
    _ordered_projects = topological_sort(
        adjacency_list=adjacency_list, visited_list=visited_list
    )

    if set(_ordered_projects) != set(project_list):
        _missing = list(set(project_list) - set(_ordered_projects))
        msg = f"Project sort resulted in missing projects: {_missing}"
        raise ValueError(msg)

    project_deque = deque(_ordered_projects)

    WranglerLogger.debug(f"Ordered Projects: \n{project_deque}")

    return project_deque

write(path, name, overwrite=True, roadway_write=True, transit_write=True, projects_write=True, roadway_convert_complex_link_properties_to_single_field=False, roadway_out_dir=None, roadway_prefix=None, roadway_file_format='parquet', roadway_true_shape=False, transit_out_dir=None, transit_prefix=None, transit_file_format='txt', projects_out_dir=None)

Writes scenario networks and summary to disk and returns path to scenario file.

Parameters:

Name Type Description Default
path Path

Path to write scenario networks and scenario summary to.

required
name str

Name to use.

required
overwrite bool

If True, will overwrite the files if they already exist.

True
roadway_write bool

If True, will write out the roadway network.

True
transit_write bool

If True, will write out the transit network.

True
projects_write bool

If True, will write out the project cards.

True
roadway_convert_complex_link_properties_to_single_field bool

If True, will convert complex link properties to a single field.

False
roadway_out_dir Optional[Path]

Path to write the roadway network files to.

None
roadway_prefix Optional[str]

Prefix to add to the file name.

None
roadway_file_format RoadwayFileTypes

File format to write the roadway network to

'parquet'
roadway_true_shape bool

If True, will write the true shape of the roadway network

False
transit_out_dir Optional[Path]

Path to write the transit network files to.

None
transit_prefix Optional[str]

Prefix to add to the file name.

None
transit_file_format TransitFileTypes

File format to write the transit network to

'txt'
projects_out_dir Optional[Path]

Path to write the project cards to.

None
Source code in network_wrangler/scenario.py
def write(
    self,
    path: Path,
    name: str,
    overwrite: bool = True,
    roadway_write: bool = True,
    transit_write: bool = True,
    projects_write: bool = True,
    roadway_convert_complex_link_properties_to_single_field: bool = False,
    roadway_out_dir: Optional[Path] = None,
    roadway_prefix: Optional[str] = None,
    roadway_file_format: RoadwayFileTypes = "parquet",
    roadway_true_shape: bool = False,
    transit_out_dir: Optional[Path] = None,
    transit_prefix: Optional[str] = None,
    transit_file_format: TransitFileTypes = "txt",
    projects_out_dir: Optional[Path] = None,
) -> Path:
    """Writes scenario networks and summary to disk and returns path to scenario file.

    Args:
        path: Path to write scenario networks and scenario summary to.
        name: Name to use.
        overwrite: If True, will overwrite the files if they already exist.
        roadway_write: If True, will write out the roadway network.
        transit_write: If True, will write out the transit network.
        projects_write: If True, will write out the project cards.
        roadway_convert_complex_link_properties_to_single_field: If True, will convert complex
            link properties to a single field.
        roadway_out_dir: Path to write the roadway network files to.
        roadway_prefix: Prefix to add to the file name.
        roadway_file_format: File format to write the roadway network to
        roadway_true_shape: If True, will write the true shape of the roadway network
        transit_out_dir: Path to write the transit network files to.
        transit_prefix: Prefix to add to the file name.
        transit_file_format: File format to write the transit network to
        projects_out_dir: Path to write the project cards to.
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    if self.road_net and roadway_write:
        if roadway_out_dir is None:
            roadway_out_dir = path / "roadway"
        roadway_out_dir.mkdir(parents=True, exist_ok=True)

        write_roadway(
            net=self.road_net,
            out_dir=roadway_out_dir,
            prefix=roadway_prefix or name,
            convert_complex_link_properties_to_single_field=roadway_convert_complex_link_properties_to_single_field,
            file_format=roadway_file_format,
            true_shape=roadway_true_shape,
            overwrite=overwrite,
        )
    if self.transit_net and transit_write:
        if transit_out_dir is None:
            transit_out_dir = path / "transit"
        transit_out_dir.mkdir(parents=True, exist_ok=True)
        write_transit(
            self.transit_net,
            out_dir=transit_out_dir,
            prefix=transit_prefix or name,
            file_format=transit_file_format,
            overwrite=overwrite,
        )
    if projects_write:
        if projects_out_dir is None:
            projects_out_dir = path / "projects"
        write_applied_projects(
            self,
            out_dir=projects_out_dir,
            overwrite=overwrite,
        )

    scenario_data = self.summary
    if transit_write:
        scenario_data["transit"] = {
            "dir": str(transit_out_dir),
            "file_format": transit_file_format,
        }
    if roadway_write:
        scenario_data["roadway"] = {
            "dir": str(roadway_out_dir),
            "file_format": roadway_file_format,
        }
    if projects_write:
        scenario_data["project_cards"] = {"dir": str(projects_out_dir)}
    scenario_file_path = Path(path) / f"{name}_scenario.yml"
    with scenario_file_path.open("w") as f:
        yaml.dump(scenario_data, f, default_flow_style=False, allow_unicode=True)
    return scenario_file_path

build_scenario_from_config(scenario_config)

Builds a scenario from a dictionary configuration.

Parameters:

Name Type Description Default
scenario_config Union[Path, list[Path], ScenarioConfig, dict]

Path to a configuration file, list of paths, or a dictionary of configuration.

required
Source code in network_wrangler/scenario.py
def build_scenario_from_config(
    scenario_config: Union[Path, list[Path], ScenarioConfig, dict],
) -> Scenario:
    """Builds a scenario from a dictionary configuration.

    Args:
        scenario_config: Path to a configuration file, list of paths, or a dictionary of
            configuration.
    """
    WranglerLogger.info(f"Building Scenario from Configuration: {scenario_config}")
    scenario_config = load_scenario_config(scenario_config)
    WranglerLogger.debug(f"{pprint.pformat(scenario_config)}")

    base_scenario = create_base_scenario(
        **scenario_config.base_scenario.to_dict(), config=scenario_config.wrangler_config
    )

    my_scenario = create_scenario(
        base_scenario=base_scenario,
        config=scenario_config.wrangler_config,
        **scenario_config.projects.to_dict(),
    )

    my_scenario.apply_all_projects()

    write_args = _scenario_output_config_to_scenario_write(scenario_config.output_scenario)
    my_scenario.write(**write_args, name=scenario_config.name)
    return my_scenario

create_base_scenario(roadway=None, transit=None, applied_projects=None, conflicts=None, config=DefaultConfig)

Creates a base scenario dictionary from roadway and transit network files.

Parameters:

Name Type Description Default
roadway Optional[dict]

kwargs for load_roadway_from_dir

None
transit Optional[dict]

kwargs for load_transit from dir

None
applied_projects Optional[list]

list of projects that have been applied to the base scenario.

None
conflicts Optional[dict]

dictionary of conflicts that have been identified in the base scenario. Takes the format of {"projectA": ["projectB", "projectC"]} showing that projectA, which has been applied, conflicts with projectB and projectC and so they shouldn’t be applied in the future.

None
config WranglerConfig

WranglerConfig instance.

DefaultConfig
Source code in network_wrangler/scenario.py
def create_base_scenario(
    roadway: Optional[dict] = None,
    transit: Optional[dict] = None,
    applied_projects: Optional[list] = None,
    conflicts: Optional[dict] = None,
    config: WranglerConfig = DefaultConfig,
) -> dict:
    """Creates a base scenario dictionary from roadway and transit network files.

    Args:
        roadway: kwargs for load_roadway_from_dir
        transit: kwargs for load_transit from dir
        applied_projects: list of projects that have been applied to the base scenario.
        conflicts: dictionary of conflicts that have been identified in the base scenario.
            Takes the format of `{"projectA": ["projectB", "projectC"]}` showing that projectA,
            which has been applied, conflicts with projectB and projectC and so they shouldn't be
            applied in the future.
        config: WranglerConfig instance.
    """
    applied_projects = applied_projects or []
    conflicts = conflicts or {}
    if roadway:
        road_net = load_roadway_from_dir(**roadway, config=config)
    else:
        road_net = None
        WranglerLogger.info(
            "No roadway directory specified, base scenario will have empty roadway network."
        )

    if transit:
        transit_net = load_transit(**transit, config=config)
        if roadway:
            transit_net.road_net = road_net
    else:
        transit_net = None
        WranglerLogger.info(
            "No transit directory specified, base scenario will have empty transit network."
        )

    base_scenario = {
        "road_net": road_net,
        "transit_net": transit_net,
        "applied_projects": applied_projects,
        "conflicts": conflicts,
    }

    return base_scenario

create_scenario(base_scenario=None, name=datetime.now().strftime('%Y%m%d%H%M%S'), project_card_list=None, project_card_filepath=None, filter_tags=None, config=None)

Creates scenario from a base scenario and adds project cards.

Project cards can be added using any/all of the following methods: 1. List of ProjectCard instances 2. List of ProjectCard files 3. Directory and optional glob search to find project card files in

Checks that a project of same name is not already in scenario. If selected, will validate ProjectCard before adding. If provided, will only add ProjectCard if it matches at least one filter_tags.

Parameters:

Name Type Description Default
base_scenario Optional[Union[Scenario, dict]]

base Scenario scenario instances of dictionary of attributes.

None
name str

Optional name for the scenario. Defaults to current datetime.

strftime('%Y%m%d%H%M%S')
project_card_list

List of ProjectCard instances to create Scenario from. Defaults to [].

None
project_card_filepath Optional[Union[list[Path], Path]]

where the project card is. A single path, list of paths,

None
filter_tags Optional[list[str]]

If used, will only add the project card if its tags match one or more of these filter_tags. Defaults to [] which means no tag-filtering will occur.

None
config Optional[Union[dict, Path, list[Path], WranglerConfig]]

Optional wrangler configuration file or dictionary or instance. Defaults to default config.

None
Source code in network_wrangler/scenario.py
def create_scenario(
    base_scenario: Optional[Union[Scenario, dict]] = None,
    name: str = datetime.now().strftime("%Y%m%d%H%M%S"),
    project_card_list=None,
    project_card_filepath: Optional[Union[list[Path], Path]] = None,
    filter_tags: Optional[list[str]] = None,
    config: Optional[Union[dict, Path, list[Path], WranglerConfig]] = None,
) -> Scenario:
    """Creates scenario from a base scenario and adds project cards.

    Project cards can be added using any/all of the following methods:
    1. List of ProjectCard instances
    2. List of ProjectCard files
    3. Directory and optional glob search to find project card files in

    Checks that a project of same name is not already in scenario.
    If selected, will validate ProjectCard before adding.
    If provided, will only add ProjectCard if it matches at least one filter_tags.

    Args:
        base_scenario: base Scenario scenario instances of dictionary of attributes.
        name: Optional name for the scenario. Defaults to current datetime.
        project_card_list: List of ProjectCard instances to create Scenario from. Defaults
            to [].
        project_card_filepath: where the project card is.  A single path, list of paths,
        a directory, or a glob pattern. Defaults to None.
        filter_tags: If used, will only add the project card if
            its tags match one or more of these filter_tags. Defaults to []
            which means no tag-filtering will occur.
        config: Optional wrangler configuration file or dictionary or instance. Defaults to
            default config.
    """
    base_scenario = base_scenario or {}
    project_card_list = project_card_list or []
    filter_tags = filter_tags or []

    scenario = Scenario(base_scenario, config=config, name=name)

    if project_card_filepath:
        project_card_list += list(
            read_cards(project_card_filepath, filter_tags=filter_tags).values()
        )

    if project_card_list:
        scenario.add_project_cards(project_card_list, filter_tags=filter_tags)

    return scenario

extract_base_scenario_metadata(base_scenario)

Extract metadata from base scenario rather than keeping all of big files.

Useful for summarizing a scenario.

Source code in network_wrangler/scenario.py
def extract_base_scenario_metadata(base_scenario: dict) -> dict:
    """Extract metadata from base scenario rather than keeping all of big files.

    Useful for summarizing a scenario.
    """
    _skip_copy = ["road_net", "transit_net", "config"]
    out_dict = {k: v for k, v in base_scenario.items() if k not in _skip_copy}
    if isinstance(base_scenario.get("road_net"), RoadwayNetwork):
        nodes_file_path = base_scenario["road_net"].nodes_df.attrs.get("source_file", None)
        if nodes_file_path is not None:
            out_dict["roadway"] = {
                "dir": str(Path(nodes_file_path).parent),
                "file_format": str(nodes_file_path.suffix).lstrip("."),
            }
    if isinstance(base_scenario.get("transit_net"), TransitNetwork):
        feed_path = base_scenario["transit_net"].feed.feed_path
        if feed_path is not None:
            out_dict["transit"] = {"dir": str(feed_path)}
    return out_dict

load_scenario(scenario_data, name=datetime.now().strftime('%Y%m%d%H%M%S'))

Loads a scenario from a file written by Scenario.write() as the base scenario.

Parameters:

Name Type Description Default
scenario_data Union[dict, Path]

Scenario data as a dict or path to scenario data file

required
name str

Optional name for the scenario. Defaults to current datetime.

strftime('%Y%m%d%H%M%S')
Source code in network_wrangler/scenario.py
def load_scenario(
    scenario_data: Union[dict, Path],
    name: str = datetime.now().strftime("%Y%m%d%H%M%S"),
) -> Scenario:
    """Loads a scenario from a file written by Scenario.write() as the base scenario.

    Args:
        scenario_data: Scenario data as a dict or path to scenario data file
        name: Optional name for the scenario. Defaults to current datetime.
    """
    if not isinstance(scenario_data, dict):
        WranglerLogger.debug(f"Loading Scenario from file: {scenario_data}")
        scenario_data = load_dict(scenario_data)
    else:
        WranglerLogger.debug("Loading Scenario from dict.")

    base_scenario_data = {
        "roadway": scenario_data.get("roadway"),
        "transit": scenario_data.get("transit"),
        "applied_projects": scenario_data.get("applied_projects", []),
        "conflicts": scenario_data.get("conflicts", {}),
    }
    base_scenario = _load_base_scenario_from_config(
        base_scenario_data, config=scenario_data["config"]
    )
    my_scenario = create_scenario(
        base_scenario=base_scenario, name=name, config=scenario_data["config"]
    )
    return my_scenario

write_applied_projects(scenario, out_dir, overwrite=True)

Summarizes all projects in a scenario to folder.

Parameters:

Name Type Description Default
scenario Scenario

Scenario instance to summarize.

required
out_dir Path

Path to write the project cards.

required
overwrite bool

If True, will overwrite the files if they already exist.

True
Source code in network_wrangler/scenario.py
def write_applied_projects(scenario: Scenario, out_dir: Path, overwrite: bool = True) -> None:
    """Summarizes all projects in a scenario to folder.

    Args:
        scenario: Scenario instance to summarize.
        out_dir: Path to write the project cards.
        overwrite: If True, will overwrite the files if they already exist.
    """
    outdir = Path(out_dir)
    prep_dir(out_dir, overwrite=overwrite)

    for p in scenario.applied_projects:
        if p in scenario.project_cards:
            card = scenario.project_cards[p]
        elif p in scenario.base_scenario["project_cards"]:
            card = scenario.base_scenario["project_cards"][p]
        else:
            continue
        filename = Path(card.__dict__.get("file", f"{p}.yml")).name
        outpath = outdir / filename
        write_card(card, outpath)

Roadway Network class and functions for Network Wrangler.

Used to represent a roadway network and perform operations on it.

Usage:

from network_wrangler import load_roadway_from_dir, write_roadway

net = load_roadway_from_dir("my_dir")
net.get_selection({"links": [{"name": ["I 35E"]}]})
net.apply("my_project_card.yml")

write_roadway(net, "my_out_prefix", "my_dir", file_format="parquet")

RoadwayNetwork

Bases: BaseModel

Representation of a Roadway Network.

Typical usage example:

net = load_roadway(
    links_file=MY_LINK_FILE,
    nodes_file=MY_NODE_FILE,
    shapes_file=MY_SHAPE_FILE,
)
my_selection = {
    "link": [{"name": ["I 35E"]}],
    "A": {"osm_node_id": "961117623"},  # start searching for segments at A
    "B": {"osm_node_id": "2564047368"},
}
net.get_selection(my_selection)

my_change = [
    {
        'property': 'lanes',
        'existing': 1,
        'set': 2,
    },
    {
        'property': 'drive_access',
        'set': 0,
    },
]

my_net.apply_roadway_feature_change(
    my_net.get_selection(my_selection),
    my_change
)

    net.model_net
    net.is_network_connected(mode="drive", nodes=self.m_nodes_df, links=self.m_links_df)
    _, disconnected_nodes = net.assess_connectivity(
        mode="walk",
        ignore_end_nodes=True,
        nodes=self.m_nodes_df,
        links=self.m_links_df
    )
    write_roadway(net,filename=my_out_prefix, path=my_dir, for_model = True)

Attributes:

Name Type Description
nodes_df RoadNodesTable

dataframe of of node records.

links_df RoadLinksTable

dataframe of link records and associated properties.

shapes_df RoadShapestable

data from of detailed shape records This is lazily created iff it is called because shapes files can be expensive to read.

_selections dict

dictionary of stored roadway selection objects, mapped by RoadwayLinkSelection.sel_key or RoadwayNodeSelection.sel_key in case they are made repeatedly.

network_hash str

dynamic property of the hashed value of links_df and nodes_df. Used for quickly identifying if a network has changed since various expensive operations have taken place (i.e. generating a ModelRoadwayNetwork or a network graph)

model_net ModelRoadwayNetwork

referenced ModelRoadwayNetwork object which will be lazily created if None or if the network_hash has changed.

config WranglerConfig

wrangler configuration object

Source code in network_wrangler/roadway/network.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
class RoadwayNetwork(BaseModel):
    """Representation of a Roadway Network.

    Typical usage example:

    ```py
    net = load_roadway(
        links_file=MY_LINK_FILE,
        nodes_file=MY_NODE_FILE,
        shapes_file=MY_SHAPE_FILE,
    )
    my_selection = {
        "link": [{"name": ["I 35E"]}],
        "A": {"osm_node_id": "961117623"},  # start searching for segments at A
        "B": {"osm_node_id": "2564047368"},
    }
    net.get_selection(my_selection)

    my_change = [
        {
            'property': 'lanes',
            'existing': 1,
            'set': 2,
        },
        {
            'property': 'drive_access',
            'set': 0,
        },
    ]

    my_net.apply_roadway_feature_change(
        my_net.get_selection(my_selection),
        my_change
    )

        net.model_net
        net.is_network_connected(mode="drive", nodes=self.m_nodes_df, links=self.m_links_df)
        _, disconnected_nodes = net.assess_connectivity(
            mode="walk",
            ignore_end_nodes=True,
            nodes=self.m_nodes_df,
            links=self.m_links_df
        )
        write_roadway(net,filename=my_out_prefix, path=my_dir, for_model = True)
    ```

    Attributes:
        nodes_df (RoadNodesTable): dataframe of of node records.
        links_df (RoadLinksTable): dataframe of link records and associated properties.
        shapes_df (RoadShapestable): data from of detailed shape records  This is lazily
            created iff it is called because shapes files can be expensive to read.
        _selections (dict): dictionary of stored roadway selection objects, mapped by
            `RoadwayLinkSelection.sel_key` or `RoadwayNodeSelection.sel_key` in case they are
                made repeatedly.
        network_hash: dynamic property of the hashed value of links_df and nodes_df. Used for
            quickly identifying if a network has changed since various expensive operations have
            taken place (i.e. generating a ModelRoadwayNetwork or a network graph)
        model_net (ModelRoadwayNetwork): referenced `ModelRoadwayNetwork` object which will be
            lazily created if None or if the `network_hash` has changed.
        config (WranglerConfig): wrangler configuration object
    """

    nodes_df: DataFrame[RoadNodesTable]
    links_df: DataFrame[RoadLinksTable]
    _shapes_df: Optional[DataFrame[RoadShapesTable]] = None

    _links_file: Optional[Path] = None
    _nodes_file: Optional[Path] = None
    _shapes_file: Optional[Path] = None

    config: WranglerConfig = DefaultConfig

    _model_net: Optional[ModelRoadwayNetwork] = None
    _selections: dict[str, Selections] = {}
    _modal_graphs: dict[str, dict] = defaultdict(lambda: {"graph": None, "hash": None})

    @field_validator("config")
    def validate_config(cls, v):
        """Validate config."""
        return load_wrangler_config(v)

    @field_validator("nodes_df", "links_df")
    def coerce_crs(cls, v):
        """Coerce crs of nodes_df and links_df to LAT_LON_CRS."""
        if v.crs != LAT_LON_CRS:
            WranglerLogger.warning(
                f"CRS of links_df ({v.crs}) doesn't match network crs {LAT_LON_CRS}. \
                    Changing to network crs."
            )
            v.to_crs(LAT_LON_CRS)
        return v

    @property
    def shapes_df(self) -> DataFrame[RoadShapesTable]:
        """Load and return RoadShapesTable.

        If not already loaded, will read from shapes_file and return. If shapes_file is None,
        will return an empty dataframe with the right schema. If shapes_df is already set, will
        return that.
        """
        if (self._shapes_df is None or self._shapes_df.empty) and self._shapes_file is not None:
            self._shapes_df = read_shapes(
                self._shapes_file,
                filter_to_shape_ids=self.links_df.shape_id.to_list(),
                config=self.config,
            )
        # if there is NONE, then at least create an empty dataframe with right schema
        elif self._shapes_df is None:
            self._shapes_df = empty_df_from_datamodel(RoadShapesTable)
            self._shapes_df.set_index("shape_id_idx", inplace=True)

        return self._shapes_df

    @shapes_df.setter
    def shapes_df(self, value):
        self._shapes_df = df_to_shapes_df(value, config=self.config)

    @property
    def network_hash(self) -> str:
        """Hash of the links and nodes dataframes."""
        _value = str.encode(self.links_df.df_hash() + "-" + self.nodes_df.df_hash())

        _hash = hashlib.sha256(_value).hexdigest()
        return _hash

    @property
    def model_net(self) -> ModelRoadwayNetwork:
        """Return a ModelRoadwayNetwork object for this network."""
        if self._model_net is None or self._model_net._net_hash != self.network_hash:
            self._model_net = ModelRoadwayNetwork(self)
        return self._model_net

    @property
    def summary(self) -> dict:
        """Quick summary dictionary of number of links, nodes."""
        d = {
            "links": len(self.links_df),
            "nodes": len(self.nodes_df),
        }
        return d

    @property
    def link_shapes_df(self) -> gpd.GeoDataFrame:
        """Add shape geometry to links if available.

        returns: shapes merged to links dataframe
        """
        _links_df = copy.deepcopy(self.links_df)
        link_shapes_df = _links_df.merge(
            self.shapes_df,
            left_on="shape_id",
            right_on="shape_id",
            how="left",
        )
        link_shapes_df["geometry"] = link_shapes_df["geometry_y"].combine_first(
            link_shapes_df["geometry_x"]
        )
        link_shapes_df = link_shapes_df.drop(columns=["geometry_x", "geometry_y"])
        link_shapes_df = link_shapes_df.set_geometry("geometry")
        return link_shapes_df

    def get_property_by_timespan_and_group(
        self,
        link_property: str,
        category: Optional[Union[str, int]] = DEFAULT_CATEGORY,
        timespan: Optional[TimespanString] = DEFAULT_TIMESPAN,
        strict_timespan_match: bool = False,
        min_overlap_minutes: int = 60,
    ) -> Any:
        """Returns a new dataframe with model_link_id and link property by category and timespan.

        Convenience method for backward compatability.

        Args:
            link_property: link property to query
            category: category to query or a list of categories. Defaults to DEFAULT_CATEGORY.
            timespan: timespan to query in the form of ["HH:MM","HH:MM"].
                Defaults to DEFAULT_TIMESPAN.
            strict_timespan_match: If True, will only return links that match the timespan exactly.
                Defaults to False.
            min_overlap_minutes: If strict_timespan_match is False, will return links that overlap
                with the timespan by at least this many minutes. Defaults to 60.
        """
        from .links.scopes import prop_for_scope

        return prop_for_scope(
            self.links_df,
            link_property,
            timespan=timespan,
            category=category,
            strict_timespan_match=strict_timespan_match,
            min_overlap_minutes=min_overlap_minutes,
        )

    def get_selection(
        self,
        selection_dict: Union[dict, SelectFacility],
        overwrite: bool = False,
    ) -> Union[RoadwayNodeSelection, RoadwayLinkSelection]:
        """Return selection if it already exists, otherwise performs selection.

        Args:
            selection_dict (dict): SelectFacility dictionary.
            overwrite: if True, will overwrite any previously cached searches. Defaults to False.
        """
        key = _create_selection_key(selection_dict)
        if (key in self._selections) and not overwrite:
            WranglerLogger.debug(f"Using cached selection from key: {key}")
            return self._selections[key]

        if isinstance(selection_dict, SelectFacility):
            selection_data = selection_dict
        elif isinstance(selection_dict, SelectLinksDict):
            selection_data = SelectFacility(links=selection_dict)
        elif isinstance(selection_dict, SelectNodesDict):
            selection_data = SelectFacility(nodes=selection_dict)
        elif isinstance(selection_dict, dict):
            selection_data = SelectFacility(**selection_dict)
        else:
            msg = "selection_dict arg must be a dictionary or SelectFacility model."
            WranglerLogger.error(
                msg + f" Received: {selection_dict} of type {type(selection_dict)}"
            )
            raise SelectionError(msg)

        WranglerLogger.debug(f"Getting selection from key: {key}")
        if "links" in selection_data.fields:
            return RoadwayLinkSelection(self, selection_dict)
        if "nodes" in selection_data.fields:
            return RoadwayNodeSelection(self, selection_dict)
        msg = "Selection data should have either 'links' or 'nodes'."
        WranglerLogger.error(msg + f" Received: {selection_dict}")
        raise SelectionError(msg)

    def modal_graph_hash(self, mode) -> str:
        """Hash of the links in order to detect a network change from when graph created."""
        _value = str.encode(self.links_df.df_hash() + "-" + mode)
        _hash = hashlib.sha256(_value).hexdigest()

        return _hash

    def get_modal_graph(self, mode) -> MultiDiGraph:
        """Return a networkx graph of the network for a specific mode.

        Args:
            mode: mode of the network, one of `drive`,`transit`,`walk`, `bike`
        """
        from .graph import net_to_graph

        if self._modal_graphs[mode]["hash"] != self.modal_graph_hash(mode):
            self._modal_graphs[mode]["graph"] = net_to_graph(self, mode)

        return self._modal_graphs[mode]["graph"]

    def apply(
        self,
        project_card: Union[ProjectCard, dict],
        transit_net: Optional[TransitNetwork] = None,
        **kwargs,
    ) -> RoadwayNetwork:
        """Wrapper method to apply a roadway project, returning a new RoadwayNetwork instance.

        Args:
            project_card: either a dictionary of the project card object or ProjectCard instance
            transit_net: optional transit network which will be used to if project requires as
                noted in `SECONDARY_TRANSIT_CARD_TYPES`.  If no transit network is provided, will
                skip anything related to transit network.
            **kwargs: keyword arguments to pass to project application
        """
        if not (isinstance(project_card, (ProjectCard, SubProject))):
            project_card = ProjectCard(project_card)

        # project_card.validate()
        if not project_card.valid:
            msg = f"Project card {project_card.project} not valid."
            WranglerLogger.error(msg)
            raise ProjectCardError(msg)

        if project_card._sub_projects:
            for sp in project_card._sub_projects:
                WranglerLogger.debug(f"- applying subproject: {sp.change_type}")
                self._apply_change(sp, transit_net=transit_net, **kwargs)
            return self
        return self._apply_change(project_card, transit_net=transit_net, **kwargs)

    def _apply_change(
        self,
        change: Union[ProjectCard, SubProject],
        transit_net: Optional[TransitNetwork] = None,
    ) -> RoadwayNetwork:
        """Apply a single change: a single-project project or a sub-project."""
        if not isinstance(change, SubProject):
            WranglerLogger.info(f"Applying Project to Roadway Network: {change.project}")

        if change.change_type == "roadway_property_change":
            return apply_roadway_property_change(
                self,
                self.get_selection(change.roadway_property_change["facility"]),
                change.roadway_property_change["property_changes"],
                project_name=change.project,
            )

        if change.change_type == "roadway_addition":
            return apply_new_roadway(
                self,
                change.roadway_addition,
                project_name=change.project,
            )

        if change.change_type == "roadway_deletion":
            return apply_roadway_deletion(
                self,
                change.roadway_deletion,
                transit_net=transit_net,
            )

        if change.change_type == "pycode":
            return apply_calculated_roadway(self, change.pycode)
        WranglerLogger.error(f"Couldn't find project in: \n{change.__dict__}")
        msg = f"Invalid Project Card Category: {change.change_type}"
        raise ProjectCardError(msg)

    def links_with_link_ids(self, link_ids: list[int]) -> DataFrame[RoadLinksTable]:
        """Return subset of links_df based on link_ids list."""
        return filter_links_to_ids(self.links_df, link_ids)

    def links_with_nodes(self, node_ids: list[int]) -> DataFrame[RoadLinksTable]:
        """Return subset of links_df based on node_ids list."""
        return filter_links_to_node_ids(self.links_df, node_ids)

    def nodes_in_links(self) -> DataFrame[RoadNodesTable]:
        """Returns subset of self.nodes_df that are in self.links_df."""
        return filter_nodes_to_links(self.links_df, self.nodes_df)

    def node_coords(self, model_node_id: int) -> tuple:
        """Return coordinates (x, y) of a node based on model_node_id."""
        try:
            node = self.nodes_df[self.nodes_df.model_node_id == model_node_id]
        except ValueError as err:
            msg = f"Node with model_node_id {model_node_id} not found."
            WranglerLogger.error(msg)
            raise NodeNotFoundError(msg) from err
        return node.geometry.x.values[0], node.geometry.y.values[0]

    def add_links(
        self,
        add_links_df: Union[pd.DataFrame, DataFrame[RoadLinksTable]],
        in_crs: int = LAT_LON_CRS,
    ):
        """Validate combined links_df with LinksSchema before adding to self.links_df.

        Args:
            add_links_df: Dataframe of additional links to add.
            in_crs: crs of input data. Defaults to LAT_LON_CRS.
        """
        dupe_recs = self.links_df.model_link_id.isin(add_links_df.model_link_id)

        if dupe_recs.any():
            dupe_ids = self.links_df.loc[dupe_recs, "model_link_id"]
            WranglerLogger.error(
                f"Cannot add links with model_link_id already in network: {dupe_ids}"
            )
            msg = "Cannot add links with model_link_id already in network."
            raise LinkAddError(msg)

        if add_links_df.attrs.get("name") != "road_links":
            add_links_df = data_to_links_df(add_links_df, nodes_df=self.nodes_df, in_crs=in_crs)
        self.links_df = validate_df_to_model(
            concat_with_attr([self.links_df, add_links_df], axis=0), RoadLinksTable
        )

    def add_nodes(
        self,
        add_nodes_df: Union[pd.DataFrame, DataFrame[RoadNodesTable]],
        in_crs: int = LAT_LON_CRS,
    ):
        """Validate combined nodes_df with NodesSchema before adding to self.nodes_df.

        Args:
            add_nodes_df: Dataframe of additional nodes to add.
            in_crs: crs of input data. Defaults to LAT_LON_CRS.
        """
        dupe_ids = self.nodes_df.model_node_id.isin(add_nodes_df.model_node_id)
        if dupe_ids.any():
            WranglerLogger.error(
                f"Cannot add nodes with model_node_id already in network: {dupe_ids}"
            )
            msg = "Cannot add nodes with model_node_id already in network."
            raise NodeAddError(msg)

        if add_nodes_df.attrs.get("name") != "road_nodes":
            add_nodes_df = data_to_nodes_df(add_nodes_df, in_crs=in_crs, config=self.config)
        self.nodes_df = validate_df_to_model(
            concat_with_attr([self.nodes_df, add_nodes_df], axis=0), RoadNodesTable
        )
        if self.nodes_df.attrs.get("name") != "road_nodes":
            msg = f"Expected nodes_df to have name 'road_nodes', got {self.nodes_df.attrs.get('name')}"
            raise NotNodesError(msg)

    def add_shapes(
        self,
        add_shapes_df: Union[pd.DataFrame, DataFrame[RoadShapesTable]],
        in_crs: int = LAT_LON_CRS,
    ):
        """Validate combined shapes_df with RoadShapesTable efore adding to self.shapes_df.

        Args:
            add_shapes_df: Dataframe of additional shapes to add.
            in_crs: crs of input data. Defaults to LAT_LON_CRS.
        """
        dupe_ids = self.shapes_df.shape_id.isin(add_shapes_df.shape_id)
        if dupe_ids.any():
            msg = "Cannot add shapes with shape_id already in network."
            WranglerLogger.error(msg + f"\nDuplicates: {dupe_ids}")
            raise ShapeAddError(msg)

        if add_shapes_df.attrs.get("name") != "road_shapes":
            add_shapes_df = df_to_shapes_df(add_shapes_df, in_crs=in_crs, config=self.config)

        WranglerLogger.debug(f"add_shapes_df: \n{add_shapes_df}")
        WranglerLogger.debug(f"self.shapes_df: \n{self.shapes_df}")

        self.shapes_df = validate_df_to_model(
            concat_with_attr([self.shapes_df, add_shapes_df], axis=0), RoadShapesTable
        )

    def delete_links(
        self,
        selection_dict: Union[dict, SelectLinksDict],
        clean_nodes: bool = False,
        clean_shapes: bool = False,
        transit_net: Optional[TransitNetwork] = None,
    ):
        """Deletes links based on selection dictionary and optionally associated nodes and shapes.

        Args:
            selection_dict (SelectLinks): Dictionary describing link selections as follows:
                `all`: Optional[bool] = False. If true, will select all.
                `name`: Optional[list[str]]
                `ref`: Optional[list[str]]
                `osm_link_id`:Optional[list[str]]
                `model_link_id`: Optional[list[int]]
                `modes`: Optional[list[str]]. Defaults to "any"
                `ignore_missing`: if true, will not error when defaults to True.
                ...plus any other link property to select on top of these.
            clean_nodes (bool, optional): If True, will clean nodes uniquely associated with
                deleted links. Defaults to False.
            clean_shapes (bool, optional): If True, will clean nodes uniquely associated with
                deleted links. Defaults to False.
            transit_net (TransitNetwork, optional): If provided, will check TransitNetwork and
                warn if deletion breaks transit shapes. Defaults to None.
        """
        if not isinstance(selection_dict, SelectLinksDict):
            selection_dict = SelectLinksDict(**selection_dict)
        selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
        selection = self.get_selection({"links": selection_dict})
        if isinstance(selection, RoadwayNodeSelection):
            msg = "Selection should be for links, but got nodes."
            raise SelectionError(msg)
        if clean_nodes:
            node_ids_to_delete = node_ids_unique_to_link_ids(
                selection.selected_links, selection.selected_links_df, self.nodes_df
            )
            WranglerLogger.debug(
                f"Dropping nodes associated with dropped links: \n{node_ids_to_delete}"
            )
            self.nodes_df = delete_nodes_by_ids(self.nodes_df, del_node_ids=node_ids_to_delete)

        if clean_shapes:
            shape_ids_to_delete = shape_ids_unique_to_link_ids(
                selection.selected_links, selection.selected_links_df, self.shapes_df
            )
            WranglerLogger.debug(
                f"Dropping shapes associated with dropped links: \n{shape_ids_to_delete}"
            )
            self.shapes_df = delete_shapes_by_ids(
                self.shapes_df, del_shape_ids=shape_ids_to_delete
            )

        self.links_df = delete_links_by_ids(
            self.links_df,
            selection.selected_links,
            ignore_missing=selection.ignore_missing,
            transit_net=transit_net,
        )

    def delete_nodes(
        self,
        selection_dict: Union[dict, SelectNodesDict],
        remove_links: bool = False,
    ) -> None:
        """Deletes nodes from roadway network. Wont delete nodes used by links in network.

        Args:
            selection_dict: dictionary of node selection criteria in the form of a SelectNodesDict.
            remove_links: if True, will remove any links that are associated with the nodes.
                If False, will only remove nodes if they are not associated with any links.
                Defaults to False.

        raises:
            NodeDeletionError: If not ignore_missing and selected nodes to delete aren't in network
        """
        if not isinstance(selection_dict, SelectNodesDict):
            selection_dict = SelectNodesDict(**selection_dict)
        selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
        _selection = self.get_selection({"nodes": selection_dict})
        assert isinstance(_selection, RoadwayNodeSelection)  # for mypy
        selection: RoadwayNodeSelection = _selection
        if remove_links:
            del_node_ids = selection.selected_nodes
            link_ids = self.links_with_nodes(selection.selected_nodes).model_link_id.to_list()
            WranglerLogger.info(f"Removing {len(link_ids)} links associated with nodes.")
            self.delete_links({"model_link_id": link_ids})
        else:
            unused_node_ids = node_ids_without_links(self.nodes_df, self.links_df)
            del_node_ids = list(set(selection.selected_nodes).intersection(unused_node_ids))

        self.nodes_df = delete_nodes_by_ids(
            self.nodes_df, del_node_ids, ignore_missing=selection.ignore_missing
        )

    def clean_unused_shapes(self):
        """Removes any unused shapes from network that aren't referenced by links_df."""
        from .shapes.shapes import shape_ids_without_links

        del_shape_ids = shape_ids_without_links(self.shapes_df, self.links_df)
        self.shapes_df = self.shapes_df.drop(del_shape_ids)

    def clean_unused_nodes(self):
        """Removes any unused nodes from network that aren't referenced by links_df.

        NOTE: does not check if these nodes are used by transit, so use with caution.
        """
        from .nodes.nodes import node_ids_without_links

        node_ids = node_ids_without_links(self.nodes_df, self.links_df)
        self.nodes_df = self.nodes_df.drop(node_ids)

    def move_nodes(
        self,
        node_geometry_change_table: DataFrame[NodeGeometryChangeTable],
    ):
        """Moves nodes based on updated geometry along with associated links and shape geometry.

        Args:
            node_geometry_change_table: a table with model_node_id, X, Y, and CRS.
        """
        node_geometry_change_table = NodeGeometryChangeTable(node_geometry_change_table)
        node_ids = node_geometry_change_table.model_node_id.to_list()
        WranglerLogger.debug(f"Moving nodes: {node_ids}")
        self.nodes_df = edit_node_geometry(self.nodes_df, node_geometry_change_table)
        self.links_df = edit_link_geometry_from_nodes(self.links_df, self.nodes_df, node_ids)
        self.shapes_df = edit_shape_geometry_from_nodes(
            self.shapes_df, self.links_df, self.nodes_df, node_ids
        )

    def has_node(self, model_node_id: int) -> bool:
        """Queries if network has node based on model_node_id.

        Args:
            model_node_id: model_node_id to check for.
        """
        has_node = self.nodes_df[self.nodes_df.model_node_id].isin([model_node_id]).any()

        return has_node

    def has_link(self, ab: tuple) -> bool:
        """Returns true if network has links with AB values.

        Args:
            ab: Tuple of values corresponding with A and B.
        """
        sel_a, sel_b = ab
        has_link = (
            self.links_df[self.links_df[["A", "B"]]].isin_dict({"A": sel_a, "B": sel_b}).any()
        )
        return has_link

    def is_connected(self, mode: str) -> bool:
        """Determines if the network graph is "strongly" connected.

        A graph is strongly connected if each vertex is reachable from every other vertex.

        Args:
            mode:  mode of the network, one of `drive`,`transit`,`walk`, `bike`
        """
        is_connected = nx.is_strongly_connected(self.get_modal_graph(mode))

        return is_connected

Add shape geometry to links if available.

returns: shapes merged to links dataframe

model_net: ModelRoadwayNetwork property

Return a ModelRoadwayNetwork object for this network.

network_hash: str property

Hash of the links and nodes dataframes.

shapes_df: DataFrame[RoadShapesTable] property writable

Load and return RoadShapesTable.

If not already loaded, will read from shapes_file and return. If shapes_file is None, will return an empty dataframe with the right schema. If shapes_df is already set, will return that.

summary: dict property

Quick summary dictionary of number of links, nodes.

Validate combined links_df with LinksSchema before adding to self.links_df.

Parameters:

Name Type Description Default
add_links_df Union[DataFrame, DataFrame[RoadLinksTable]]

Dataframe of additional links to add.

required
in_crs int

crs of input data. Defaults to LAT_LON_CRS.

LAT_LON_CRS
Source code in network_wrangler/roadway/network.py
def add_links(
    self,
    add_links_df: Union[pd.DataFrame, DataFrame[RoadLinksTable]],
    in_crs: int = LAT_LON_CRS,
):
    """Validate combined links_df with LinksSchema before adding to self.links_df.

    Args:
        add_links_df: Dataframe of additional links to add.
        in_crs: crs of input data. Defaults to LAT_LON_CRS.
    """
    dupe_recs = self.links_df.model_link_id.isin(add_links_df.model_link_id)

    if dupe_recs.any():
        dupe_ids = self.links_df.loc[dupe_recs, "model_link_id"]
        WranglerLogger.error(
            f"Cannot add links with model_link_id already in network: {dupe_ids}"
        )
        msg = "Cannot add links with model_link_id already in network."
        raise LinkAddError(msg)

    if add_links_df.attrs.get("name") != "road_links":
        add_links_df = data_to_links_df(add_links_df, nodes_df=self.nodes_df, in_crs=in_crs)
    self.links_df = validate_df_to_model(
        concat_with_attr([self.links_df, add_links_df], axis=0), RoadLinksTable
    )

add_nodes(add_nodes_df, in_crs=LAT_LON_CRS)

Validate combined nodes_df with NodesSchema before adding to self.nodes_df.

Parameters:

Name Type Description Default
add_nodes_df Union[DataFrame, DataFrame[RoadNodesTable]]

Dataframe of additional nodes to add.

required
in_crs int

crs of input data. Defaults to LAT_LON_CRS.

LAT_LON_CRS
Source code in network_wrangler/roadway/network.py
def add_nodes(
    self,
    add_nodes_df: Union[pd.DataFrame, DataFrame[RoadNodesTable]],
    in_crs: int = LAT_LON_CRS,
):
    """Validate combined nodes_df with NodesSchema before adding to self.nodes_df.

    Args:
        add_nodes_df: Dataframe of additional nodes to add.
        in_crs: crs of input data. Defaults to LAT_LON_CRS.
    """
    dupe_ids = self.nodes_df.model_node_id.isin(add_nodes_df.model_node_id)
    if dupe_ids.any():
        WranglerLogger.error(
            f"Cannot add nodes with model_node_id already in network: {dupe_ids}"
        )
        msg = "Cannot add nodes with model_node_id already in network."
        raise NodeAddError(msg)

    if add_nodes_df.attrs.get("name") != "road_nodes":
        add_nodes_df = data_to_nodes_df(add_nodes_df, in_crs=in_crs, config=self.config)
    self.nodes_df = validate_df_to_model(
        concat_with_attr([self.nodes_df, add_nodes_df], axis=0), RoadNodesTable
    )
    if self.nodes_df.attrs.get("name") != "road_nodes":
        msg = f"Expected nodes_df to have name 'road_nodes', got {self.nodes_df.attrs.get('name')}"
        raise NotNodesError(msg)

add_shapes(add_shapes_df, in_crs=LAT_LON_CRS)

Validate combined shapes_df with RoadShapesTable efore adding to self.shapes_df.

Parameters:

Name Type Description Default
add_shapes_df Union[DataFrame, DataFrame[RoadShapesTable]]

Dataframe of additional shapes to add.

required
in_crs int

crs of input data. Defaults to LAT_LON_CRS.

LAT_LON_CRS
Source code in network_wrangler/roadway/network.py
def add_shapes(
    self,
    add_shapes_df: Union[pd.DataFrame, DataFrame[RoadShapesTable]],
    in_crs: int = LAT_LON_CRS,
):
    """Validate combined shapes_df with RoadShapesTable efore adding to self.shapes_df.

    Args:
        add_shapes_df: Dataframe of additional shapes to add.
        in_crs: crs of input data. Defaults to LAT_LON_CRS.
    """
    dupe_ids = self.shapes_df.shape_id.isin(add_shapes_df.shape_id)
    if dupe_ids.any():
        msg = "Cannot add shapes with shape_id already in network."
        WranglerLogger.error(msg + f"\nDuplicates: {dupe_ids}")
        raise ShapeAddError(msg)

    if add_shapes_df.attrs.get("name") != "road_shapes":
        add_shapes_df = df_to_shapes_df(add_shapes_df, in_crs=in_crs, config=self.config)

    WranglerLogger.debug(f"add_shapes_df: \n{add_shapes_df}")
    WranglerLogger.debug(f"self.shapes_df: \n{self.shapes_df}")

    self.shapes_df = validate_df_to_model(
        concat_with_attr([self.shapes_df, add_shapes_df], axis=0), RoadShapesTable
    )

apply(project_card, transit_net=None, **kwargs)

Wrapper method to apply a roadway project, returning a new RoadwayNetwork instance.

Parameters:

Name Type Description Default
project_card Union[ProjectCard, dict]

either a dictionary of the project card object or ProjectCard instance

required
transit_net Optional[TransitNetwork]

optional transit network which will be used to if project requires as noted in SECONDARY_TRANSIT_CARD_TYPES. If no transit network is provided, will skip anything related to transit network.

None
**kwargs

keyword arguments to pass to project application

{}
Source code in network_wrangler/roadway/network.py
def apply(
    self,
    project_card: Union[ProjectCard, dict],
    transit_net: Optional[TransitNetwork] = None,
    **kwargs,
) -> RoadwayNetwork:
    """Wrapper method to apply a roadway project, returning a new RoadwayNetwork instance.

    Args:
        project_card: either a dictionary of the project card object or ProjectCard instance
        transit_net: optional transit network which will be used to if project requires as
            noted in `SECONDARY_TRANSIT_CARD_TYPES`.  If no transit network is provided, will
            skip anything related to transit network.
        **kwargs: keyword arguments to pass to project application
    """
    if not (isinstance(project_card, (ProjectCard, SubProject))):
        project_card = ProjectCard(project_card)

    # project_card.validate()
    if not project_card.valid:
        msg = f"Project card {project_card.project} not valid."
        WranglerLogger.error(msg)
        raise ProjectCardError(msg)

    if project_card._sub_projects:
        for sp in project_card._sub_projects:
            WranglerLogger.debug(f"- applying subproject: {sp.change_type}")
            self._apply_change(sp, transit_net=transit_net, **kwargs)
        return self
    return self._apply_change(project_card, transit_net=transit_net, **kwargs)

clean_unused_nodes()

Removes any unused nodes from network that aren’t referenced by links_df.

NOTE: does not check if these nodes are used by transit, so use with caution.

Source code in network_wrangler/roadway/network.py
def clean_unused_nodes(self):
    """Removes any unused nodes from network that aren't referenced by links_df.

    NOTE: does not check if these nodes are used by transit, so use with caution.
    """
    from .nodes.nodes import node_ids_without_links

    node_ids = node_ids_without_links(self.nodes_df, self.links_df)
    self.nodes_df = self.nodes_df.drop(node_ids)

clean_unused_shapes()

Removes any unused shapes from network that aren’t referenced by links_df.

Source code in network_wrangler/roadway/network.py
def clean_unused_shapes(self):
    """Removes any unused shapes from network that aren't referenced by links_df."""
    from .shapes.shapes import shape_ids_without_links

    del_shape_ids = shape_ids_without_links(self.shapes_df, self.links_df)
    self.shapes_df = self.shapes_df.drop(del_shape_ids)

coerce_crs(v)

Coerce crs of nodes_df and links_df to LAT_LON_CRS.

Source code in network_wrangler/roadway/network.py
@field_validator("nodes_df", "links_df")
def coerce_crs(cls, v):
    """Coerce crs of nodes_df and links_df to LAT_LON_CRS."""
    if v.crs != LAT_LON_CRS:
        WranglerLogger.warning(
            f"CRS of links_df ({v.crs}) doesn't match network crs {LAT_LON_CRS}. \
                Changing to network crs."
        )
        v.to_crs(LAT_LON_CRS)
    return v

Deletes links based on selection dictionary and optionally associated nodes and shapes.

Parameters:

Name Type Description Default
selection_dict SelectLinks

Dictionary describing link selections as follows: all: Optional[bool] = False. If true, will select all. name: Optional[list[str]] ref: Optional[list[str]] osm_link_id:Optional[list[str]] model_link_id: Optional[list[int]] modes: Optional[list[str]]. Defaults to “any” ignore_missing: if true, will not error when defaults to True. …plus any other link property to select on top of these.

required
clean_nodes bool

If True, will clean nodes uniquely associated with deleted links. Defaults to False.

False
clean_shapes bool

If True, will clean nodes uniquely associated with deleted links. Defaults to False.

False
transit_net TransitNetwork

If provided, will check TransitNetwork and warn if deletion breaks transit shapes. Defaults to None.

None
Source code in network_wrangler/roadway/network.py
def delete_links(
    self,
    selection_dict: Union[dict, SelectLinksDict],
    clean_nodes: bool = False,
    clean_shapes: bool = False,
    transit_net: Optional[TransitNetwork] = None,
):
    """Deletes links based on selection dictionary and optionally associated nodes and shapes.

    Args:
        selection_dict (SelectLinks): Dictionary describing link selections as follows:
            `all`: Optional[bool] = False. If true, will select all.
            `name`: Optional[list[str]]
            `ref`: Optional[list[str]]
            `osm_link_id`:Optional[list[str]]
            `model_link_id`: Optional[list[int]]
            `modes`: Optional[list[str]]. Defaults to "any"
            `ignore_missing`: if true, will not error when defaults to True.
            ...plus any other link property to select on top of these.
        clean_nodes (bool, optional): If True, will clean nodes uniquely associated with
            deleted links. Defaults to False.
        clean_shapes (bool, optional): If True, will clean nodes uniquely associated with
            deleted links. Defaults to False.
        transit_net (TransitNetwork, optional): If provided, will check TransitNetwork and
            warn if deletion breaks transit shapes. Defaults to None.
    """
    if not isinstance(selection_dict, SelectLinksDict):
        selection_dict = SelectLinksDict(**selection_dict)
    selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
    selection = self.get_selection({"links": selection_dict})
    if isinstance(selection, RoadwayNodeSelection):
        msg = "Selection should be for links, but got nodes."
        raise SelectionError(msg)
    if clean_nodes:
        node_ids_to_delete = node_ids_unique_to_link_ids(
            selection.selected_links, selection.selected_links_df, self.nodes_df
        )
        WranglerLogger.debug(
            f"Dropping nodes associated with dropped links: \n{node_ids_to_delete}"
        )
        self.nodes_df = delete_nodes_by_ids(self.nodes_df, del_node_ids=node_ids_to_delete)

    if clean_shapes:
        shape_ids_to_delete = shape_ids_unique_to_link_ids(
            selection.selected_links, selection.selected_links_df, self.shapes_df
        )
        WranglerLogger.debug(
            f"Dropping shapes associated with dropped links: \n{shape_ids_to_delete}"
        )
        self.shapes_df = delete_shapes_by_ids(
            self.shapes_df, del_shape_ids=shape_ids_to_delete
        )

    self.links_df = delete_links_by_ids(
        self.links_df,
        selection.selected_links,
        ignore_missing=selection.ignore_missing,
        transit_net=transit_net,
    )

delete_nodes(selection_dict, remove_links=False)

Deletes nodes from roadway network. Wont delete nodes used by links in network.

Parameters:

Name Type Description Default
selection_dict Union[dict, SelectNodesDict]

dictionary of node selection criteria in the form of a SelectNodesDict.

required
remove_links bool

if True, will remove any links that are associated with the nodes. If False, will only remove nodes if they are not associated with any links. Defaults to False.

False

Raises:

Type Description
NodeDeletionError

If not ignore_missing and selected nodes to delete aren’t in network

Source code in network_wrangler/roadway/network.py
def delete_nodes(
    self,
    selection_dict: Union[dict, SelectNodesDict],
    remove_links: bool = False,
) -> None:
    """Deletes nodes from roadway network. Wont delete nodes used by links in network.

    Args:
        selection_dict: dictionary of node selection criteria in the form of a SelectNodesDict.
        remove_links: if True, will remove any links that are associated with the nodes.
            If False, will only remove nodes if they are not associated with any links.
            Defaults to False.

    raises:
        NodeDeletionError: If not ignore_missing and selected nodes to delete aren't in network
    """
    if not isinstance(selection_dict, SelectNodesDict):
        selection_dict = SelectNodesDict(**selection_dict)
    selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
    _selection = self.get_selection({"nodes": selection_dict})
    assert isinstance(_selection, RoadwayNodeSelection)  # for mypy
    selection: RoadwayNodeSelection = _selection
    if remove_links:
        del_node_ids = selection.selected_nodes
        link_ids = self.links_with_nodes(selection.selected_nodes).model_link_id.to_list()
        WranglerLogger.info(f"Removing {len(link_ids)} links associated with nodes.")
        self.delete_links({"model_link_id": link_ids})
    else:
        unused_node_ids = node_ids_without_links(self.nodes_df, self.links_df)
        del_node_ids = list(set(selection.selected_nodes).intersection(unused_node_ids))

    self.nodes_df = delete_nodes_by_ids(
        self.nodes_df, del_node_ids, ignore_missing=selection.ignore_missing
    )

get_modal_graph(mode)

Return a networkx graph of the network for a specific mode.

Parameters:

Name Type Description Default
mode

mode of the network, one of drive,transit,walk, bike

required
Source code in network_wrangler/roadway/network.py
def get_modal_graph(self, mode) -> MultiDiGraph:
    """Return a networkx graph of the network for a specific mode.

    Args:
        mode: mode of the network, one of `drive`,`transit`,`walk`, `bike`
    """
    from .graph import net_to_graph

    if self._modal_graphs[mode]["hash"] != self.modal_graph_hash(mode):
        self._modal_graphs[mode]["graph"] = net_to_graph(self, mode)

    return self._modal_graphs[mode]["graph"]

get_property_by_timespan_and_group(link_property, category=DEFAULT_CATEGORY, timespan=DEFAULT_TIMESPAN, strict_timespan_match=False, min_overlap_minutes=60)

Returns a new dataframe with model_link_id and link property by category and timespan.

Convenience method for backward compatability.

Parameters:

Name Type Description Default
link_property str

link property to query

required
category Optional[Union[str, int]]

category to query or a list of categories. Defaults to DEFAULT_CATEGORY.

DEFAULT_CATEGORY
timespan Optional[TimespanString]

timespan to query in the form of [“HH:MM”,”HH:MM”]. Defaults to DEFAULT_TIMESPAN.

DEFAULT_TIMESPAN
strict_timespan_match bool

If True, will only return links that match the timespan exactly. Defaults to False.

False
min_overlap_minutes int

If strict_timespan_match is False, will return links that overlap with the timespan by at least this many minutes. Defaults to 60.

60
Source code in network_wrangler/roadway/network.py
def get_property_by_timespan_and_group(
    self,
    link_property: str,
    category: Optional[Union[str, int]] = DEFAULT_CATEGORY,
    timespan: Optional[TimespanString] = DEFAULT_TIMESPAN,
    strict_timespan_match: bool = False,
    min_overlap_minutes: int = 60,
) -> Any:
    """Returns a new dataframe with model_link_id and link property by category and timespan.

    Convenience method for backward compatability.

    Args:
        link_property: link property to query
        category: category to query or a list of categories. Defaults to DEFAULT_CATEGORY.
        timespan: timespan to query in the form of ["HH:MM","HH:MM"].
            Defaults to DEFAULT_TIMESPAN.
        strict_timespan_match: If True, will only return links that match the timespan exactly.
            Defaults to False.
        min_overlap_minutes: If strict_timespan_match is False, will return links that overlap
            with the timespan by at least this many minutes. Defaults to 60.
    """
    from .links.scopes import prop_for_scope

    return prop_for_scope(
        self.links_df,
        link_property,
        timespan=timespan,
        category=category,
        strict_timespan_match=strict_timespan_match,
        min_overlap_minutes=min_overlap_minutes,
    )

get_selection(selection_dict, overwrite=False)

Return selection if it already exists, otherwise performs selection.

Parameters:

Name Type Description Default
selection_dict dict

SelectFacility dictionary.

required
overwrite bool

if True, will overwrite any previously cached searches. Defaults to False.

False
Source code in network_wrangler/roadway/network.py
def get_selection(
    self,
    selection_dict: Union[dict, SelectFacility],
    overwrite: bool = False,
) -> Union[RoadwayNodeSelection, RoadwayLinkSelection]:
    """Return selection if it already exists, otherwise performs selection.

    Args:
        selection_dict (dict): SelectFacility dictionary.
        overwrite: if True, will overwrite any previously cached searches. Defaults to False.
    """
    key = _create_selection_key(selection_dict)
    if (key in self._selections) and not overwrite:
        WranglerLogger.debug(f"Using cached selection from key: {key}")
        return self._selections[key]

    if isinstance(selection_dict, SelectFacility):
        selection_data = selection_dict
    elif isinstance(selection_dict, SelectLinksDict):
        selection_data = SelectFacility(links=selection_dict)
    elif isinstance(selection_dict, SelectNodesDict):
        selection_data = SelectFacility(nodes=selection_dict)
    elif isinstance(selection_dict, dict):
        selection_data = SelectFacility(**selection_dict)
    else:
        msg = "selection_dict arg must be a dictionary or SelectFacility model."
        WranglerLogger.error(
            msg + f" Received: {selection_dict} of type {type(selection_dict)}"
        )
        raise SelectionError(msg)

    WranglerLogger.debug(f"Getting selection from key: {key}")
    if "links" in selection_data.fields:
        return RoadwayLinkSelection(self, selection_dict)
    if "nodes" in selection_data.fields:
        return RoadwayNodeSelection(self, selection_dict)
    msg = "Selection data should have either 'links' or 'nodes'."
    WranglerLogger.error(msg + f" Received: {selection_dict}")
    raise SelectionError(msg)

Returns true if network has links with AB values.

Parameters:

Name Type Description Default
ab tuple

Tuple of values corresponding with A and B.

required
Source code in network_wrangler/roadway/network.py
def has_link(self, ab: tuple) -> bool:
    """Returns true if network has links with AB values.

    Args:
        ab: Tuple of values corresponding with A and B.
    """
    sel_a, sel_b = ab
    has_link = (
        self.links_df[self.links_df[["A", "B"]]].isin_dict({"A": sel_a, "B": sel_b}).any()
    )
    return has_link

has_node(model_node_id)

Queries if network has node based on model_node_id.

Parameters:

Name Type Description Default
model_node_id int

model_node_id to check for.

required
Source code in network_wrangler/roadway/network.py
def has_node(self, model_node_id: int) -> bool:
    """Queries if network has node based on model_node_id.

    Args:
        model_node_id: model_node_id to check for.
    """
    has_node = self.nodes_df[self.nodes_df.model_node_id].isin([model_node_id]).any()

    return has_node

is_connected(mode)

Determines if the network graph is “strongly” connected.

A graph is strongly connected if each vertex is reachable from every other vertex.

Parameters:

Name Type Description Default
mode str

mode of the network, one of drive,transit,walk, bike

required
Source code in network_wrangler/roadway/network.py
def is_connected(self, mode: str) -> bool:
    """Determines if the network graph is "strongly" connected.

    A graph is strongly connected if each vertex is reachable from every other vertex.

    Args:
        mode:  mode of the network, one of `drive`,`transit`,`walk`, `bike`
    """
    is_connected = nx.is_strongly_connected(self.get_modal_graph(mode))

    return is_connected

Return subset of links_df based on link_ids list.

Source code in network_wrangler/roadway/network.py
def links_with_link_ids(self, link_ids: list[int]) -> DataFrame[RoadLinksTable]:
    """Return subset of links_df based on link_ids list."""
    return filter_links_to_ids(self.links_df, link_ids)

Return subset of links_df based on node_ids list.

Source code in network_wrangler/roadway/network.py
def links_with_nodes(self, node_ids: list[int]) -> DataFrame[RoadLinksTable]:
    """Return subset of links_df based on node_ids list."""
    return filter_links_to_node_ids(self.links_df, node_ids)

modal_graph_hash(mode)

Hash of the links in order to detect a network change from when graph created.

Source code in network_wrangler/roadway/network.py
def modal_graph_hash(self, mode) -> str:
    """Hash of the links in order to detect a network change from when graph created."""
    _value = str.encode(self.links_df.df_hash() + "-" + mode)
    _hash = hashlib.sha256(_value).hexdigest()

    return _hash

move_nodes(node_geometry_change_table)

Moves nodes based on updated geometry along with associated links and shape geometry.

Parameters:

Name Type Description Default
node_geometry_change_table DataFrame[NodeGeometryChangeTable]

a table with model_node_id, X, Y, and CRS.

required
Source code in network_wrangler/roadway/network.py
def move_nodes(
    self,
    node_geometry_change_table: DataFrame[NodeGeometryChangeTable],
):
    """Moves nodes based on updated geometry along with associated links and shape geometry.

    Args:
        node_geometry_change_table: a table with model_node_id, X, Y, and CRS.
    """
    node_geometry_change_table = NodeGeometryChangeTable(node_geometry_change_table)
    node_ids = node_geometry_change_table.model_node_id.to_list()
    WranglerLogger.debug(f"Moving nodes: {node_ids}")
    self.nodes_df = edit_node_geometry(self.nodes_df, node_geometry_change_table)
    self.links_df = edit_link_geometry_from_nodes(self.links_df, self.nodes_df, node_ids)
    self.shapes_df = edit_shape_geometry_from_nodes(
        self.shapes_df, self.links_df, self.nodes_df, node_ids
    )

node_coords(model_node_id)

Return coordinates (x, y) of a node based on model_node_id.

Source code in network_wrangler/roadway/network.py
def node_coords(self, model_node_id: int) -> tuple:
    """Return coordinates (x, y) of a node based on model_node_id."""
    try:
        node = self.nodes_df[self.nodes_df.model_node_id == model_node_id]
    except ValueError as err:
        msg = f"Node with model_node_id {model_node_id} not found."
        WranglerLogger.error(msg)
        raise NodeNotFoundError(msg) from err
    return node.geometry.x.values[0], node.geometry.y.values[0]

Returns subset of self.nodes_df that are in self.links_df.

Source code in network_wrangler/roadway/network.py
def nodes_in_links(self) -> DataFrame[RoadNodesTable]:
    """Returns subset of self.nodes_df that are in self.links_df."""
    return filter_nodes_to_links(self.links_df, self.nodes_df)

validate_config(v)

Validate config.

Source code in network_wrangler/roadway/network.py
@field_validator("config")
def validate_config(cls, v):
    """Validate config."""
    return load_wrangler_config(v)

Add data from links going to/from nodes to node.

Parameters:

Name Type Description Default
links_df DataFrame[RoadLinksTable]

Will assess connectivity of this links list

required
nodes_df DataFrame[RoadNodesTable]

Will assess connectivity of this nodes list

required
link_variables Optional[list]

list of columns in links dataframe to add to incident nodes

None

Returns:

Type Description
DataFrame[RoadNodesTable]

nodes DataFrame with link data where length is N*number of links going in/out

Source code in network_wrangler/roadway/network.py
def add_incident_link_data_to_nodes(
    links_df: DataFrame[RoadLinksTable],
    nodes_df: DataFrame[RoadNodesTable],
    link_variables: Optional[list] = None,
) -> DataFrame[RoadNodesTable]:
    """Add data from links going to/from nodes to node.

    Args:
        links_df: Will assess connectivity of this links list
        nodes_df: Will assess connectivity of this nodes list
        link_variables: list of columns in links dataframe to add to incident nodes

    Returns:
        nodes DataFrame with link data where length is N*number of links going in/out
    """
    WranglerLogger.debug("Adding following link data to nodes: ".format())
    link_variables = link_variables or []

    _link_vals_to_nodes = [x for x in link_variables if x in links_df.columns]
    if link_variables not in _link_vals_to_nodes:
        WranglerLogger.warning(
            f"Following columns not in links_df and wont be added to nodes: {list(set(link_variables) - set(_link_vals_to_nodes))} "
        )

    _nodes_from_links_A = nodes_df.merge(
        links_df[[links_df.A, *_link_vals_to_nodes]],
        how="outer",
        left_on=nodes_df.model_node_id,
        right_on=links_df.A,
    )
    _nodes_from_links_B = nodes_df.merge(
        links_df[[links_df.B, *_link_vals_to_nodes]],
        how="outer",
        left_on=nodes_df.model_node_id,
        right_on=links_df.B,
    )
    _nodes_from_links_ab = concat_with_attr([_nodes_from_links_A, _nodes_from_links_B])

    return _nodes_from_links_ab

TransitNetwork class for representing a transit network.

Transit Networks are represented as a Wrangler-flavored GTFS Feed and optionally mapped to a RoadwayNetwork object. The TransitNetwork object is the primary object for managing transit networks in Wrangler.

Usage:

1
2
3
4
5
6
7
8
```python
import network_wrangler as wr

t = wr.load_transit(stpaul_gtfs)
t.road_net = wr.load_roadway(stpaul_roadway)
t = t.apply(project_card)
write_transit(t, "output_dir")
```

TransitNetwork

Representation of a Transit Network.

Typical usage example:

import network_wrangler as wr

tc = wr.load_transit(stpaul_gtfs)

Attributes:

Name Type Description
feed

gtfs feed object with interlinked tables.

road_net RoadwayNetwork

Associated roadway network object.

graph MultiDiGraph

Graph for associated roadway network object.

config WranglerConfig

Configuration object for the transit network.

feed_path str

Where the feed was read in from.

validated_frequencies bool

The frequencies have been validated.

validated_road_network_consistency

The network has been validated against the road network.

Source code in network_wrangler/transit/network.py
class TransitNetwork:
    """Representation of a Transit Network.

    Typical usage example:
    ``` py
    import network_wrangler as wr

    tc = wr.load_transit(stpaul_gtfs)
    ```

    Attributes:
        feed: gtfs feed object with interlinked tables.
        road_net (RoadwayNetwork): Associated roadway network object.
        graph (nx.MultiDiGraph): Graph for associated roadway network object.
        config (WranglerConfig): Configuration object for the transit network.
        feed_path (str): Where the feed was read in from.
        validated_frequencies (bool): The frequencies have been validated.
        validated_road_network_consistency (): The network has been validated against
            the road network.
    """

    TIME_COLS: ClassVar = ["arrival_time", "departure_time", "start_time", "end_time"]

    def __init__(self, feed: Feed, config: WranglerConfig = DefaultConfig) -> None:
        """Constructor for TransitNetwork.

        Args:
            feed: Feed object representing the transit network gtfs tables
            config: WranglerConfig object. Defaults to DefaultConfig.
        """
        WranglerLogger.debug("Creating new TransitNetwork.")

        self._road_net: Optional[RoadwayNetwork] = None
        self.feed: Feed = feed
        self.graph: nx.MultiDiGraph = None
        self.config: WranglerConfig = config
        # initialize
        self._consistent_with_road_net = False

        # cached selections
        self._selections: dict[str, TransitSelection] = {}

    @property
    def feed_path(self):
        """Pass through property from Feed."""
        return self.feed.feed_path

    @property
    def applied_projects(self) -> list[str]:
        """List of projects applied to the network.

        Note: This may or may not return a full accurate account of all the applied projects.
        For better project accounting, please leverage the scenario object.
        """
        return _get_applied_projects_from_tables(self.feed)

    @property
    def feed(self):
        """Feed associated with the transit network."""
        return self._feed

    @feed.setter
    def feed(self, feed: Feed):
        if not isinstance(feed, Feed):
            msg = f"TransitNetwork's feed value must be a valid Feed instance. \
                             This is a {type(feed)}."
            WranglerLogger.error(msg)
            raise TransitValidationError(msg)
        if self._road_net is None or transit_road_net_consistency(feed, self._road_net):
            self._feed = feed
            self._stored_feed_hash = copy.deepcopy(feed.hash)
        else:
            msg = "Can't assign Feed inconsistent with set Roadway Network."
            WranglerLogger.error(msg)
            raise TransitRoadwayConsistencyError(msg)

    @property
    def road_net(self) -> Union[None, RoadwayNetwork]:
        """Roadway network associated with the transit network."""
        return self._road_net

    @road_net.setter
    def road_net(self, road_net_in: RoadwayNetwork):
        if road_net_in is None or road_net_in.__class__.__name__ != "RoadwayNetwork":
            msg = f"TransitNetwork's road_net: value must be a valid RoadwayNetwork instance. \
                             This is a {type(road_net_in)}."
            WranglerLogger.error(msg)
            raise TransitValidationError(msg)
        if transit_road_net_consistency(self.feed, road_net_in):
            self._road_net = road_net_in
            self._stored_road_net_hash = copy.deepcopy(road_net_in.network_hash)
            self._consistent_with_road_net = True
        else:
            msg = "Can't assign inconsistent RoadwayNetwork - Roadway Network not \
                   set, but can be referenced separately."
            WranglerLogger.error(msg)
            raise TransitRoadwayConsistencyError(msg)

    @property
    def feed_hash(self):
        """Return the hash of the feed."""
        return self.feed.hash

    @property
    def consistent_with_road_net(self) -> bool:
        """Indicate if road_net is consistent with transit network.

        Will return True if road_net is None, but provide a warning.

        Checks the network hash of when consistency was last evaluated. If transit network or
        roadway network has changed, will re-evaluate consistency and return the updated value and
        update self._stored_road_net_hash.

        Returns:
            Boolean indicating if road_net is consistent with transit network.
        """
        if self.road_net is None:
            WranglerLogger.warning("Roadway Network not set, cannot accurately check consistency.")
            return True
        updated_road = self.road_net.network_hash != self._stored_road_net_hash
        updated_feed = self.feed_hash != self._stored_feed_hash

        if updated_road or updated_feed:
            self._consistent_with_road_net = transit_road_net_consistency(self.feed, self.road_net)
            self._stored_road_net_hash = copy.deepcopy(self.road_net.network_hash)
            self._stored_feed_hash = copy.deepcopy(self.feed_hash)
        return self._consistent_with_road_net

    def __deepcopy__(self, memo):
        """Returns copied TransitNetwork instance with deep copy of Feed but not roadway net."""
        COPY_REF_NOT_VALUE = ["_road_net"]
        # Create a new, empty instance
        copied_net = self.__class__.__new__(self.__class__)
        # Return the new TransitNetwork instance
        attribute_dict = vars(self)

        # Copy the attributes to the new instance
        for attr_name, attr_value in attribute_dict.items():
            # WranglerLogger.debug(f"Copying {attr_name}")
            if attr_name in COPY_REF_NOT_VALUE:
                # If the attribute is in the COPY_REF_NOT_VALUE list, assign the reference
                setattr(copied_net, attr_name, attr_value)
            else:
                # WranglerLogger.debug(f"making deep copy: {attr_name}")
                # For other attributes, perform a deep copy
                setattr(copied_net, attr_name, copy.deepcopy(attr_value, memo))

        return copied_net

    def deepcopy(self):
        """Returns copied TransitNetwork instance with deep copy of Feed but not roadway net."""
        return copy.deepcopy(self)

    @property
    def stops_gdf(self) -> gpd.GeoDataFrame:
        """Return stops as a GeoDataFrame using set roadway geometry."""
        ref_nodes = self.road_net.nodes_df if self.road_net is not None else None
        return to_points_gdf(self.feed.stops, ref_nodes_df=ref_nodes)

    @property
    def shapes_gdf(self) -> gpd.GeoDataFrame:
        """Return aggregated shapes as a GeoDataFrame using set roadway geometry."""
        ref_nodes = self.road_net.nodes_df if self.road_net is not None else None
        return shapes_to_trip_shapes_gdf(self.feed.shapes, ref_nodes_df=ref_nodes)

    @property
    def shape_links_gdf(self) -> gpd.GeoDataFrame:
        """Return shape-links as a GeoDataFrame using set roadway geometry."""
        ref_nodes = self.road_net.nodes_df if self.road_net is not None else None
        return shapes_to_shape_links_gdf(self.feed.shapes, ref_nodes_df=ref_nodes)

    @property
    def stop_time_links_gdf(self) -> gpd.GeoDataFrame:
        """Return stop-time-links as a GeoDataFrame using set roadway geometry."""
        ref_nodes = self.road_net.nodes_df if self.road_net is not None else None
        return stop_times_to_stop_time_links_gdf(
            self.feed.stop_times, self.feed.stops, ref_nodes_df=ref_nodes
        )

    @property
    def stop_times_points_gdf(self) -> gpd.GeoDataFrame:
        """Return stop-time-points as a GeoDataFrame using set roadway geometry."""
        ref_nodes = self.road_net.nodes_df if self.road_net is not None else None

        return stop_times_to_stop_time_points_gdf(
            self.feed.stop_times, self.feed.stops, ref_nodes_df=ref_nodes
        )

    def get_selection(
        self,
        selection_dict: dict,
        overwrite: bool = False,
    ) -> TransitSelection:
        """Return selection if it already exists, otherwise performs selection.

        Will raise an error if no trips found.

        Args:
            selection_dict (dict): _description_
            overwrite: if True, will overwrite any previously cached searches. Defaults to False.

        Returns:
            Selection: Selection object
        """
        key = dict_to_hexkey(selection_dict)

        if (key not in self._selections) or overwrite:
            WranglerLogger.debug(f"Performing selection from key: {key}")
            self._selections[key] = TransitSelection(self, selection_dict)
        else:
            WranglerLogger.debug(f"Using cached selection from key: {key}")

        if not self._selections[key]:
            msg = f"No links or nodes found for selection dict: \n {selection_dict}"
            WranglerLogger.error(msg)
            raise TransitSelectionEmptyError(msg)
        return self._selections[key]

    def apply(self, project_card: Union[ProjectCard, dict], **kwargs) -> TransitNetwork:
        """Wrapper method to apply a roadway project, returning a new TransitNetwork instance.

        Args:
            project_card: either a dictionary of the project card object or ProjectCard instance
            **kwargs: keyword arguments to pass to project application
        """
        if not (isinstance(project_card, (ProjectCard, SubProject))):
            project_card = ProjectCard(project_card)

        if not project_card.valid:
            msg = f"Project card {project_card.project} not valid."
            WranglerLogger.error(msg)
            raise ProjectCardError(msg)

        if project_card._sub_projects:
            for sp in project_card._sub_projects:
                WranglerLogger.debug(f"- applying subproject: {sp.change_type}")
                self._apply_change(sp, **kwargs)
            return self
        return self._apply_change(project_card, **kwargs)

    def _apply_change(
        self,
        change: Union[ProjectCard, SubProject],
        reference_road_net: Optional[RoadwayNetwork] = None,
    ) -> TransitNetwork:
        """Apply a single change: a single-project project or a sub-project."""
        if not isinstance(change, SubProject):
            WranglerLogger.info(f"Applying Project to Transit Network: {change.project}")

        if change.change_type == "transit_property_change":
            return apply_transit_property_change(
                self,
                self.get_selection(change.transit_property_change["service"]),
                change.transit_property_change["property_changes"],
                project_name=change.project,
            )

        if change.change_type == "transit_routing_change":
            return apply_transit_routing_change(
                self,
                self.get_selection(change.transit_routing_change["service"]),
                change.transit_routing_change["routing"],
                reference_road_net=reference_road_net,
                project_name=change.project,
            )

        if change.change_type == "pycode":
            return apply_calculated_transit(self, change.pycode)

        if change.change_type == "transit_route_addition":
            return apply_transit_route_addition(
                self,
                change.transit_route_addition,
                reference_road_net=reference_road_net,
            )
        if change.change_type == "transit_service_deletion":
            return apply_transit_service_deletion(
                self,
                self.get_selection(change.transit_service_deletion["service"]),
                clean_shapes=change.transit_service_deletion.get("clean_shapes"),
                clean_routes=change.transit_service_deletion.get("clean_routes"),
            )
        msg = f"Not a currently valid transit project: {change}."
        WranglerLogger.error(msg)
        raise NotImplementedError(msg)

applied_projects: list[str] property

List of projects applied to the network.

Note: This may or may not return a full accurate account of all the applied projects. For better project accounting, please leverage the scenario object.

consistent_with_road_net: bool property

Indicate if road_net is consistent with transit network.

Will return True if road_net is None, but provide a warning.

Checks the network hash of when consistency was last evaluated. If transit network or roadway network has changed, will re-evaluate consistency and return the updated value and update self._stored_road_net_hash.

Returns:

Type Description
bool

Boolean indicating if road_net is consistent with transit network.

feed property writable

Feed associated with the transit network.

feed_hash property

Return the hash of the feed.

feed_path property

Pass through property from Feed.

road_net: Union[None, RoadwayNetwork] property writable

Roadway network associated with the transit network.

Return shape-links as a GeoDataFrame using set roadway geometry.

shapes_gdf: gpd.GeoDataFrame property

Return aggregated shapes as a GeoDataFrame using set roadway geometry.

Return stop-time-links as a GeoDataFrame using set roadway geometry.

stop_times_points_gdf: gpd.GeoDataFrame property

Return stop-time-points as a GeoDataFrame using set roadway geometry.

stops_gdf: gpd.GeoDataFrame property

Return stops as a GeoDataFrame using set roadway geometry.

__deepcopy__(memo)

Returns copied TransitNetwork instance with deep copy of Feed but not roadway net.

Source code in network_wrangler/transit/network.py
def __deepcopy__(self, memo):
    """Returns copied TransitNetwork instance with deep copy of Feed but not roadway net."""
    COPY_REF_NOT_VALUE = ["_road_net"]
    # Create a new, empty instance
    copied_net = self.__class__.__new__(self.__class__)
    # Return the new TransitNetwork instance
    attribute_dict = vars(self)

    # Copy the attributes to the new instance
    for attr_name, attr_value in attribute_dict.items():
        # WranglerLogger.debug(f"Copying {attr_name}")
        if attr_name in COPY_REF_NOT_VALUE:
            # If the attribute is in the COPY_REF_NOT_VALUE list, assign the reference
            setattr(copied_net, attr_name, attr_value)
        else:
            # WranglerLogger.debug(f"making deep copy: {attr_name}")
            # For other attributes, perform a deep copy
            setattr(copied_net, attr_name, copy.deepcopy(attr_value, memo))

    return copied_net

__init__(feed, config=DefaultConfig)

Constructor for TransitNetwork.

Parameters:

Name Type Description Default
feed Feed

Feed object representing the transit network gtfs tables

required
config WranglerConfig

WranglerConfig object. Defaults to DefaultConfig.

DefaultConfig
Source code in network_wrangler/transit/network.py
def __init__(self, feed: Feed, config: WranglerConfig = DefaultConfig) -> None:
    """Constructor for TransitNetwork.

    Args:
        feed: Feed object representing the transit network gtfs tables
        config: WranglerConfig object. Defaults to DefaultConfig.
    """
    WranglerLogger.debug("Creating new TransitNetwork.")

    self._road_net: Optional[RoadwayNetwork] = None
    self.feed: Feed = feed
    self.graph: nx.MultiDiGraph = None
    self.config: WranglerConfig = config
    # initialize
    self._consistent_with_road_net = False

    # cached selections
    self._selections: dict[str, TransitSelection] = {}

apply(project_card, **kwargs)

Wrapper method to apply a roadway project, returning a new TransitNetwork instance.

Parameters:

Name Type Description Default
project_card Union[ProjectCard, dict]

either a dictionary of the project card object or ProjectCard instance

required
**kwargs

keyword arguments to pass to project application

{}
Source code in network_wrangler/transit/network.py
def apply(self, project_card: Union[ProjectCard, dict], **kwargs) -> TransitNetwork:
    """Wrapper method to apply a roadway project, returning a new TransitNetwork instance.

    Args:
        project_card: either a dictionary of the project card object or ProjectCard instance
        **kwargs: keyword arguments to pass to project application
    """
    if not (isinstance(project_card, (ProjectCard, SubProject))):
        project_card = ProjectCard(project_card)

    if not project_card.valid:
        msg = f"Project card {project_card.project} not valid."
        WranglerLogger.error(msg)
        raise ProjectCardError(msg)

    if project_card._sub_projects:
        for sp in project_card._sub_projects:
            WranglerLogger.debug(f"- applying subproject: {sp.change_type}")
            self._apply_change(sp, **kwargs)
        return self
    return self._apply_change(project_card, **kwargs)

deepcopy()

Returns copied TransitNetwork instance with deep copy of Feed but not roadway net.

Source code in network_wrangler/transit/network.py
def deepcopy(self):
    """Returns copied TransitNetwork instance with deep copy of Feed but not roadway net."""
    return copy.deepcopy(self)

get_selection(selection_dict, overwrite=False)

Return selection if it already exists, otherwise performs selection.

Will raise an error if no trips found.

Parameters:

Name Type Description Default
selection_dict dict

description

required
overwrite bool

if True, will overwrite any previously cached searches. Defaults to False.

False

Returns:

Name Type Description
Selection TransitSelection

Selection object

Source code in network_wrangler/transit/network.py
def get_selection(
    self,
    selection_dict: dict,
    overwrite: bool = False,
) -> TransitSelection:
    """Return selection if it already exists, otherwise performs selection.

    Will raise an error if no trips found.

    Args:
        selection_dict (dict): _description_
        overwrite: if True, will overwrite any previously cached searches. Defaults to False.

    Returns:
        Selection: Selection object
    """
    key = dict_to_hexkey(selection_dict)

    if (key not in self._selections) or overwrite:
        WranglerLogger.debug(f"Performing selection from key: {key}")
        self._selections[key] = TransitSelection(self, selection_dict)
    else:
        WranglerLogger.debug(f"Using cached selection from key: {key}")

    if not self._selections[key]:
        msg = f"No links or nodes found for selection dict: \n {selection_dict}"
        WranglerLogger.error(msg)
        raise TransitSelectionEmptyError(msg)
    return self._selections[key]

Configs

Configuration for parameters for Network Wrangler.

Users can change a handful of parameters which control the way Wrangler runs. These parameters can be saved as a wrangler config file which can be read in repeatedly to make sure the same parameters are used each time.

Usage

At runtime, you can specify configurable parameters at the scenario level which will then also be assigned and accessible to the roadway and transit networks.

create_scenario(...config = myconfig)

Or if you are not using Scenario functionality, you can specify the config when you read in a RoadwayNetwork.

load_roadway_from_dir(**roadway, config=myconfig)
load_transit(**transit, config=myconfig)

my_config can be a:

  • Path to a config file in yaml/toml/json (recommended),
  • List of paths to config files (in case you want to split up various sub-configurations)
  • Dictionary which is in the same structure of a config file, or
  • A WranglerConfig() instance.

If not provided, Wrangler will use reasonable defaults.

Default Wrangler Configuration Values

If not explicitly provided, the following default values are used:

IDS:
    TRANSIT_SHAPE_ID_METHOD: scalar
    TRANSIT_SHAPE_ID_SCALAR: 1000000
    ROAD_SHAPE_ID_METHOD: scalar
    ROAD_SHAPE_ID_SCALAR: 1000
    ML_LINK_ID_METHOD: range
    ML_LINK_ID_RANGE: (950000, 999999)
    ML_LINK_ID_SCALAR: 15000
    ML_NODE_ID_METHOD: range
    ML_NODE_ID_RANGE: (950000, 999999)
    ML_NODE_ID_SCALAR: 15000
EDITS:
    EXISTING_VALUE_CONFLIC: warn
    OVERWRITE_SCOPED: conflicting
MODEL_ROADWAY:
    ML_OFFSET_METERS: int = -10
    ADDITIONAL_COPY_FROM_GP_TO_ML: []
    ADDITIONAL_COPY_TO_ACCESS_EGRESS: []
CPU:
    EST_PD_READ_SPEED:
        csv: 0.03
        parquet: 0.005
        geojson: 0.03
        json: 0.15
        txt: 0.04
Extended usage

Load the default configuration:

from network_wrangler.configs import DefaultConfig

Access the configuration:

from network_wrangler.configs import DefaultConfig
DefaultConfig.MODEL_ROADWAY.ML_OFFSET_METERS
>> -10

Modify the default configuration in-line:

from network_wrangler.configs import DefaultConfig

DefaultConfig.MODEL_ROADWAY.ML_OFFSET_METERS = 20

Load a configuration from a file:

from network_wrangler.configs import load_wrangler_config

config = load_wrangler_config("path/to/config.yaml")

Set a configuration value:

config.MODEL_ROADWAY.ML_OFFSET_METERS = 10

CpuConfig

Bases: ConfigItem

CPU Configuration - Will not change any outcomes.

Attributes:

Name Type Description
EST_PD_READ_SPEED dict[str, float]

Read sec / MB - WILL DEPEND ON SPECIFIC COMPUTER

Source code in network_wrangler/configs/wrangler.py
@dataclass
class CpuConfig(ConfigItem):
    """CPU Configuration -  Will not change any outcomes.

    Attributes:
        EST_PD_READ_SPEED: Read sec / MB - WILL DEPEND ON SPECIFIC COMPUTER
    """

    EST_PD_READ_SPEED: dict[str, float] = Field(
        default_factory=lambda: {
            "csv": 0.03,
            "parquet": 0.005,
            "geojson": 0.03,
            "json": 0.15,
            "txt": 0.04,
        }
    )

EditsConfig

Bases: ConfigItem

Configuration for Edits.

Attributes:

Name Type Description
EXISTING_VALUE_CONFLICT Literal['warn', 'error', 'skip']

Only used if ‘existing’ provided in project card and existing doesn’t match the existing network value. One of error, warn, or skip. error will raise an error, warn will warn the user, and skip will skip the change for that specific property (note it will still apply any remaining property changes). Defaults to warn. Can be overridden by setting existing_value_conflict in a roadway_property_change project card.

OVERWRITE_SCOPED Literal['conflicting', 'all', 'error']

How to handle conflicts with existing values. Should be one of “conflicting”, “all”, or False. “conflicting” will only overwrite values where the scope only partially overlaps with the existing value. “all” will overwrite all the scoped values. “error” will error if there is any overlap. Default is “conflicting”. Can be changed at the project-level by setting overwrite_scoped in a roadway_property_change project card.

Source code in network_wrangler/configs/wrangler.py
@dataclass
class EditsConfig(ConfigItem):
    """Configuration for Edits.

    Attributes:
        EXISTING_VALUE_CONFLICT: Only used if 'existing' provided in project card and
            `existing` doesn't match the existing network value. One of `error`, `warn`, or `skip`.
            `error` will raise an error, `warn` will warn the user, and `skip` will skip the change
            for that specific property (note it will still apply any remaining property changes).
            Defaults to `warn`. Can be overridden by setting `existing_value_conflict` in
            a `roadway_property_change` project card.

        OVERWRITE_SCOPED: How to handle conflicts with existing values.
            Should be one of "conflicting", "all", or False.
            "conflicting" will only overwrite values where the scope only partially overlaps with
            the existing value. "all" will overwrite all the scoped values. "error" will error if
            there is any overlap. Default is "conflicting". Can be changed at the project-level
            by setting `overwrite_scoped` in a `roadway_property_change` project card.
    """

    EXISTING_VALUE_CONFLICT: Literal["warn", "error", "skip"] = "warn"
    OVERWRITE_SCOPED: Literal["conflicting", "all", "error"] = "conflicting"

IdGenerationConfig

Bases: ConfigItem

Model Roadway Configuration.

Attributes:

Name Type Description
TRANSIT_SHAPE_ID_METHOD Literal['scalar']

method for creating a shape_id for a transit shape. Should be “scalar”.

TRANSIT_SHAPE_ID_SCALAR int

scalar value to add to general purpose lane to create a shape_id for a transit shape.

ROAD_SHAPE_ID_METHOD Literal['scalar']

method for creating a shape_id for a roadway shape. Should be “scalar”.

ROAD_SHAPE_ID_SCALAR int

scalar value to add to general purpose lane to create a shape_id for a roadway shape.

ML_LINK_ID_METHOD Literal['range', 'scalar']

method for creating a model_link_id for an associated link for a parallel managed lane.

ML_LINK_ID_RANGE tuple[int, int]

range of model_link_ids to use when creating an associated link for a parallel managed lane.

ML_LINK_ID_SCALAR int

scalar value to add to general purpose lane to create a model_link_id when creating an associated link for a parallel managed lane.

ML_NODE_ID_METHOD Literal['range', 'scalar']

method for creating a model_node_id for an associated node for a parallel managed lane.

ML_NODE_ID_RANGE tuple[int, int]

range of model_node_ids to use when creating an associated node for a parallel managed lane.

ML_NODE_ID_SCALAR int

scalar value to add to general purpose lane node ides create a model_node_id when creating an associated nodes for parallel managed lane.

Source code in network_wrangler/configs/wrangler.py
@dataclass
class IdGenerationConfig(ConfigItem):
    """Model Roadway Configuration.

    Attributes:
        TRANSIT_SHAPE_ID_METHOD: method for creating a shape_id for a transit shape.
            Should be "scalar".
        TRANSIT_SHAPE_ID_SCALAR: scalar value to add to general purpose lane to create a
            shape_id for a transit shape.
        ROAD_SHAPE_ID_METHOD: method for creating a shape_id for a roadway shape.
            Should be "scalar".
        ROAD_SHAPE_ID_SCALAR: scalar value to add to general purpose lane to create a
            shape_id for a roadway shape.
        ML_LINK_ID_METHOD: method for creating a model_link_id for an associated
            link for a parallel managed lane.
        ML_LINK_ID_RANGE: range of model_link_ids to use when creating an associated
            link for a parallel managed lane.
        ML_LINK_ID_SCALAR: scalar value to add to general purpose lane to create a
            model_link_id when creating an associated link for a parallel managed lane.
        ML_NODE_ID_METHOD: method for creating a model_node_id for an associated node
            for a parallel managed lane.
        ML_NODE_ID_RANGE: range of model_node_ids to use when creating an associated
            node for a parallel managed lane.
        ML_NODE_ID_SCALAR: scalar value to add to general purpose lane node ides create
            a model_node_id when creating an associated nodes for parallel managed lane.
    """

    TRANSIT_SHAPE_ID_METHOD: Literal["scalar"] = "scalar"
    TRANSIT_SHAPE_ID_SCALAR: int = 1000000
    ROAD_SHAPE_ID_METHOD: Literal["scalar"] = "scalar"
    ROAD_SHAPE_ID_SCALAR: int = 1000
    ML_LINK_ID_METHOD: Literal["range", "scalar"] = "scalar"
    ML_LINK_ID_RANGE: tuple[int, int] = (950000, 999999)
    ML_LINK_ID_SCALAR: int = 3000000
    ML_NODE_ID_METHOD: Literal["range", "scalar"] = "range"
    ML_NODE_ID_RANGE: tuple[int, int] = (950000, 999999)
    ML_NODE_ID_SCALAR: int = 15000

ModelRoadwayConfig

Bases: ConfigItem

Model Roadway Configuration.

Attributes:

Name Type Description
ML_OFFSET_METERS int

Offset in meters for managed lanes.

ADDITIONAL_COPY_FROM_GP_TO_ML list[str]

Additional fields to copy from general purpose to managed lanes.

ADDITIONAL_COPY_TO_ACCESS_EGRESS list[str]

Additional fields to copy to access and egress links.

Source code in network_wrangler/configs/wrangler.py
@dataclass
class ModelRoadwayConfig(ConfigItem):
    """Model Roadway Configuration.

    Attributes:
        ML_OFFSET_METERS: Offset in meters for managed lanes.
        ADDITIONAL_COPY_FROM_GP_TO_ML: Additional fields to copy from general purpose to managed
            lanes.
        ADDITIONAL_COPY_TO_ACCESS_EGRESS: Additional fields to copy to access and egress links.
    """

    ML_OFFSET_METERS: int = -10
    ADDITIONAL_COPY_FROM_GP_TO_ML: list[str] = Field(default_factory=list)
    ADDITIONAL_COPY_TO_ACCESS_EGRESS: list[str] = Field(default_factory=list)

WranglerConfig

Bases: ConfigItem

Configuration for Network Wrangler.

Attributes:

Name Type Description
IDS IdGenerationConfig

Parameteters governing how new ids are generated.

MODEL_ROADWAY ModelRoadwayConfig

Parameters governing how the model roadway is created.

CPU CpuConfig

Parameters for accessing CPU information. Will not change any outcomes.

EDITS EditsConfig

Parameters governing how edits are handled.

Source code in network_wrangler/configs/wrangler.py
@dataclass
class WranglerConfig(ConfigItem):
    """Configuration for Network Wrangler.

    Attributes:
        IDS: Parameteters governing how new ids are generated.
        MODEL_ROADWAY: Parameters governing how the model roadway is created.
        CPU: Parameters for accessing CPU information. Will not change any outcomes.
        EDITS: Parameters governing how edits are handled.
    """

    IDS: IdGenerationConfig = IdGenerationConfig()
    MODEL_ROADWAY: ModelRoadwayConfig = ModelRoadwayConfig()
    CPU: CpuConfig = CpuConfig()
    EDITS: EditsConfig = EditsConfig()

Scenario configuration for Network Wrangler.

You can build a scenario and write out the output from a scenario configuration file using the code below. This is very useful when you are running a specific scenario with minor variations over again because you can enter your config file into version control. In addition to the completed roadway and transit files, the output will provide a record of how the scenario was run.

Usage
    from scenario import build_scenario_from_config
    my_scenario = build_scenario_from_config(my_scenario_config)

Where my_scenario_config can be a:

  • Path to a scenario config file in yaml/toml/json (recommended),
  • Dictionary which is in the same structure of a scenario config file, or
  • A ScenarioConfig() instance.

Notes on relative paths in scenario configs

  • Relative paths are recognized by a preceeding “.”.
  • Relative paths within output_scenario for roadway, transit, and project_cards are interpreted to be relative to output_scenario.path.
  • All other relative paths are interpreted to be relative to directory of the scenario config file. (Or if scenario config is provided as a dictionary, relative paths will be interpreted as relative to the current working directory.)
Example Scenario Config
name: "my_scenario"
base_scenario:
    roadway:
        dir: "path/to/roadway_network"
        file_format: "geojson"
        read_in_shapes: True
    transit:
        dir: "path/to/transit_network"
        file_format: "txt"
    applied_projects:
        - "project1"
        - "project2"
    conflicts:
        "project3": ["project1", "project2"]
        "project4": ["project1"]
projects:
    project_card_filepath:
        - "path/to/projectA.yaml"
        - "path/to/projectB.yaml"
    filter_tags:
        - "tag1"
output_scenario:
    overwrite: True
    roadway:
        out_dir: "path/to/output/roadway"
        prefix: "my_scenario"
        file_format: "geojson"
        true_shape: False
    transit:
        out_dir: "path/to/output/transit"
        prefix: "my_scenario"
        file_format: "txt"
    project_cards:
        out_dir: "path/to/output/project_cards"

wrangler_config: "path/to/wrangler_config.yaml"
Extended Usage

Load a configuration from a file:

from network_wrangler.configs import load_scenario_config

my_scenario_config = load_scenario_config("path/to/config.yaml")

Access the configuration:

my_scenario_config.base_transit_network.path
>> path/to/transit_network

ProjectCardOutputConfig

Bases: ConfigItem

Configuration for outputing project cards in a scenario.

Attributes:

Name Type Description
out_dir

Path to write the project card files to if you don’t want to use the default.

write

If True, will write the project cards. Defaults to True.

Source code in network_wrangler/configs/scenario.py
class ProjectCardOutputConfig(ConfigItem):
    """Configuration for outputing project cards in a scenario.

    Attributes:
        out_dir: Path to write the project card files to if you don't want to use the default.
        write: If True, will write the project cards. Defaults to True.
    """

    def __init__(
        self,
        base_path: Path = DEFAULT_BASE_DIR,
        out_dir: Path = DEFAULT_PROJECT_OUT_DIR,
        write: bool = DEFAULT_PROJECT_WRITE,
    ):
        """Constructor for ProjectCardOutputConfig."""
        if out_dir is not None and not Path(out_dir).is_absolute():
            self.out_dir = (base_path / Path(out_dir)).resolve()
        else:
            self.out_dir = Path(out_dir)
        self.write = write

__init__(base_path=DEFAULT_BASE_DIR, out_dir=DEFAULT_PROJECT_OUT_DIR, write=DEFAULT_PROJECT_WRITE)

Constructor for ProjectCardOutputConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    base_path: Path = DEFAULT_BASE_DIR,
    out_dir: Path = DEFAULT_PROJECT_OUT_DIR,
    write: bool = DEFAULT_PROJECT_WRITE,
):
    """Constructor for ProjectCardOutputConfig."""
    if out_dir is not None and not Path(out_dir).is_absolute():
        self.out_dir = (base_path / Path(out_dir)).resolve()
    else:
        self.out_dir = Path(out_dir)
    self.write = write

ProjectsConfig

Bases: ConfigItem

Configuration for projects in a scenario.

Attributes:

Name Type Description
project_card_filepath

where the project card is. A single path, list of paths, a directory, or a glob pattern. Defaults to None.

filter_tags

List of tags to filter the project cards by.

Source code in network_wrangler/configs/scenario.py
class ProjectsConfig(ConfigItem):
    """Configuration for projects in a scenario.

    Attributes:
        project_card_filepath: where the project card is.  A single path, list of paths,
            a directory, or a glob pattern. Defaults to None.
        filter_tags: List of tags to filter the project cards by.
    """

    def __init__(
        self,
        base_path: Path = DEFAULT_BASE_DIR,
        project_card_filepath: ProjectCardFilepaths = DEFAULT_PROJECT_IN_PATHS,
        filter_tags: list[str] = DEFAULT_PROJECT_TAGS,
    ):
        """Constructor for ProjectsConfig."""
        self.project_card_filepath = _resolve_rel_paths(project_card_filepath, base_path=base_path)
        self.filter_tags = filter_tags

__init__(base_path=DEFAULT_BASE_DIR, project_card_filepath=DEFAULT_PROJECT_IN_PATHS, filter_tags=DEFAULT_PROJECT_TAGS)

Constructor for ProjectsConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    base_path: Path = DEFAULT_BASE_DIR,
    project_card_filepath: ProjectCardFilepaths = DEFAULT_PROJECT_IN_PATHS,
    filter_tags: list[str] = DEFAULT_PROJECT_TAGS,
):
    """Constructor for ProjectsConfig."""
    self.project_card_filepath = _resolve_rel_paths(project_card_filepath, base_path=base_path)
    self.filter_tags = filter_tags

RoadwayNetworkInputConfig

Bases: ConfigItem

Configuration for the road network in a scenario.

Attributes:

Name Type Description
dir

Path to directory with roadway network files.

file_format

File format for the roadway network files. Should be one of RoadwayFileTypes. Defaults to “geojson”.

read_in_shapes

If True, will read in the shapes of the roadway network. Defaults to False.

boundary_geocode

Geocode of the boundary. Will use this to filter the roadway network.

boundary_file

Path to the boundary file. If provided and both boundary_gdf and boundary_geocode are not provided, will use this to filter the roadway network.

Source code in network_wrangler/configs/scenario.py
class RoadwayNetworkInputConfig(ConfigItem):
    """Configuration for the road network in a scenario.

    Attributes:
        dir: Path to directory with roadway network files.
        file_format: File format for the roadway network files. Should be one of RoadwayFileTypes.
            Defaults to "geojson".
        read_in_shapes: If True, will read in the shapes of the roadway network. Defaults to False.
        boundary_geocode: Geocode of the boundary. Will use this to filter the roadway network.
        boundary_file: Path to the boundary file. If provided and both boundary_gdf and
            boundary_geocode are not provided, will use this to filter the roadway network.
    """

    def __init__(
        self,
        base_path: Path = DEFAULT_BASE_DIR,
        dir: Path = DEFAULT_ROADWAY_IN_DIR,
        file_format: RoadwayFileTypes = DEFAULT_ROADWAY_IN_FORMAT,
        read_in_shapes: bool = DEFAULT_ROADWAY_SHAPE_READ,
        boundary_geocode: Optional[str] = None,
        boundary_file: Optional[Path] = None,
    ):
        """Constructor for RoadwayNetworkInputConfig."""
        if dir is not None and not Path(dir).is_absolute():
            self.dir = (base_path / Path(dir)).resolve()
        else:
            self.dir = Path(dir)
        self.file_format = file_format
        self.read_in_shapes = read_in_shapes
        self.boundary_geocode = boundary_geocode
        self.boundary_file = boundary_file

__init__(base_path=DEFAULT_BASE_DIR, dir=DEFAULT_ROADWAY_IN_DIR, file_format=DEFAULT_ROADWAY_IN_FORMAT, read_in_shapes=DEFAULT_ROADWAY_SHAPE_READ, boundary_geocode=None, boundary_file=None)

Constructor for RoadwayNetworkInputConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    base_path: Path = DEFAULT_BASE_DIR,
    dir: Path = DEFAULT_ROADWAY_IN_DIR,
    file_format: RoadwayFileTypes = DEFAULT_ROADWAY_IN_FORMAT,
    read_in_shapes: bool = DEFAULT_ROADWAY_SHAPE_READ,
    boundary_geocode: Optional[str] = None,
    boundary_file: Optional[Path] = None,
):
    """Constructor for RoadwayNetworkInputConfig."""
    if dir is not None and not Path(dir).is_absolute():
        self.dir = (base_path / Path(dir)).resolve()
    else:
        self.dir = Path(dir)
    self.file_format = file_format
    self.read_in_shapes = read_in_shapes
    self.boundary_geocode = boundary_geocode
    self.boundary_file = boundary_file

RoadwayNetworkOutputConfig

Bases: ConfigItem

Configuration for writing out the resulting roadway network for a scenario.

Attributes:

Name Type Description
out_dir

Path to write the roadway network files to if you don’t want to use the default.

prefix

Prefix to add to the file name. If not provided will use the scenario name.

file_format

File format to write the roadway network to. Should be one of RoadwayFileTypes. Defaults to “geojson”.

true_shape

If True, will write the true shape of the roadway network. Defaults to False.

write

If True, will write the roadway network. Defaults to True.

Source code in network_wrangler/configs/scenario.py
class RoadwayNetworkOutputConfig(ConfigItem):
    """Configuration for writing out the resulting roadway network for a scenario.

    Attributes:
        out_dir: Path to write the roadway network files to if you don't want to use the default.
        prefix: Prefix to add to the file name. If not provided will use the scenario name.
        file_format: File format to write the roadway network to. Should be one of
            RoadwayFileTypes. Defaults to "geojson".
        true_shape: If True, will write the true shape of the roadway network. Defaults to False.
        write: If True, will write the roadway network. Defaults to True.
    """

    def __init__(
        self,
        out_dir: Path = DEFAULT_ROADWAY_OUT_DIR,
        base_path: Path = DEFAULT_BASE_DIR,
        convert_complex_link_properties_to_single_field: bool = False,
        prefix: Optional[str] = None,
        file_format: RoadwayFileTypes = DEFAULT_ROADWAY_OUT_FORMAT,
        true_shape: bool = False,
        write: bool = DEFAULT_ROADWAY_WRITE,
    ):
        """Constructor for RoadwayNetworkOutputConfig."""
        if out_dir is not None and not Path(out_dir).is_absolute():
            self.out_dir = (base_path / Path(out_dir)).resolve()
        else:
            self.out_dir = Path(out_dir)

        self.convert_complex_link_properties_to_single_field = (
            convert_complex_link_properties_to_single_field
        )
        self.prefix = prefix
        self.file_format = file_format
        self.true_shape = true_shape
        self.write = write

__init__(out_dir=DEFAULT_ROADWAY_OUT_DIR, base_path=DEFAULT_BASE_DIR, convert_complex_link_properties_to_single_field=False, prefix=None, file_format=DEFAULT_ROADWAY_OUT_FORMAT, true_shape=False, write=DEFAULT_ROADWAY_WRITE)

Constructor for RoadwayNetworkOutputConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    out_dir: Path = DEFAULT_ROADWAY_OUT_DIR,
    base_path: Path = DEFAULT_BASE_DIR,
    convert_complex_link_properties_to_single_field: bool = False,
    prefix: Optional[str] = None,
    file_format: RoadwayFileTypes = DEFAULT_ROADWAY_OUT_FORMAT,
    true_shape: bool = False,
    write: bool = DEFAULT_ROADWAY_WRITE,
):
    """Constructor for RoadwayNetworkOutputConfig."""
    if out_dir is not None and not Path(out_dir).is_absolute():
        self.out_dir = (base_path / Path(out_dir)).resolve()
    else:
        self.out_dir = Path(out_dir)

    self.convert_complex_link_properties_to_single_field = (
        convert_complex_link_properties_to_single_field
    )
    self.prefix = prefix
    self.file_format = file_format
    self.true_shape = true_shape
    self.write = write

ScenarioConfig

Bases: ConfigItem

Scenario configuration for Network Wrangler.

Attributes:

Name Type Description
base_path

base path of the scenario. Defaults to cwd.

name

Name of the scenario.

base_scenario

information about the base scenario

projects

information about the projects to apply on top of the base scenario

output_scenario

information about how to output the scenario

wrangler_config

wrangler configuration to use

Source code in network_wrangler/configs/scenario.py
class ScenarioConfig(ConfigItem):
    """Scenario configuration for Network Wrangler.

    Attributes:
        base_path: base path of the scenario. Defaults to cwd.
        name: Name of the scenario.
        base_scenario: information about the base scenario
        projects: information about the projects to apply on top of the base scenario
        output_scenario: information about how to output the scenario
        wrangler_config: wrangler configuration to use
    """

    def __init__(
        self,
        base_scenario: dict,
        projects: dict,
        output_scenario: dict,
        base_path: Path = DEFAULT_BASE_DIR,
        name: str = DEFAULT_SCENARIO_NAME,
        wrangler_config=DefaultConfig,
    ):
        """Constructor for ScenarioConfig."""
        self.base_path = Path(base_path) if base_path is not None else Path.cwd()
        self.name = name
        self.base_scenario = ScenarioInputConfig(**base_scenario, base_path=base_path)
        self.projects = ProjectsConfig(**projects, base_path=base_path)
        self.output_scenario = ScenarioOutputConfig(**output_scenario, base_path=base_path)
        self.wrangler_config = wrangler_config

__init__(base_scenario, projects, output_scenario, base_path=DEFAULT_BASE_DIR, name=DEFAULT_SCENARIO_NAME, wrangler_config=DefaultConfig)

Constructor for ScenarioConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    base_scenario: dict,
    projects: dict,
    output_scenario: dict,
    base_path: Path = DEFAULT_BASE_DIR,
    name: str = DEFAULT_SCENARIO_NAME,
    wrangler_config=DefaultConfig,
):
    """Constructor for ScenarioConfig."""
    self.base_path = Path(base_path) if base_path is not None else Path.cwd()
    self.name = name
    self.base_scenario = ScenarioInputConfig(**base_scenario, base_path=base_path)
    self.projects = ProjectsConfig(**projects, base_path=base_path)
    self.output_scenario = ScenarioOutputConfig(**output_scenario, base_path=base_path)
    self.wrangler_config = wrangler_config

ScenarioInputConfig

Bases: ConfigItem

Configuration for the writing the output of a scenario.

Attributes:

Name Type Description
roadway Optional[RoadwayNetworkInputConfig]

Configuration for writing out the roadway network.

transit Optional[TransitNetworkInputConfig]

Configuration for writing out the transit network.

applied_projects

List of projects to apply to the base scenario.

conflicts

Dict of projects that conflict with the applied_projects.

Source code in network_wrangler/configs/scenario.py
class ScenarioInputConfig(ConfigItem):
    """Configuration for the writing the output of a scenario.

    Attributes:
        roadway: Configuration for writing out the roadway network.
        transit: Configuration for writing out the transit network.
        applied_projects: List of projects to apply to the base scenario.
        conflicts: Dict of projects that conflict with the applied_projects.
    """

    def __init__(
        self,
        base_path: Path = DEFAULT_BASE_DIR,
        roadway: Optional[dict] = None,
        transit: Optional[dict] = None,
        applied_projects: Optional[list[str]] = None,
        conflicts: Optional[dict] = None,
    ):
        """Constructor for ScenarioInputConfig."""
        if roadway is not None:
            self.roadway: Optional[RoadwayNetworkInputConfig] = RoadwayNetworkInputConfig(
                **roadway, base_path=base_path
            )
        else:
            self.roadway = None

        if transit is not None:
            self.transit: Optional[TransitNetworkInputConfig] = TransitNetworkInputConfig(
                **transit, base_path=base_path
            )
        else:
            self.transit = None

        self.applied_projects = applied_projects if applied_projects is not None else []
        self.conflicts = conflicts if conflicts is not None else {}

__init__(base_path=DEFAULT_BASE_DIR, roadway=None, transit=None, applied_projects=None, conflicts=None)

Constructor for ScenarioInputConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    base_path: Path = DEFAULT_BASE_DIR,
    roadway: Optional[dict] = None,
    transit: Optional[dict] = None,
    applied_projects: Optional[list[str]] = None,
    conflicts: Optional[dict] = None,
):
    """Constructor for ScenarioInputConfig."""
    if roadway is not None:
        self.roadway: Optional[RoadwayNetworkInputConfig] = RoadwayNetworkInputConfig(
            **roadway, base_path=base_path
        )
    else:
        self.roadway = None

    if transit is not None:
        self.transit: Optional[TransitNetworkInputConfig] = TransitNetworkInputConfig(
            **transit, base_path=base_path
        )
    else:
        self.transit = None

    self.applied_projects = applied_projects if applied_projects is not None else []
    self.conflicts = conflicts if conflicts is not None else {}

ScenarioOutputConfig

Bases: ConfigItem

Configuration for the writing the output of a scenario.

Attributes:

Name Type Description
roadway

Configuration for writing out the roadway network.

transit

Configuration for writing out the transit network.

project_cards Optional[ProjectCardOutputConfig]

Configuration for writing out the project cards.

overwrite

If True, will overwrite the files if they already exist. Defaults to True

Source code in network_wrangler/configs/scenario.py
class ScenarioOutputConfig(ConfigItem):
    """Configuration for the writing the output of a scenario.

    Attributes:
        roadway: Configuration for writing out the roadway network.
        transit: Configuration for writing out the transit network.
        project_cards: Configuration for writing out the project cards.
        overwrite: If True, will overwrite the files if they already exist. Defaults to True
    """

    def __init__(
        self,
        path: Path = DEFAULT_OUTPUT_DIR,
        base_path: Path = DEFAULT_BASE_DIR,
        roadway: Optional[dict] = None,
        transit: Optional[dict] = None,
        project_cards: Optional[dict] = None,
        overwrite: bool = True,
    ):
        """Constructor for ScenarioOutputConfig."""
        if not Path(path).is_absolute():
            self.path = (base_path / Path(path)).resolve()
        else:
            self.path = Path(path)

        roadway = roadway if roadway else RoadwayNetworkOutputConfig().to_dict()
        transit = transit if transit else TransitNetworkOutputConfig().to_dict()
        self.roadway = RoadwayNetworkOutputConfig(**roadway, base_path=self.path)
        self.transit = TransitNetworkOutputConfig(**transit, base_path=self.path)

        if project_cards is not None:
            self.project_cards: Optional[ProjectCardOutputConfig] = ProjectCardOutputConfig(
                **project_cards, base_path=self.path
            )
        else:
            self.project_cards = None

        self.overwrite = overwrite

__init__(path=DEFAULT_OUTPUT_DIR, base_path=DEFAULT_BASE_DIR, roadway=None, transit=None, project_cards=None, overwrite=True)

Constructor for ScenarioOutputConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    path: Path = DEFAULT_OUTPUT_DIR,
    base_path: Path = DEFAULT_BASE_DIR,
    roadway: Optional[dict] = None,
    transit: Optional[dict] = None,
    project_cards: Optional[dict] = None,
    overwrite: bool = True,
):
    """Constructor for ScenarioOutputConfig."""
    if not Path(path).is_absolute():
        self.path = (base_path / Path(path)).resolve()
    else:
        self.path = Path(path)

    roadway = roadway if roadway else RoadwayNetworkOutputConfig().to_dict()
    transit = transit if transit else TransitNetworkOutputConfig().to_dict()
    self.roadway = RoadwayNetworkOutputConfig(**roadway, base_path=self.path)
    self.transit = TransitNetworkOutputConfig(**transit, base_path=self.path)

    if project_cards is not None:
        self.project_cards: Optional[ProjectCardOutputConfig] = ProjectCardOutputConfig(
            **project_cards, base_path=self.path
        )
    else:
        self.project_cards = None

    self.overwrite = overwrite

TransitNetworkInputConfig

Bases: ConfigItem

Configuration for the transit network in a scenario.

Attributes:

Name Type Description
dir

Path to the transit network files. Defaults to “.”.

file_format

File format for the transit network files. Should be one of TransitFileTypes. Defaults to “txt”.

Source code in network_wrangler/configs/scenario.py
class TransitNetworkInputConfig(ConfigItem):
    """Configuration for the transit network in a scenario.

    Attributes:
        dir: Path to the transit network files. Defaults to ".".
        file_format: File format for the transit network files. Should be one of TransitFileTypes.
            Defaults to "txt".
    """

    def __init__(
        self,
        base_path: Path = DEFAULT_BASE_DIR,
        dir: Path = DEFAULT_TRANSIT_IN_DIR,
        file_format: TransitFileTypes = DEFAULT_TRANSIT_IN_FORMAT,
    ):
        """Constructor for TransitNetworkInputConfig."""
        if dir is not None and not Path(dir).is_absolute():
            self.feed = (base_path / Path(dir)).resolve()
        else:
            self.feed = Path(dir)
        self.file_format = file_format

__init__(base_path=DEFAULT_BASE_DIR, dir=DEFAULT_TRANSIT_IN_DIR, file_format=DEFAULT_TRANSIT_IN_FORMAT)

Constructor for TransitNetworkInputConfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    base_path: Path = DEFAULT_BASE_DIR,
    dir: Path = DEFAULT_TRANSIT_IN_DIR,
    file_format: TransitFileTypes = DEFAULT_TRANSIT_IN_FORMAT,
):
    """Constructor for TransitNetworkInputConfig."""
    if dir is not None and not Path(dir).is_absolute():
        self.feed = (base_path / Path(dir)).resolve()
    else:
        self.feed = Path(dir)
    self.file_format = file_format

TransitNetworkOutputConfig

Bases: ConfigItem

Configuration for the transit network in a scenario.

Attributes:

Name Type Description
out_dir

Path to write the transit network files to if you don’t want to use the default.

prefix

Prefix to add to the file name. If not provided will use the scenario name.

file_format

File format to write the transit network to. Should be one of TransitFileTypes. Defaults to “txt”.

write

If True, will write the transit network. Defaults to True.

Source code in network_wrangler/configs/scenario.py
class TransitNetworkOutputConfig(ConfigItem):
    """Configuration for the transit network in a scenario.

    Attributes:
        out_dir: Path to write the transit network files to if you don't want to use the default.
        prefix: Prefix to add to the file name. If not provided will use the scenario name.
        file_format: File format to write the transit network to. Should be one of
            TransitFileTypes. Defaults to "txt".
        write: If True, will write the transit network. Defaults to True.
    """

    def __init__(
        self,
        base_path: Path = DEFAULT_BASE_DIR,
        out_dir: Path = DEFAULT_TRANSIT_OUT_DIR,
        prefix: Optional[str] = None,
        file_format: TransitFileTypes = DEFAULT_TRANSIT_OUT_FORMAT,
        write: bool = DEFAULT_TRANSIT_WRITE,
    ):
        """Constructor for TransitNetworkOutputCOnfig."""
        if out_dir is not None and not Path(out_dir).is_absolute():
            self.out_dir = (base_path / Path(out_dir)).resolve()
        else:
            self.out_dir = Path(out_dir)
        self.write = write
        self.prefix = prefix
        self.file_format = file_format

__init__(base_path=DEFAULT_BASE_DIR, out_dir=DEFAULT_TRANSIT_OUT_DIR, prefix=None, file_format=DEFAULT_TRANSIT_OUT_FORMAT, write=DEFAULT_TRANSIT_WRITE)

Constructor for TransitNetworkOutputCOnfig.

Source code in network_wrangler/configs/scenario.py
def __init__(
    self,
    base_path: Path = DEFAULT_BASE_DIR,
    out_dir: Path = DEFAULT_TRANSIT_OUT_DIR,
    prefix: Optional[str] = None,
    file_format: TransitFileTypes = DEFAULT_TRANSIT_OUT_FORMAT,
    write: bool = DEFAULT_TRANSIT_WRITE,
):
    """Constructor for TransitNetworkOutputCOnfig."""
    if out_dir is not None and not Path(out_dir).is_absolute():
        self.out_dir = (base_path / Path(out_dir)).resolve()
    else:
        self.out_dir = Path(out_dir)
    self.write = write
    self.prefix = prefix
    self.file_format = file_format

Projects

Projects are how you manipulate the networks. Each project type is defined in a module in the projects folder and accepts a RoadwayNetwork and or TransitNetwork as an input and returns the same objects (manipulated) as an output.

Roadway

The roadway module contains submodules which define and extend the links, nodes, and shapes dataframe objects which within a RoadwayNetwork object as well as other classes and methods which support and extend the RoadwayNetwork class.

Roadway Network Objects

Submodules which define and extend the links, nodes, and shapes dataframe objects which within a RoadwayNetwork object. Includes classes which define:

  • dataframe schemas to be used for dataframe validation using pandera
  • methods which extend the dataframes

:: network_wrangler.roadway.links.io options: heading_level: 5 :: network_wrangler.roadway.links.create options: heading_level: 5 :: network_wrangler.roadway.links.delete options: heading_level: 5 :: network_wrangler.roadway.links.edit options: heading_level: 5 :: network_wrangler.roadway.links.filters options: heading_level: 5 :: network_wrangler.roadway.links.geo options: heading_level: 5 :: network_wrangler.roadway.links.scopes options: heading_level: 5 :: network_wrangler.roadway.links.summary options: heading_level: 5 :: network_wrangler.roadway.links.validate options: heading_level: 5 :: network_wrangler.roadway.links.df_accessors options: heading_level: 5

Roadway Nodes

:: network_wrangler.roadway.nodes.io options: heading_level: 5 :: network_wrangler.roadway.nodes.create options: heading_level: 5 :: network_wrangler.roadway.nodes.delete options: heading_level: 5 :: network_wrangler.roadway.nodes.edit options: heading_level: 5 :: network_wrangler.roadway.nodes.filters options: heading_level: 5 :: network_wrangler.roadway.nodes options: heading_level: 5

Roadway Shapes

:: network_wrangler.roadway.shapes.io options: heading_level: 5 :: network_wrangler.roadway.shapes.create options: heading_level: 5 :: network_wrangler.roadway.shapes.edit options: heading_level: 5 :: network_wrangler.roadway.shapes.delete options: heading_level: 5 :: network_wrangler.roadway.shapes.filters options: heading_level: 5 :: network_wrangler.roadway.shapes.shapes options: heading_level: 5

Roadway Projects

:: network_wrangler.roadway.projects.add options: heading_level: 4 :: network_wrangler.roadway.projects.calculate options: heading_level: 4 :: network_wrangler.roadway.projects.delete options: heading_level: 4 :: network_wrangler.roadway.projects.edit_property options: heading_level: 4

Roadway Supporting Modules

:: network_wrangler.roadway.io options: heading_level: 4 :: network_wrangler.roadway.clip options: heading_level: 4 :: network_wrangler.roadway.model_roadway options: heading_level: 4 :: network_wrangler.roadway.utils options: heading_level: 4 :: network_wrangler.roadway.validate options: heading_level: 4 :: network_wrangler.roadway.segment options: heading_level: 4 :: network_wrangler.roadway.subnet options: heading_level: 4 :: network_wrangler.roadway.graph options: heading_level: 4

Transit

Feed

Main functionality for GTFS tables including Feed object.

Feed

Bases: DBModelMixin

Wrapper class around Wrangler flavored GTFS feed.

Most functionality derives from mixin class DBModelMixin which provides:

  • validation of tables to schemas when setting a table attribute (e.g. self.trips = trips_df)
  • validation of fks when setting a table attribute (e.g. self.trips = trips_df)
  • hashing and deep copy functionality
  • overload of eq to apply only to tables in table_names.
  • convenience methods for accessing tables

Attributes:

Name Type Description
table_names list[str]

list of table names in GTFS feed.

tables list[DataFrame]

: list tables as dataframes.

stop_times DataFrame[WranglerStopTimesTable]

: stop_times dataframe with roadway node_ids

stops DataFrame[WranglerStopsTable]

stops dataframe

shapes(DataFrame[WranglerShapesTable]) DataFrame[WranglerStopsTable]

shapes dataframe

trips DataFrame[WranglerTripsTable]

trips dataframe

frequencies DataFrame[WranglerFrequenciesTable]

frequencies dataframe

routes DataFrame[RoutesTable]

route dataframe

agencies Optional[DataFrame[AgenciesTable]]

agencies dataframe

net Optional[TransitNetwork]

TransitNetwork object

Source code in network_wrangler/transit/feed/feed.py
class Feed(DBModelMixin):
    """Wrapper class around Wrangler flavored GTFS feed.

    Most functionality derives from mixin class DBModelMixin which provides:

    - validation of tables to schemas when setting a table attribute (e.g. self.trips = trips_df)
    - validation of fks when setting a table attribute (e.g. self.trips = trips_df)
    - hashing and deep copy functionality
    - overload of __eq__ to apply only to tables in table_names.
    - convenience methods for accessing tables

    Attributes:
        table_names (list[str]): list of table names in GTFS feed.
        tables (list[DataFrame]):: list tables as dataframes.
        stop_times (DataFrame[WranglerStopTimesTable]):: stop_times dataframe with roadway node_ids
        stops (DataFrame[WranglerStopsTable]):stops dataframe
        shapes(DataFrame[WranglerShapesTable]): shapes dataframe
        trips (DataFrame[WranglerTripsTable]): trips dataframe
        frequencies (DataFrame[WranglerFrequenciesTable]): frequencies dataframe
        routes (DataFrame[RoutesTable]): route dataframe
        agencies (Optional[DataFrame[AgenciesTable]]): agencies dataframe
        net (Optional[TransitNetwork]): TransitNetwork object
    """

    # the ordering here matters because the stops need to be added before stop_times if
    # stop times needs to be converted
    _table_models: ClassVar[dict] = {
        "agencies": AgenciesTable,
        "frequencies": WranglerFrequenciesTable,
        "routes": RoutesTable,
        "shapes": WranglerShapesTable,
        "stops": WranglerStopsTable,
        "trips": WranglerTripsTable,
        "stop_times": WranglerStopTimesTable,
    }

    # Define the converters if the table needs to be converted to a Wrangler table.
    # Format: "table_name": converter_function
    _converters: ClassVar[dict[str, Callable]] = {}

    table_names: ClassVar[list[str]] = [
        "frequencies",
        "routes",
        "shapes",
        "stops",
        "trips",
        "stop_times",
    ]

    optional_table_names: ClassVar[list[str]] = ["agencies"]

    def __init__(self, **kwargs):
        """Create a Feed object from a dictionary of DataFrames representing a GTFS feed.

        Args:
            kwargs: A dictionary containing DataFrames representing the tables of a GTFS feed.
        """
        self._net = None
        self.feed_path: Path = None
        self.initialize_tables(**kwargs)

        # Set extra provided attributes but just FYI in logger.
        extra_attr = {k: v for k, v in kwargs.items() if k not in self.table_names}
        if extra_attr:
            WranglerLogger.info(f"Adding additional attributes to Feed: {extra_attr.keys()}")
        for k, v in extra_attr:
            self.__setattr__(k, v)

    def set_by_id(
        self,
        table_name: str,
        set_df: pd.DataFrame,
        id_property: str = "index",
        properties: Optional[list[str]] = None,
    ):
        """Set one or more property values based on an ID property for a given table.

        Args:
            table_name (str): Name of the table to modify.
            set_df (pd.DataFrame): DataFrame with columns `<id_property>` and `value` containing
                values to set for the specified property where `<id_property>` is unique.
            id_property: Property to use as ID to set by. Defaults to "index".
            properties: List of properties to set which are in set_df. If not specified, will set
                all properties.
        """
        if not set_df[id_property].is_unique:
            msg = f"{id_property} must be unique in set_df."
            _dupes = set_df[id_property][set_df[id_property].duplicated()]
            WranglerLogger.error(msg + f"Found duplicates: {_dupes.sum()}")

            raise ValueError(msg)
        table_df = self.get_table(table_name)
        updated_df = update_df_by_col_value(table_df, set_df, id_property, properties=properties)
        self.__dict__[table_name] = updated_df
__init__(**kwargs)

Create a Feed object from a dictionary of DataFrames representing a GTFS feed.

Parameters:

Name Type Description Default
kwargs

A dictionary containing DataFrames representing the tables of a GTFS feed.

{}
Source code in network_wrangler/transit/feed/feed.py
def __init__(self, **kwargs):
    """Create a Feed object from a dictionary of DataFrames representing a GTFS feed.

    Args:
        kwargs: A dictionary containing DataFrames representing the tables of a GTFS feed.
    """
    self._net = None
    self.feed_path: Path = None
    self.initialize_tables(**kwargs)

    # Set extra provided attributes but just FYI in logger.
    extra_attr = {k: v for k, v in kwargs.items() if k not in self.table_names}
    if extra_attr:
        WranglerLogger.info(f"Adding additional attributes to Feed: {extra_attr.keys()}")
    for k, v in extra_attr:
        self.__setattr__(k, v)
set_by_id(table_name, set_df, id_property='index', properties=None)

Set one or more property values based on an ID property for a given table.

Parameters:

Name Type Description Default
table_name str

Name of the table to modify.

required
set_df DataFrame

DataFrame with columns <id_property> and value containing values to set for the specified property where <id_property> is unique.

required
id_property str

Property to use as ID to set by. Defaults to “index”.

'index'
properties Optional[list[str]]

List of properties to set which are in set_df. If not specified, will set all properties.

None
Source code in network_wrangler/transit/feed/feed.py
def set_by_id(
    self,
    table_name: str,
    set_df: pd.DataFrame,
    id_property: str = "index",
    properties: Optional[list[str]] = None,
):
    """Set one or more property values based on an ID property for a given table.

    Args:
        table_name (str): Name of the table to modify.
        set_df (pd.DataFrame): DataFrame with columns `<id_property>` and `value` containing
            values to set for the specified property where `<id_property>` is unique.
        id_property: Property to use as ID to set by. Defaults to "index".
        properties: List of properties to set which are in set_df. If not specified, will set
            all properties.
    """
    if not set_df[id_property].is_unique:
        msg = f"{id_property} must be unique in set_df."
        _dupes = set_df[id_property][set_df[id_property].duplicated()]
        WranglerLogger.error(msg + f"Found duplicates: {_dupes.sum()}")

        raise ValueError(msg)
    table_df = self.get_table(table_name)
    updated_df = update_df_by_col_value(table_df, set_df, id_property, properties=properties)
    self.__dict__[table_name] = updated_df

merge_shapes_to_stop_times(stop_times, shapes, trips)

Add shape_id and shape_pt_sequence to stop_times dataframe.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

stop_times dataframe to add shape_id and shape_pt_sequence to.

required
shapes DataFrame[WranglerShapesTable]

shapes dataframe to add to stop_times.

required
trips DataFrame[WranglerTripsTable]

trips dataframe to link stop_times to shapes

required

Returns:

Type Description
DataFrame[WranglerStopTimesTable]

stop_times dataframe with shape_id and shape_pt_sequence added.

Source code in network_wrangler/transit/feed/feed.py
def merge_shapes_to_stop_times(
    stop_times: DataFrame[WranglerStopTimesTable],
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
) -> DataFrame[WranglerStopTimesTable]:
    """Add shape_id and shape_pt_sequence to stop_times dataframe.

    Args:
        stop_times: stop_times dataframe to add shape_id and shape_pt_sequence to.
        shapes: shapes dataframe to add to stop_times.
        trips: trips dataframe to link stop_times to shapes

    Returns:
        stop_times dataframe with shape_id and shape_pt_sequence added.
    """
    stop_times_w_shape_id = stop_times.merge(
        trips[["trip_id", "shape_id"]], on="trip_id", how="left"
    )

    stop_times_w_shapes = stop_times_w_shape_id.merge(
        shapes,
        how="left",
        left_on=["shape_id", "stop_id"],
        right_on=["shape_id", "shape_model_node_id"],
    )
    stop_times_w_shapes = stop_times_w_shapes.drop(columns=["shape_model_node_id"])
    return stop_times_w_shapes

stop_count_by_trip(stop_times)

Returns dataframe with trip_id and stop_count from stop_times.

Source code in network_wrangler/transit/feed/feed.py
def stop_count_by_trip(
    stop_times: DataFrame[WranglerStopTimesTable],
) -> pd.DataFrame:
    """Returns dataframe with trip_id and stop_count from stop_times."""
    stops_count = stop_times.groupby("trip_id").size()
    return stops_count.reset_index(name="stop_count")

Filters and queries of a gtfs frequencies table.

frequencies_for_trips(frequencies, trips)

Filter frequenceis dataframe to records associated with trips table.

Source code in network_wrangler/transit/feed/frequencies.py
def frequencies_for_trips(
    frequencies: DataFrame[WranglerFrequenciesTable], trips: DataFrame[WranglerTripsTable]
) -> DataFrame[WranglerFrequenciesTable]:
    """Filter frequenceis dataframe to records associated with trips table."""
    _sel_trips = trips.trip_id.unique().tolist()
    filtered_frequencies = frequencies[frequencies.trip_id.isin(_sel_trips)]
    WranglerLogger.debug(
        f"Filtered frequencies to {len(filtered_frequencies)}/{len(frequencies)} \
                         records that referenced one of {len(trips)} trips."
    )
    return filtered_frequencies

Filters and queries of a gtfs routes table and route_ids.

route_ids_for_trip_ids(trips, trip_ids)

Returns route ids for given list of trip_ids.

Source code in network_wrangler/transit/feed/routes.py
def route_ids_for_trip_ids(trips: DataFrame[WranglerTripsTable], trip_ids: list[str]) -> list[str]:
    """Returns route ids for given list of trip_ids."""
    return trips[trips["trip_id"].isin(trip_ids)].route_id.unique().tolist()

routes_for_trip_ids(routes, trips, trip_ids)

Returns route records for given list of trip_ids.

Source code in network_wrangler/transit/feed/routes.py
def routes_for_trip_ids(
    routes: DataFrame[RoutesTable], trips: DataFrame[WranglerTripsTable], trip_ids: list[str]
) -> DataFrame[RoutesTable]:
    """Returns route records for given list of trip_ids."""
    route_ids = route_ids_for_trip_ids(trips, trip_ids)
    return routes.loc[routes.route_id.isin(route_ids)]

routes_for_trips(routes, trips)

Filter routes dataframe to records associated with trip records.

Source code in network_wrangler/transit/feed/routes.py
def routes_for_trips(
    routes: DataFrame[RoutesTable], trips: DataFrame[WranglerTripsTable]
) -> DataFrame[RoutesTable]:
    """Filter routes dataframe to records associated with trip records."""
    _sel_routes = trips.route_id.unique().tolist()
    filtered_routes = routes[routes.route_id.isin(_sel_routes)]
    WranglerLogger.debug(
        f"Filtered routes to {len(filtered_routes)}/{len(routes)} \
                         records that referenced one of {len(trips)} trips."
    )
    return filtered_routes

Filters, queries of a gtfs shapes table and node patterns.

find_nearest_stops(shapes, trips, stop_times, trip_id, node_id, pickup_dropoff='either')

Returns node_ids (before and after) of nearest node_ids that are stops for a given trip_id.

Parameters:

Name Type Description Default
shapes WranglerShapesTable

WranglerShapesTable

required
trips WranglerTripsTable

WranglerTripsTable

required
stop_times WranglerStopTimesTable

WranglerStopTimesTable

required
trip_id str

trip id to find nearest stops for

required
node_id int

node_id to find nearest stops for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'

Returns:

Name Type Description
tuple tuple[int, int]

node_ids for stop before and stop after

Source code in network_wrangler/transit/feed/shapes.py
def find_nearest_stops(
    shapes: WranglerShapesTable,
    trips: WranglerTripsTable,
    stop_times: WranglerStopTimesTable,
    trip_id: str,
    node_id: int,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> tuple[int, int]:
    """Returns node_ids (before and after) of nearest node_ids that are stops for a given trip_id.

    Args:
        shapes: WranglerShapesTable
        trips: WranglerTripsTable
        stop_times: WranglerStopTimesTable
        trip_id: trip id to find nearest stops for
        node_id: node_id to find nearest stops for
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0

    Returns:
        tuple: node_ids for stop before and stop after
    """
    shapes = shapes_with_stop_id_for_trip_id(
        shapes, trips, stop_times, trip_id, pickup_dropoff=pickup_dropoff
    )
    WranglerLogger.debug(f"Looking for stops near node_id: {node_id}")
    if node_id not in shapes["shape_model_node_id"].values:
        msg = f"Node ID {node_id} not in shapes for trip {trip_id}"
        raise ValueError(msg)
    # Find index of node_id in shapes
    node_idx = shapes[shapes["shape_model_node_id"] == node_id].index[0]

    # Find stops before and after new stop in shapes sequence
    nodes_before = shapes.loc[: node_idx - 1]
    stops_before = nodes_before.loc[nodes_before["stop_id"].notna()]
    stop_node_before = 0 if stops_before.empty else stops_before.iloc[-1]["shape_model_node_id"]

    nodes_after = shapes.loc[node_idx + 1 :]
    stops_after = nodes_after.loc[nodes_after["stop_id"].notna()]
    stop_node_after = 0 if stops_after.empty else stops_after.iloc[0]["shape_model_node_id"]

    return stop_node_before, stop_node_after

node_pattern_for_shape_id(shapes, shape_id)

Returns node pattern of a shape.

Source code in network_wrangler/transit/feed/shapes.py
def node_pattern_for_shape_id(shapes: DataFrame[WranglerShapesTable], shape_id: str) -> list[int]:
    """Returns node pattern of a shape."""
    shape_df = shapes.loc[shapes["shape_id"] == shape_id]
    shape_df = shape_df.sort_values(by=["shape_pt_sequence"])
    return shape_df["shape_model_node_id"].to_list()

shape_id_for_trip_id(trips, trip_id)

Returns a shape_id for a given trip_id.

Source code in network_wrangler/transit/feed/shapes.py
def shape_id_for_trip_id(trips: WranglerTripsTable, trip_id: str) -> str:
    """Returns a shape_id for a given trip_id."""
    return trips.loc[trips.trip_id == trip_id, "shape_id"].values[0]

shape_ids_for_trip_ids(trips, trip_ids)

Returns a list of shape_ids for a given list of trip_ids.

Source code in network_wrangler/transit/feed/shapes.py
def shape_ids_for_trip_ids(trips: DataFrame[WranglerTripsTable], trip_ids: list[str]) -> list[str]:
    """Returns a list of shape_ids for a given list of trip_ids."""
    return trips[trips["trip_id"].isin(trip_ids)].shape_id.unique().tolist()

Filter shapes dataframe to records associated with links dataframe.

EX:

shapes = pd.DataFrame({ “shape_id”: [“1”, “1”, “1”, “1”, “2”, “2”, “2”, “2”, “2”], “shape_pt_sequence”: [1, 2, 3, 4, 1, 2, 3, 4, 5], “shape_model_node_id”: [1, 2, 3, 4, 2, 3, 1, 5, 4] })

links_df = pd.DataFrame({ “A”: [1, 2, 3], “B”: [2, 3, 4] })

shapes

shape_id shape_pt_sequence shape_model_node_id should retain 1 1 1 TRUE 1 2 2 TRUE 1 3 3 TRUE 1 4 4 TRUE 1 5 5 FALSE 2 1 1 TRUE 2 2 2 TRUE 2 3 3 TRUE 2 4 1 FALSE 2 5 5 FALSE 2 6 4 FALSE 2 7 1 FALSE - not largest segment 2 8 2 FALSE - not largest segment

links_df

A B 1 2 2 3 3 4

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_road_links(
    shapes: DataFrame[WranglerShapesTable], links_df: pd.DataFrame
) -> DataFrame[WranglerShapesTable]:
    """Filter shapes dataframe to records associated with links dataframe.

    EX:

    > shapes = pd.DataFrame({
        "shape_id": ["1", "1", "1", "1", "2", "2", "2", "2", "2"],
        "shape_pt_sequence": [1, 2, 3, 4, 1, 2, 3, 4, 5],
        "shape_model_node_id": [1, 2, 3, 4, 2, 3, 1, 5, 4]
    })

    > links_df = pd.DataFrame({
        "A": [1, 2, 3],
        "B": [2, 3, 4]
    })

    > shapes

    shape_id   shape_pt_sequence   shape_model_node_id *should retain*
    1          1                  1                        TRUE
    1          2                  2                        TRUE
    1          3                  3                        TRUE
    1          4                  4                        TRUE
    1          5                  5                       FALSE
    2          1                  1                        TRUE
    2          2                  2                        TRUE
    2          3                  3                        TRUE
    2          4                  1                       FALSE
    2          5                  5                       FALSE
    2          6                  4                       FALSE
    2          7                  1                       FALSE - not largest segment
    2          8                  2                       FALSE - not largest segment

    > links_df

    A   B
    1   2
    2   3
    3   4
    """
    """
    > shape_links

    shape_id  shape_pt_sequence_A  shape_model_node_id_A shape_pt_sequence_B shape_model_node_id_B
    1          1                        1                       2                        2
    1          2                        2                       3                        3
    1          3                        3                       4                        4
    1          4                        4                       5                        5
    2          1                        1                       2                        2
    2          2                        2                       3                        3
    2          3                        3                       4                        1
    2          4                        1                       5                        5
    2          5                        5                       6                        4
    2          6                        4                       7                        1
    2          7                        1                       8                        2
    """
    shape_links = shapes_to_shape_links(shapes)

    """
    > shape_links_w_links

    shape_id  shape_pt_sequence_A shape_pt_sequence_B  A  B
    1          1                         2             1  2
    1          2                         3             2  3
    1          3                         4             3  4
    2          1                         2             1  2
    2          2                         3             2  3
    2          7                         8             1  2
    """

    shape_links_w_links = shape_links.merge(
        links_df[["A", "B"]],
        how="inner",
        on=["A", "B"],
    )

    """
    Find largest segment of each shape_id that is in the links

    > longest_shape_segments
    shape_id, segment_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq
    1          1                        1                       4
    2          1                        1                       3
    """
    longest_shape_segments = shape_links_to_longest_shape_segments(shape_links_w_links)

    """
    > shapes

    shape_id   shape_pt_sequence   shape_model_node_id
    1          1                  1
    1          2                  2
    1          3                  3
    1          4                  4
    2          1                  1
    2          2                  2
    2          3                  3
    """
    filtered_shapes = filter_shapes_to_segments(shapes, longest_shape_segments)
    filtered_shapes = filtered_shapes.reset_index(drop=True)
    return filtered_shapes

shapes_for_shape_id(shapes, shape_id)

Returns shape records for a given shape_id.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_shape_id(
    shapes: DataFrame[WranglerShapesTable], shape_id: str
) -> DataFrame[WranglerShapesTable]:
    """Returns shape records for a given shape_id."""
    shapes = shapes.loc[shapes.shape_id == shape_id]
    return shapes.sort_values(by=["shape_pt_sequence"])

shapes_for_trip_id(shapes, trips, trip_id)

Returns shape records for a single given trip_id.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_trip_id(
    shapes: DataFrame[WranglerShapesTable], trips: DataFrame[WranglerTripsTable], trip_id: str
) -> DataFrame[WranglerShapesTable]:
    """Returns shape records for a single given trip_id."""
    from .shapes import shape_id_for_trip_id

    shape_id = shape_id_for_trip_id(trips, trip_id)
    return shapes.loc[shapes.shape_id == shape_id]

shapes_for_trip_ids(shapes, trips, trip_ids)

Returns shape records for list of trip_ids.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_trip_ids(
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
    trip_ids: list[str],
) -> DataFrame[WranglerShapesTable]:
    """Returns shape records for list of trip_ids."""
    shape_ids = shape_ids_for_trip_ids(trips, trip_ids)
    return shapes.loc[shapes.shape_id.isin(shape_ids)]

shapes_for_trips(shapes, trips)

Filter shapes dataframe to records associated with trips table.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_for_trips(
    shapes: DataFrame[WranglerShapesTable], trips: DataFrame[WranglerTripsTable]
) -> DataFrame[WranglerShapesTable]:
    """Filter shapes dataframe to records associated with trips table."""
    _sel_shapes = trips.shape_id.unique().tolist()
    filtered_shapes = shapes[shapes.shape_id.isin(_sel_shapes)]
    WranglerLogger.debug(
        f"Filtered shapes to {len(filtered_shapes)}/{len(shapes)} \
                         records that referenced one of {len(trips)} trips."
    )
    return filtered_shapes

shapes_with_stop_id_for_trip_id(shapes, trips, stop_times, trip_id, pickup_dropoff='either')

Returns shapes.txt for a given trip_id with the stop_id added based on pickup_type.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

WranglerShapesTable

required
trips DataFrame[WranglerTripsTable]

WranglerTripsTable

required
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
trip_id str

trip id to select

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'
Source code in network_wrangler/transit/feed/shapes.py
def shapes_with_stop_id_for_trip_id(
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> DataFrame[WranglerShapesTable]:
    """Returns shapes.txt for a given trip_id with the stop_id added based on pickup_type.

    Args:
        shapes: WranglerShapesTable
        trips: WranglerTripsTable
        stop_times: WranglerStopTimesTable
        trip_id: trip id to select
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0
    """
    from .stop_times import stop_times_for_pickup_dropoff_trip_id

    shapes = shapes_for_trip_id(shapes, trips, trip_id)
    trip_stop_times = stop_times_for_pickup_dropoff_trip_id(
        stop_times, trip_id, pickup_dropoff=pickup_dropoff
    )

    stop_times_cols = [
        "stop_id",
        "trip_id",
        "pickup_type",
        "drop_off_type",
    ]

    shape_with_trip_stops = shapes.merge(
        trip_stop_times[stop_times_cols],
        how="left",
        right_on="stop_id",
        left_on="shape_model_node_id",
    )
    shape_with_trip_stops = shape_with_trip_stops.sort_values(by=["shape_pt_sequence"])
    return shape_with_trip_stops

shapes_with_stops_for_shape_id(shapes, trips, stop_times, shape_id)

Returns a DataFrame containing shapes with associated stops for a given shape_id.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

DataFrame containing shape data.

required
trips DataFrame[WranglerTripsTable]

DataFrame containing trip data.

required
stop_times DataFrame[WranglerStopTimesTable]

DataFrame containing stop times data.

required
shape_id str

The shape_id for which to retrieve shapes with stops.

required

Returns:

Type Description
DataFrame[WranglerShapesTable]

DataFrame[WranglerShapesTable]: DataFrame containing shapes with associated stops.

Source code in network_wrangler/transit/feed/shapes.py
def shapes_with_stops_for_shape_id(
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    shape_id: str,
) -> DataFrame[WranglerShapesTable]:
    """Returns a DataFrame containing shapes with associated stops for a given shape_id.

    Parameters:
        shapes (DataFrame[WranglerShapesTable]): DataFrame containing shape data.
        trips (DataFrame[WranglerTripsTable]): DataFrame containing trip data.
        stop_times (DataFrame[WranglerStopTimesTable]): DataFrame containing stop times data.
        shape_id (str): The shape_id for which to retrieve shapes with stops.

    Returns:
        DataFrame[WranglerShapesTable]: DataFrame containing shapes with associated stops.
    """
    from .trips import trip_ids_for_shape_id

    trip_ids = trip_ids_for_shape_id(trips, shape_id)
    all_shape_stop_times = concat_with_attr(
        [shapes_with_stop_id_for_trip_id(shapes, trips, stop_times, t) for t in trip_ids]
    )
    shapes_with_stops = all_shape_stop_times[all_shape_stop_times["stop_id"].notna()]
    shapes_with_stops = shapes_with_stops.sort_values(by=["shape_pt_sequence"])
    return shapes_with_stops

Filters and queries of a gtfs stop_times table.

stop_times_for_longest_segments(stop_times)

Find the longest segment of each trip_id that is in the stop_times.

Segment ends defined based on interruptions in stop_sequence.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_longest_segments(
    stop_times: DataFrame[WranglerStopTimesTable],
) -> pd.DataFrame:
    """Find the longest segment of each trip_id that is in the stop_times.

    Segment ends defined based on interruptions in `stop_sequence`.
    """
    stop_times = stop_times.sort_values(by=["trip_id", "stop_sequence"])

    stop_times["prev_stop_sequence"] = stop_times.groupby("trip_id")["stop_sequence"].shift(1)
    stop_times["gap"] = (stop_times["stop_sequence"] - stop_times["prev_stop_sequence"]).ne(
        1
    ) | stop_times["prev_stop_sequence"].isna()

    stop_times["segment_id"] = stop_times["gap"].cumsum()
    # WranglerLogger.debug(f"stop_times with segment_id:\n{stop_times}")

    # Calculate the length of each segment
    segment_lengths = (
        stop_times.groupby(["trip_id", "segment_id"]).size().reset_index(name="segment_length")
    )

    # Identify the longest segment for each trip
    idx = segment_lengths.groupby("trip_id")["segment_length"].idxmax()
    longest_segments = segment_lengths.loc[idx]

    # Merge longest segment info back to stop_times
    stop_times = stop_times.merge(
        longest_segments[["trip_id", "segment_id"]],
        on=["trip_id", "segment_id"],
        how="inner",
    )

    # Drop temporary columns used for calculations
    stop_times.drop(columns=["prev_stop_sequence", "gap", "segment_id"], inplace=True)
    # WranglerLogger.debug(f"stop_timesw/longest segments:\n{stop_times}")
    return stop_times

stop_times_for_min_stops(stop_times, min_stops)

Filter stop_times dataframe to only the records which have >= min_stops for the trip.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

stoptimestable to filter

required
min_stops int

minimum stops to require to keep trip in stoptimes

required
Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_min_stops(
    stop_times: DataFrame[WranglerStopTimesTable], min_stops: int
) -> DataFrame[WranglerStopTimesTable]:
    """Filter stop_times dataframe to only the records which have >= min_stops for the trip.

    Args:
        stop_times: stoptimestable to filter
        min_stops: minimum stops to require to keep trip in stoptimes
    """
    stop_ct_by_trip_df = stop_count_by_trip(stop_times)

    # Filter to obtain DataFrame of trips with stop counts >= min_stops
    min_stop_ct_trip_df = stop_ct_by_trip_df[stop_ct_by_trip_df.stop_count >= min_stops]
    if len(min_stop_ct_trip_df) == 0:
        msg = f"No trips meet threshold of minimum stops: {min_stops}"
        raise ValueError(msg)
    WranglerLogger.debug(
        f"Found {len(min_stop_ct_trip_df)} trips with a minimum of {min_stops} stops."
    )

    # Filter the original stop_times DataFrame to only include trips with >= min_stops
    filtered_stop_times = stop_times.merge(
        min_stop_ct_trip_df["trip_id"], on="trip_id", how="inner"
    )
    WranglerLogger.debug(
        f"Filter stop times to {len(filtered_stop_times)}/{len(stop_times)}\
            w/a minimum of {min_stops} stops."
    )

    return filtered_stop_times

stop_times_for_pickup_dropoff_trip_id(stop_times, trip_id, pickup_dropoff='either')

Filters stop_times for a given trip_id based on pickup type.

GTFS values for pickup_type and drop_off_type” 0 or empty - Regularly scheduled pickup/dropoff. 1 - No pickup/dropoff available. 2 - Must phone agency to arrange pickup/dropoff. 3 - Must coordinate with driver to arrange pickup/dropoff.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

A WranglerStopTimesTable to query.

required
trip_id str

trip_id to get stop pattern for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on pickup and dropoff availability at stop. Defaults to “either”. “any”: all stoptime records “either”: either pickup_type or dropoff_type != 1 “both”: both pickup_type and dropoff_type != 1 “pickup_only”: dropoff = 1; pickup != 1 “dropoff_only”: pickup = 1; dropoff != 1

'either'
Source code in network_wrangler/transit/feed/stop_times.py
@validate_call_pyd
def stop_times_for_pickup_dropoff_trip_id(
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> DataFrame[WranglerStopTimesTable]:
    """Filters stop_times for a given trip_id based on pickup type.

    GTFS values for pickup_type and drop_off_type"
        0 or empty - Regularly scheduled pickup/dropoff.
        1 - No pickup/dropoff available.
        2 - Must phone agency to arrange pickup/dropoff.
        3 - Must coordinate with driver to arrange pickup/dropoff.

    Args:
        stop_times: A WranglerStopTimesTable to query.
        trip_id: trip_id to get stop pattern for
        pickup_dropoff: str indicating logic for selecting stops based on pickup and dropoff
            availability at stop. Defaults to "either".
            "any": all stoptime records
            "either": either pickup_type or dropoff_type != 1
            "both": both pickup_type and dropoff_type != 1
            "pickup_only": dropoff = 1; pickup != 1
            "dropoff_only":  pickup = 1; dropoff != 1
    """
    trip_stop_pattern = stop_times_for_trip_id(stop_times, trip_id)

    if pickup_dropoff == "any":
        return trip_stop_pattern

    pickup_type_selection = {
        "either": (trip_stop_pattern.pickup_type != 1) | (trip_stop_pattern.drop_off_type != 1),
        "both": (trip_stop_pattern.pickup_type != 1) & (trip_stop_pattern.drop_off_type != 1),
        "pickup_only": (trip_stop_pattern.pickup_type != 1)
        & (trip_stop_pattern.drop_off_type == 1),
        "dropoff_only": (trip_stop_pattern.drop_off_type != 1)
        & (trip_stop_pattern.pickup_type == 1),
    }

    selection = pickup_type_selection[pickup_dropoff]
    trip_stops = trip_stop_pattern[selection]

    return trip_stops

stop_times_for_route_ids(stop_times, trips, route_ids)

Returns a stop_time records for a list of route_ids.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_route_ids(
    stop_times: DataFrame[WranglerStopTimesTable],
    trips: DataFrame[WranglerTripsTable],
    route_ids: list[str],
) -> DataFrame[WranglerStopTimesTable]:
    """Returns a stop_time records for a list of route_ids."""
    trip_ids = trips.loc[trips.route_id.isin(route_ids)].trip_id.unique()
    return stop_times_for_trip_ids(stop_times, trip_ids)

stop_times_for_shapes(stop_times, shapes, trips)

Filter stop_times dataframe to records associated with shapes dataframe.

Where multiple segments of stop_times are found to match shapes, retain only the longest.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

stop_times dataframe to filter

required
shapes DataFrame[WranglerShapesTable]

shapes dataframe to stop_times to.

required
trips DataFrame[WranglerTripsTable]

trips to link stop_times to shapess

required

Returns:

Type Description
DataFrame[WranglerStopTimesTable]

filtered stop_times dataframe

  • should be retained

    stop_times

trip_id stop_sequence stop_id t1 1 1 t1 2 2 t1 3 3 t1 4 5 t2 1 1 *t2 2 3 t2 3 7

shapes

shape_id shape_pt_sequence shape_model_node_id s1 1 1 s1 2 2 s1 3 3 s1 4 4 s2 1 1 s2 2 2 s2 3 3

trips

trip_id shape_id t1 s1 t2 s2

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_shapes(
    stop_times: DataFrame[WranglerStopTimesTable],
    shapes: DataFrame[WranglerShapesTable],
    trips: DataFrame[WranglerTripsTable],
) -> DataFrame[WranglerStopTimesTable]:
    """Filter stop_times dataframe to records associated with shapes dataframe.

    Where multiple segments of stop_times are found to match shapes, retain only the longest.

    Args:
        stop_times: stop_times dataframe to filter
        shapes: shapes dataframe to stop_times to.
        trips: trips to link stop_times to shapess

    Returns:
        filtered stop_times dataframe

    EX:
    * should be retained
    > stop_times

    trip_id   stop_sequence   stop_id
    *t1          1                  1
    *t1          2                  2
    *t1          3                  3
    t1           4                  5
    *t2          1                  1
    *t2          2                  3
    t2           3                  7

    > shapes

    shape_id   shape_pt_sequence   shape_model_node_id
    s1          1                  1
    s1          2                  2
    s1          3                  3
    s1          4                  4
    s2          1                  1
    s2          2                  2
    s2          3                  3

    > trips

    trip_id   shape_id
    t1          s1
    t2          s2
    """
    """
    > stop_times_w_shapes

    trip_id   stop_sequence   stop_id    shape_id   shape_pt_sequence
    *t1          1                  1        s1          1
    *t1          2                  2        s1          2
    *t1          3                  3        s1          3
    t1           4                  5        NA          NA
    *t2          1                  1        s2          1
    *t2          2                  3        s2          2
    t2           3                  7        NA          NA

    """
    stop_times_w_shapes = merge_shapes_to_stop_times(stop_times, shapes, trips)
    # WranglerLogger.debug(f"stop_times_w_shapes :\n{stop_times_w_shapes}")
    """
    > stop_times_w_shapes

    trip_id   stop_sequence   stop_id   shape_id   shape_pt_sequence
    *t1          1               1        s1          1
    *t1          2               2        s1          2
    *t1          3               3        s1          3
    *t2          1               1        s2          1
    *t2          2               3        s2          2

    """
    filtered_stop_times = stop_times_w_shapes[stop_times_w_shapes["shape_pt_sequence"].notna()]
    # WranglerLogger.debug(f"filtered_stop_times:\n{filtered_stop_times}")

    # Filter out any stop_times the shape_pt_sequence is not ascending
    valid_stop_times = filtered_stop_times.groupby("trip_id").filter(
        lambda x: x["shape_pt_sequence"].is_monotonic_increasing
    )
    # WranglerLogger.debug(f"valid_stop_times:\n{valid_stop_times}")

    valid_stop_times = valid_stop_times.drop(columns=["shape_id", "shape_pt_sequence"])

    longest_valid_stop_times = stop_times_for_longest_segments(valid_stop_times)
    longest_valid_stop_times = longest_valid_stop_times.reset_index(drop=True)

    return longest_valid_stop_times

stop_times_for_stops(stop_times, stops)

Filter stop_times dataframe to only have stop_times associated with stops records.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_stops(
    stop_times: DataFrame[WranglerStopTimesTable], stops: DataFrame[WranglerStopsTable]
) -> DataFrame[WranglerStopTimesTable]:
    """Filter stop_times dataframe to only have stop_times associated with stops records."""
    _sel_stops = stops.stop_id.unique().tolist()
    filtered_stop_times = stop_times[stop_times.stop_id.isin(_sel_stops)]
    WranglerLogger.debug(
        f"Filtered stop_times to {len(filtered_stop_times)}/{len(stop_times)} \
                         records that referenced one of {len(stops)} stops."
    )
    return filtered_stop_times

stop_times_for_trip_id(stop_times, trip_id)

Returns a stop_time records for a given trip_id.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_trip_id(
    stop_times: DataFrame[WranglerStopTimesTable], trip_id: str
) -> DataFrame[WranglerStopTimesTable]:
    """Returns a stop_time records for a given trip_id."""
    stop_times = stop_times.loc[stop_times.trip_id == trip_id]
    return stop_times.sort_values(by=["stop_sequence"])

stop_times_for_trip_ids(stop_times, trip_ids)

Returns a stop_time records for a given list of trip_ids.

Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_trip_ids(
    stop_times: DataFrame[WranglerStopTimesTable], trip_ids: list[str]
) -> DataFrame[WranglerStopTimesTable]:
    """Returns a stop_time records for a given list of trip_ids."""
    stop_times = stop_times.loc[stop_times.trip_id.isin(trip_ids)]
    return stop_times.sort_values(by=["stop_sequence"])

stop_times_for_trip_node_segment(stop_times, trip_id, node_id_start, node_id_end, include_start=True, include_end=True)

Returns stop_times for a given trip_id between two nodes or with those nodes included.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
trip_id str

trip id to select

required
node_id_start int

int of the starting node

required
node_id_end int

int of the ending node

required
include_start bool

bool indicating if the start node should be included in the segment. Defaults to True.

True
include_end bool

bool indicating if the end node should be included in the segment. Defaults to True.

True
Source code in network_wrangler/transit/feed/stop_times.py
def stop_times_for_trip_node_segment(
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    node_id_start: int,
    node_id_end: int,
    include_start: bool = True,
    include_end: bool = True,
) -> DataFrame[WranglerStopTimesTable]:
    """Returns stop_times for a given trip_id between two nodes or with those nodes included.

    Args:
        stop_times: WranglerStopTimesTable
        trip_id: trip id to select
        node_id_start: int of the starting node
        node_id_end: int of the ending node
        include_start: bool indicating if the start node should be included in the segment.
            Defaults to True.
        include_end: bool indicating if the end node should be included in the segment.
            Defaults to True.
    """
    stop_times = stop_times_for_trip_id(stop_times, trip_id)
    start_idx = stop_times[stop_times["stop_id"] == node_id_start].index[0]
    end_idx = stop_times[stop_times["stop_id"] == node_id_end].index[0]
    if not include_start:
        start_idx += 1
    if include_end:
        end_idx += 1
    return stop_times.loc[start_idx:end_idx]

Filters and queries of a gtfs stops table and stop_ids.

node_is_stop(stops, stop_times, node_id, trip_id, pickup_dropoff='either')

Returns boolean indicating if a (or list of) node(s)) is (are) stops for a given trip_id.

Parameters:

Name Type Description Default
stops DataFrame[WranglerStopsTable]

WranglerStopsTable

required
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
node_id Union[int, list[int]]

node ID for roadway

required
trip_id str

trip_id to get stop pattern for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'
Source code in network_wrangler/transit/feed/stops.py
def node_is_stop(
    stops: DataFrame[WranglerStopsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    node_id: Union[int, list[int]],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> Union[bool, list[bool]]:
    """Returns boolean indicating if a (or list of) node(s)) is (are) stops for a given trip_id.

    Args:
        stops: WranglerStopsTable
        stop_times: WranglerStopTimesTable
        node_id: node ID for roadway
        trip_id: trip_id to get stop pattern for
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0
    """
    trip_stop_nodes = stops_for_trip_id(stops, stop_times, trip_id, pickup_dropoff=pickup_dropoff)[
        "stop_id"
    ]
    if isinstance(node_id, list):
        return [n in trip_stop_nodes.values for n in node_id]
    return node_id in trip_stop_nodes.values

stop_id_pattern_for_trip(stop_times, trip_id, pickup_dropoff='either')

Returns a stop pattern for a given trip_id given by a list of stop_ids.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

WranglerStopTimesTable

required
trip_id str

trip_id to get stop pattern for

required
pickup_dropoff PickupDropoffAvailability

str indicating logic for selecting stops based on piackup and dropoff availability at stop. Defaults to “either”. “either”: either pickup_type or dropoff_type > 0 “both”: both pickup_type or dropoff_type > 0 “pickup_only”: only pickup > 0 “dropoff_only”: only dropoff > 0

'either'
Source code in network_wrangler/transit/feed/stops.py
@validate_call_pyd
def stop_id_pattern_for_trip(
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "either",
) -> list[str]:
    """Returns a stop pattern for a given trip_id given by a list of stop_ids.

    Args:
        stop_times: WranglerStopTimesTable
        trip_id: trip_id to get stop pattern for
        pickup_dropoff: str indicating logic for selecting stops based on piackup and dropoff
            availability at stop. Defaults to "either".
            "either": either pickup_type or dropoff_type > 0
            "both": both pickup_type or dropoff_type > 0
            "pickup_only": only pickup > 0
            "dropoff_only": only dropoff > 0
    """
    from .stop_times import stop_times_for_pickup_dropoff_trip_id

    trip_stops = stop_times_for_pickup_dropoff_trip_id(
        stop_times, trip_id, pickup_dropoff=pickup_dropoff
    )
    return trip_stops.stop_id.to_list()

stops_for_stop_times(stops, stop_times)

Filter stops dataframe to only have stops associated with stop_times records.

Source code in network_wrangler/transit/feed/stops.py
def stops_for_stop_times(
    stops: DataFrame[WranglerStopsTable], stop_times: DataFrame[WranglerStopTimesTable]
) -> DataFrame[WranglerStopsTable]:
    """Filter stops dataframe to only have stops associated with stop_times records."""
    _sel_stops_ge_min = stop_times.stop_id.unique().tolist()
    filtered_stops = stops[stops.stop_id.isin(_sel_stops_ge_min)]
    WranglerLogger.debug(
        f"Filtered stops to {len(filtered_stops)}/{len(stops)} \
                         records that referenced one of {len(stop_times)} stop_times."
    )
    return filtered_stops

stops_for_trip_id(stops, stop_times, trip_id, pickup_dropoff='any')

Returns stops.txt which are used for a given trip_id.

Source code in network_wrangler/transit/feed/stops.py
def stops_for_trip_id(
    stops: DataFrame[WranglerStopsTable],
    stop_times: DataFrame[WranglerStopTimesTable],
    trip_id: str,
    pickup_dropoff: PickupDropoffAvailability = "any",
) -> DataFrame[WranglerStopsTable]:
    """Returns stops.txt which are used for a given trip_id."""
    stop_ids = stop_id_pattern_for_trip(stop_times, trip_id, pickup_dropoff=pickup_dropoff)
    return stops.loc[stops.stop_id.isin(stop_ids)]

Filters and queries of a gtfs trips table and trip_ids.

trip_ids_for_shape_id(trips, shape_id)

Returns a list of trip_ids for a given shape_id.

Source code in network_wrangler/transit/feed/trips.py
def trip_ids_for_shape_id(trips: DataFrame[WranglerTripsTable], shape_id: str) -> list[str]:
    """Returns a list of trip_ids for a given shape_id."""
    return trips_for_shape_id(trips, shape_id)["trip_id"].unique().tolist()

trips_for_shape_id(trips, shape_id)

Returns a trips records for a given shape_id.

Source code in network_wrangler/transit/feed/trips.py
def trips_for_shape_id(
    trips: DataFrame[WranglerTripsTable], shape_id: str
) -> DataFrame[WranglerTripsTable]:
    """Returns a trips records for a given shape_id."""
    return trips.loc[trips.shape_id == shape_id]

trips_for_stop_times(trips, stop_times)

Filter trips dataframe to records associated with stop_time records.

Source code in network_wrangler/transit/feed/trips.py
def trips_for_stop_times(
    trips: DataFrame[WranglerTripsTable], stop_times: DataFrame[WranglerStopTimesTable]
) -> DataFrame[WranglerTripsTable]:
    """Filter trips dataframe to records associated with stop_time records."""
    _sel_trips = stop_times.trip_id.unique().tolist()
    filtered_trips = trips[trips.trip_id.isin(_sel_trips)]
    WranglerLogger.debug(
        f"Filtered trips to {len(filtered_trips)}/{len(trips)} \
                         records that referenced one of {len(stop_times)} stop_times."
    )
    return filtered_trips

Functions for translating transit tables into visualizable links relatable to roadway network.

Converts shapes DataFrame to shape links DataFrame.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

The input shapes DataFrame.

required

Returns:

Type Description
DataFrame

pd.DataFrame: The resulting shape links DataFrame.

Source code in network_wrangler/transit/feed/transit_links.py
def shapes_to_shape_links(shapes: DataFrame[WranglerShapesTable]) -> pd.DataFrame:
    """Converts shapes DataFrame to shape links DataFrame.

    Args:
        shapes (DataFrame[WranglerShapesTable]): The input shapes DataFrame.

    Returns:
        pd.DataFrame: The resulting shape links DataFrame.
    """
    return point_seq_to_links(
        shapes,
        id_field="shape_id",
        seq_field="shape_pt_sequence",
        node_id_field="shape_model_node_id",
    )

Converts stop times to stop times links.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

The stop times data.

required
from_field str

The name of the field representing the ‘from’ stop. Defaults to “A”.

'A'
to_field str

The name of the field representing the ‘to’ stop. Defaults to “B”.

'B'

Returns:

Type Description
DataFrame

pd.DataFrame: The resulting stop times links.

Source code in network_wrangler/transit/feed/transit_links.py
def stop_times_to_stop_times_links(
    stop_times: DataFrame[WranglerStopTimesTable],
    from_field: str = "A",
    to_field: str = "B",
) -> pd.DataFrame:
    """Converts stop times to stop times links.

    Args:
        stop_times (DataFrame[WranglerStopTimesTable]): The stop times data.
        from_field (str, optional): The name of the field representing the 'from' stop.
            Defaults to "A".
        to_field (str, optional): The name of the field representing the 'to' stop.
            Defaults to "B".

    Returns:
        pd.DataFrame: The resulting stop times links.
    """
    return point_seq_to_links(
        stop_times,
        id_field="trip_id",
        seq_field="stop_sequence",
        node_id_field="stop_id",
        from_field=from_field,
        to_field=to_field,
    )

Returns a DataFrame containing unique shape links based on the provided shapes DataFrame.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

The input DataFrame containing shape information.

required
from_field str

The name of the column representing the ‘from’ field. Defaults to “A”.

'A'
to_field str

The name of the column representing the ‘to’ field. Defaults to “B”.

'B'

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame containing unique shape links based on the provided shapes df.

Source code in network_wrangler/transit/feed/transit_links.py
def unique_shape_links(
    shapes: DataFrame[WranglerShapesTable], from_field: str = "A", to_field: str = "B"
) -> pd.DataFrame:
    """Returns a DataFrame containing unique shape links based on the provided shapes DataFrame.

    Parameters:
        shapes (DataFrame[WranglerShapesTable]): The input DataFrame containing shape information.
        from_field (str, optional): The name of the column representing the 'from' field.
            Defaults to "A".
        to_field (str, optional): The name of the column representing the 'to' field.
            Defaults to "B".

    Returns:
        pd.DataFrame: DataFrame containing unique shape links based on the provided shapes df.
    """
    shape_links = shapes_to_shape_links(shapes)
    # WranglerLogger.debug(f"Shape links: \n {shape_links[['shape_id', from_field, to_field]]}")

    _agg_dict: dict[str, Union[type, str]] = {"shape_id": list}
    _opt_fields = [f"shape_pt_{v}_{t}" for v in ["lat", "lon"] for t in [from_field, to_field]]
    for f in _opt_fields:
        if f in shape_links:
            _agg_dict[f] = "first"

    unique_shape_links = shape_links.groupby([from_field, to_field]).agg(_agg_dict).reset_index()
    return unique_shape_links

Returns a DataFrame containing unique stop time links based on the given stop times DataFrame.

Parameters:

Name Type Description Default
stop_times DataFrame[WranglerStopTimesTable]

The DataFrame containing stop times data.

required
from_field str

The name of the column representing the ‘from’ field in the stop times DataFrame. Defaults to “A”.

'A'
to_field str

The name of the column representing the ‘to’ field in the stop times DataFrame. Defaults to “B”.

'B'

Returns:

Type Description
DataFrame

pd.DataFrame: A DataFrame containing unique stop time links with columns ‘from_field’, ‘to_field’, and ‘trip_id’.

Source code in network_wrangler/transit/feed/transit_links.py
def unique_stop_time_links(
    stop_times: DataFrame[WranglerStopTimesTable],
    from_field: str = "A",
    to_field: str = "B",
) -> pd.DataFrame:
    """Returns a DataFrame containing unique stop time links based on the given stop times DataFrame.

    Parameters:
        stop_times (DataFrame[WranglerStopTimesTable]): The DataFrame containing stop times data.
        from_field (str, optional): The name of the column representing the 'from' field in the
            stop times DataFrame. Defaults to "A".
        to_field (str, optional): The name of the column representing the 'to' field in the stop
            times DataFrame. Defaults to "B".

    Returns:
        pd.DataFrame: A DataFrame containing unique stop time links with columns 'from_field',
            'to_field', and 'trip_id'.
    """
    links = stop_times_to_stop_times_links(stop_times, from_field=from_field, to_field=to_field)
    unique_links = links.groupby([from_field, to_field])["trip_id"].apply(list).reset_index()
    return unique_links

Functions to create segments from shapes and shape_links.

filter_shapes_to_segments(shapes, segments)

Filter shapes dataframe to records associated with segments dataframe.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

shapes dataframe to filter

required
segments DataFrame

segments dataframe to filter by with shape_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq . Should have one record per shape_id.

required

Returns:

Type Description
DataFrame[WranglerShapesTable]

filtered shapes dataframe

Source code in network_wrangler/transit/feed/transit_segments.py
def filter_shapes_to_segments(
    shapes: DataFrame[WranglerShapesTable], segments: pd.DataFrame
) -> DataFrame[WranglerShapesTable]:
    """Filter shapes dataframe to records associated with segments dataframe.

    Args:
        shapes: shapes dataframe to filter
        segments: segments dataframe to filter by with shape_id, segment_start_shape_pt_seq,
            segment_end_shape_pt_seq . Should have one record per shape_id.

    Returns:
        filtered shapes dataframe
    """
    shapes_w_segs = shapes.merge(segments, on="shape_id", how="left")

    # Retain only those points within the segment sequences
    filtered_shapes = shapes_w_segs[
        (shapes_w_segs["shape_pt_sequence"] >= shapes_w_segs["segment_start_shape_pt_seq"])
        & (shapes_w_segs["shape_pt_sequence"] <= shapes_w_segs["segment_end_shape_pt_seq"])
    ]

    drop_cols = [
        "segment_id",
        "segment_start_shape_pt_seq",
        "segment_end_shape_pt_seq",
        "segment_length",
    ]
    filtered_shapes = filtered_shapes.drop(columns=drop_cols)

    return filtered_shapes

Find the longest segment of each shape_id that is in the links.

Parameters:

Name Type Description Default
shape_links

DataFrame with shape_id, shape_pt_sequence_A, shape_pt_sequence_B

required

Returns:

Type Description
DataFrame

DataFrame with shape_id, segment_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq

Source code in network_wrangler/transit/feed/transit_segments.py
def shape_links_to_longest_shape_segments(shape_links) -> pd.DataFrame:
    """Find the longest segment of each shape_id that is in the links.

    Args:
        shape_links: DataFrame with shape_id, shape_pt_sequence_A, shape_pt_sequence_B

    Returns:
        DataFrame with shape_id, segment_id, segment_start_shape_pt_seq, segment_end_shape_pt_seq
    """
    segments = shape_links_to_segments(shape_links)
    idx = segments.groupby("shape_id")["segment_length"].idxmax()
    longest_segments = segments.loc[idx]
    return longest_segments

Convert shape_links to segments by shape_id with segments of continuous shape_pt_sequence.

DataFrame with shape_id, segment_id, segment_start_shape_pt_seq,

Type Description
DataFrame

segment_end_shape_pt_seq

Source code in network_wrangler/transit/feed/transit_segments.py
def shape_links_to_segments(shape_links) -> pd.DataFrame:
    """Convert shape_links to segments by shape_id with segments of continuous shape_pt_sequence.

    Returns: DataFrame with shape_id, segment_id, segment_start_shape_pt_seq,
        segment_end_shape_pt_seq
    """
    shape_links["gap"] = shape_links.groupby("shape_id")["shape_pt_sequence_A"].diff().gt(1)
    shape_links["segment_id"] = shape_links.groupby("shape_id")["gap"].cumsum()

    # Define segment starts and ends
    segment_definitions = (
        shape_links.groupby(["shape_id", "segment_id"])
        .agg(
            segment_start_shape_pt_seq=("shape_pt_sequence_A", "min"),
            segment_end_shape_pt_seq=("shape_pt_sequence_B", "max"),
        )
        .reset_index()
    )

    # Optionally calculate segment lengths for further uses
    segment_definitions["segment_length"] = (
        segment_definitions["segment_end_shape_pt_seq"]
        - segment_definitions["segment_start_shape_pt_seq"]
        + 1
    )

    return segment_definitions

Transit Projects

Functions for adding a transit route to a TransitNetwork.

apply_transit_route_addition(net, transit_route_addition, reference_road_net=None)

Add transit route to TransitNetwork.

Parameters:

Name Type Description Default
net TransitNetwork

Network to modify.

required
transit_route_addition dict

route dictionary to add to the feed.

required
reference_road_net Optional[RoadwayNetwork]

(RoadwayNetwork, optional): Reference roadway network to use for adding shapes and stops. Defaults to None.

None

Returns:

Name Type Description
TransitNetwork TransitNetwork

Modified network.

Source code in network_wrangler/transit/projects/add_route.py
def apply_transit_route_addition(
    net: TransitNetwork,
    transit_route_addition: dict,
    reference_road_net: Optional[RoadwayNetwork] = None,
) -> TransitNetwork:
    """Add transit route to TransitNetwork.

    Args:
        net (TransitNetwork): Network to modify.
        transit_route_addition: route dictionary to add to the feed.
        reference_road_net: (RoadwayNetwork, optional): Reference roadway network to use for adding shapes and stops. Defaults to None.

    Returns:
        TransitNetwork: Modified network.
    """
    WranglerLogger.debug("Applying add transit route project.")

    add_routes = transit_route_addition["routes"]

    road_net = net.road_net if reference_road_net is None else reference_road_net
    if road_net is None:
        WranglerLogger.error(
            "! Must have a reference road network set in order to update transit \
                         routin.  Either provide as an input to this function or set it for the \
                         transit network: >> transit_net.road_net = ..."
        )
        msg = "Must have a reference road network set in order to update transit routing."
        raise TransitRouteAddError(msg)

    net.feed = _add_route_to_feed(net.feed, add_routes, road_net)

    return net

Module for applying calculated transit projects to a transit network object.

These projects are stored in project card pycode property as python code strings which are executed to change the transit network object.

apply_calculated_transit(net, pycode)

Changes transit network object by executing pycode.

Parameters:

Name Type Description Default
net TransitNetwork

transit network to manipulate

required
pycode str

python code which changes values in the transit network object

required
Source code in network_wrangler/transit/projects/calculate.py
def apply_calculated_transit(
    net: TransitNetwork,
    pycode: str,
) -> TransitNetwork:
    """Changes transit network object by executing pycode.

    Args:
        net: transit network to manipulate
        pycode: python code which changes values in the transit network object
    """
    WranglerLogger.debug("Applying calculated transit project.")
    exec(pycode)

    return net

Functions for adding a transit route to a TransitNetwork.

apply_transit_service_deletion(net, selection, clean_shapes=False, clean_routes=False)

Delete transit service to TransitNetwork.

Parameters:

Name Type Description Default
net TransitNetwork

Network to modify.

required
selection TransitSelection

TransitSelection object, created from a selection dictionary.

required
clean_shapes bool

If True, remove shapes not used by any trips. Defaults to False.

False
clean_routes bool

If True, remove routes not used by any trips. Defaults to False.

False

Returns:

Name Type Description
TransitNetwork TransitNetwork

Modified network.

Source code in network_wrangler/transit/projects/delete_service.py
def apply_transit_service_deletion(
    net: TransitNetwork,
    selection: TransitSelection,
    clean_shapes: Optional[bool] = False,
    clean_routes: Optional[bool] = False,
) -> TransitNetwork:
    """Delete transit service to TransitNetwork.

    Args:
        net (TransitNetwork): Network to modify.
        selection: TransitSelection object, created from a selection dictionary.
        clean_shapes (bool, optional): If True, remove shapes not used by any trips.
            Defaults to False.
        clean_routes (bool, optional): If True, remove routes not used by any trips.
            Defaults to False.

    Returns:
        TransitNetwork: Modified network.
    """
    WranglerLogger.debug("Applying delete transit service project.")

    trip_ids = selection.selected_trips
    net.feed = _delete_trips_from_feed(
        net.feed, trip_ids, clean_shapes=clean_shapes, clean_routes=clean_routes
    )

    return net

Functions for editing transit properties in a TransitNetwork.

apply_transit_property_change(net, selection, property_changes, project_name=None)

Apply changes to transit properties.

Parameters:

Name Type Description Default
net TransitNetwork

Network to modify.

required
selection TransitSelection

Selection of trips to modify.

required
property_changes dict

Dictionary of properties to change.

required
project_name str

Name of the project. Defaults to None.

None

Returns:

Name Type Description
TransitNetwork TransitNetwork

Modified network.

Source code in network_wrangler/transit/projects/edit_property.py
def apply_transit_property_change(
    net: TransitNetwork,
    selection: TransitSelection,
    property_changes: dict,
    project_name: Optional[str] = None,
) -> TransitNetwork:
    """Apply changes to transit properties.

    Args:
        net (TransitNetwork): Network to modify.
        selection (TransitSelection): Selection of trips to modify.
        property_changes (dict): Dictionary of properties to change.
        project_name (str, optional): Name of the project. Defaults to None.

    Returns:
        TransitNetwork: Modified network.
    """
    WranglerLogger.debug("Applying transit property change project.")
    for property, property_change in property_changes.items():
        net = _apply_transit_property_change_to_table(
            net,
            selection,
            property,
            property_change,
            project_name=project_name,
        )
    return net

Functions for editing the transit route shapes and stop patterns.

apply_transit_routing_change(net, selection, routing_change, reference_road_net=None, project_name=None)

Apply a routing change to the transit network, including stop updates.

Parameters:

Name Type Description Default
net TransitNetwork

TransitNetwork object to apply routing change to.

required
selection Selection

TransitSelection object, created from a selection dictionary.

required
routing_change dict

Routing Change dictionary, e.g.

{
    "existing": [46665, 150855],
    "set": [-46665, 150855, 46665, 150855],
}

required
shape_id_scalar int

Initial scalar value to add to duplicated shape_ids to create a new shape_id. Defaults to SHAPE_ID_SCALAR.

required
reference_road_net RoadwayNetwork

Reference roadway network to use for updating shapes and stops. Defaults to None.

None
project_name str

Name of the project. Defaults to None.

None
Source code in network_wrangler/transit/projects/edit_routing.py
def apply_transit_routing_change(
    net: TransitNetwork,
    selection: TransitSelection,
    routing_change: dict,
    reference_road_net: Optional[RoadwayNetwork] = None,
    project_name: Optional[str] = None,
) -> TransitNetwork:
    """Apply a routing change to the transit network, including stop updates.

    Args:
        net (TransitNetwork): TransitNetwork object to apply routing change to.
        selection (Selection): TransitSelection object, created from a selection dictionary.
        routing_change (dict): Routing Change dictionary, e.g.
            ```python
            {
                "existing": [46665, 150855],
                "set": [-46665, 150855, 46665, 150855],
            }
            ```
        shape_id_scalar (int, optional): Initial scalar value to add to duplicated shape_ids to
            create a new shape_id. Defaults to SHAPE_ID_SCALAR.
        reference_road_net (RoadwayNetwork, optional): Reference roadway network to use for
            updating shapes and stops. Defaults to None.
        project_name (str, optional): Name of the project. Defaults to None.
    """
    WranglerLogger.debug("Applying transit routing change project.")
    WranglerLogger.debug(f"...selection: {selection.selection_dict}")
    WranglerLogger.debug(f"...routing: {routing_change}")

    # ---- Secure all inputs needed --------------
    updated_feed = copy.deepcopy(net.feed)
    trip_ids = selection.selected_trips
    if project_name:
        updated_feed.trips.loc[updated_feed.trips.trip_id.isin(trip_ids), "projects"] += (
            f"{project_name},"
        )

    road_net = net.road_net if reference_road_net is None else reference_road_net
    if road_net is None:
        WranglerLogger.error(
            "! Must have a reference road network set in order to update transit \
                         routin.  Either provide as an input to this function or set it for the \
                         transit network: >> transit_net.road_net = ..."
        )
        msg = "Must have a reference road network set in order to update transit routing."
        raise TransitRoutingChangeError(msg)

    # ---- update each shape that is used by selected trips to use new routing -------
    shape_ids = shape_ids_for_trip_ids(updated_feed.trips, trip_ids)
    # WranglerLogger.debug(f"shape_ids: {shape_ids}")
    for shape_id in shape_ids:
        updated_feed.shapes, updated_feed.trips = _update_shapes_and_trips(
            updated_feed,
            shape_id,
            trip_ids,
            routing_change["set"],
            net.config.IDS.TRANSIT_SHAPE_ID_SCALAR,
            road_net,
            routing_existing=routing_change.get("existing", []),
            project_name=project_name,
        )
    # WranglerLogger.debug(f"updated_feed.shapes: \n{updated_feed.shapes}")
    # WranglerLogger.debug(f"updated_feed.trips: \n{updated_feed.trips}")
    # ---- Check if any stops need adding to stops.txt and add if they do ----------
    updated_feed.stops = _update_stops(
        updated_feed, routing_change["set"], road_net, project_name=project_name
    )
    # WranglerLogger.debug(f"updated_feed.stops: \n{updated_feed.stops}")
    # ---- Update stop_times --------------------------------------------------------
    for trip_id in trip_ids:
        updated_feed.stop_times = _update_stop_times_for_trip(
            updated_feed,
            trip_id,
            routing_change["set"],
            routing_change.get("existing", []),
        )

    # ---- Check result -------------------------------------------------------------
    _show_col = [
        "trip_id",
        "stop_id",
        "stop_sequence",
        "departure_time",
        "arrival_time",
    ]
    _ex_stoptimes = updated_feed.stop_times.loc[
        updated_feed.stop_times.trip_id == trip_ids[0], _show_col
    ]
    # WranglerLogger.debug(f"stop_times for first updated trip: \n {_ex_stoptimes}")

    # ---- Update transit network with updated feed.
    net.feed = updated_feed
    # WranglerLogger.debug(f"net.feed.stops: \n {net.feed.stops}")
    return net

Transit Helper Modules

Functions to clip a TransitNetwork object to a boundary.

Clipped transit is an independent transit network that is a subset of the original transit network.

Example usage:

from network_wrangler.transit load_transit, write_transit
from network_wrangler.transit.clip import clip_transit

stpaul_transit = load_transit(example_dir / "stpaul")
boundary_file = test_dir / "data" / "ecolab.geojson"
clipped_network = clip_transit(stpaul_transit, boundary_file=boundary_file)
write_transit(clipped_network, out_dir, prefix="ecolab", format="geojson", true_shape=True)

clip_feed_to_boundary(feed, ref_nodes_df, boundary_gdf=None, boundary_geocode=None, boundary_file=None, min_stops=DEFAULT_MIN_STOPS)

Clips a transit Feed object to a boundary and returns the resulting GeoDataFrames.

Retains only the stops within the boundary and trips that traverse them subject to a minimum number of stops per trip as defined by min_stops.

Parameters:

Name Type Description Default
feed Feed

Feed object to be clipped.

required
ref_nodes_df GeoDataFrame

geodataframe with node geometry to reference

required
boundary_geocode Union[str, dict]

A geocode string or dictionary representing the boundary. Defaults to None.

None
boundary_file Union[str, Path]

A path to the boundary file. Only used if boundary_geocode is None. Defaults to None.

None
boundary_gdf GeoDataFrame

A GeoDataFrame representing the boundary. Only used if boundary_geocode and boundary_file are None. Defaults to None.

None
min_stops int

minimum number of stops needed to retain a transit trip within clipped area. Defaults to DEFAULT_MIN_STOPS which is set to 2.

DEFAULT_MIN_STOPS
Source code in network_wrangler/transit/clip.py
def clip_feed_to_boundary(
    feed: Feed,
    ref_nodes_df: gpd.GeoDataFrame,
    boundary_gdf: Optional[gpd.GeoDataFrame] = None,
    boundary_geocode: Optional[Union[str, dict]] = None,
    boundary_file: Optional[Union[str, Path]] = None,
    min_stops: int = DEFAULT_MIN_STOPS,
) -> Feed:
    """Clips a transit Feed object to a boundary and returns the resulting GeoDataFrames.

    Retains only the stops within the boundary and trips that traverse them subject to a minimum
    number of stops per trip as defined by `min_stops`.

    Args:
        feed: Feed object to be clipped.
        ref_nodes_df: geodataframe with node geometry to reference
        boundary_geocode (Union[str, dict], optional): A geocode string or dictionary
            representing the boundary. Defaults to None.
        boundary_file (Union[str, Path], optional): A path to the boundary file. Only used if
            boundary_geocode is None. Defaults to None.
        boundary_gdf (gpd.GeoDataFrame, optional): A GeoDataFrame representing the boundary.
            Only used if boundary_geocode and boundary_file are None. Defaults to None.
        min_stops: minimum number of stops needed to retain a transit trip within clipped area.
            Defaults to DEFAULT_MIN_STOPS which is set to 2.

    Returns: Feed object trimmed to the boundary.
    """
    WranglerLogger.info("Clipping transit network to boundary.")

    boundary_gdf = get_bounding_polygon(
        boundary_gdf=boundary_gdf,
        boundary_geocode=boundary_geocode,
        boundary_file=boundary_file,
    )

    shape_links_gdf = shapes_to_shape_links_gdf(feed.shapes, ref_nodes_df=ref_nodes_df)

    # make sure boundary_gdf.crs == network.crs
    if boundary_gdf.crs != shape_links_gdf.crs:
        boundary_gdf = boundary_gdf.to_crs(shape_links_gdf.crs)

    # get the boundary as a single polygon
    boundary = boundary_gdf.geometry.union_all()
    # get the shape_links that intersect the boundary
    clipped_shape_links = shape_links_gdf[shape_links_gdf.geometry.intersects(boundary)]

    # nodes within clipped_shape_links
    node_ids = list(set(clipped_shape_links.A.to_list() + clipped_shape_links.B.to_list()))
    WranglerLogger.debug(f"Clipping network to {len(node_ids)} nodes.")
    if not node_ids:
        msg = "No nodes found within the boundary."
        raise ValueError(msg)
    return _clip_feed_to_nodes(feed, node_ids, min_stops=min_stops)

clip_feed_to_roadway(feed, roadway_net, min_stops=DEFAULT_MIN_STOPS)

Returns a copy of transit feed clipped to the roadway network.

Parameters:

Name Type Description Default
feed Feed

Transit Feed to clip.

required
roadway_net RoadwayNetwork

Roadway network to clip to.

required
min_stops int

minimum number of stops needed to retain a transit trip within clipped area. Defaults to DEFAULT_MIN_STOPS which is set to 2.

DEFAULT_MIN_STOPS

Raises:

Type Description
ValueError

If no stops found within the roadway network.

Returns:

Name Type Description
Feed Feed

Clipped deep copy of feed limited to the roadway network.

Source code in network_wrangler/transit/clip.py
def clip_feed_to_roadway(
    feed: Feed,
    roadway_net: RoadwayNetwork,
    min_stops: int = DEFAULT_MIN_STOPS,
) -> Feed:
    """Returns a copy of transit feed clipped to the roadway network.

    Args:
        feed (Feed): Transit Feed to clip.
        roadway_net: Roadway network to clip to.
        min_stops: minimum number of stops needed to retain a transit trip within clipped area.
            Defaults to DEFAULT_MIN_STOPS which is set to 2.

    Raises:
        ValueError: If no stops found within the roadway network.

    Returns:
        Feed: Clipped deep copy of feed limited to the roadway network.
    """
    WranglerLogger.info("Clipping transit network to roadway network.")

    clipped_feed = _remove_links_from_feed(feed, roadway_net.links_df, min_stops=min_stops)

    return clipped_feed

clip_transit(network, node_ids=None, boundary_geocode=None, boundary_file=None, boundary_gdf=None, ref_nodes_df=None, roadway_net=None, min_stops=DEFAULT_MIN_STOPS)

Returns a new TransitNetwork clipped to a boundary as determined by arguments.

Will clip based on which arguments are provided as prioritized below:

  1. If node_ids provided, will clip based on node_ids
  2. If boundary_geocode provided, will clip based on on search in OSM for that jurisdiction boundary using reference geometry from ref_nodes_df, roadway_net, or roadway_path
  3. If boundary_file provided, will clip based on that polygon using reference geometry from ref_nodes_df, roadway_net, or roadway_path
  4. If boundary_gdf provided, will clip based on that geodataframe using reference geometry from ref_nodes_df, roadway_net, or roadway_path
  5. If roadway_net provided, will clip based on that roadway network

Parameters:

Name Type Description Default
network TransitNetwork

TransitNetwork to clip.

required
node_ids list[str]

A list of node_ids to clip to. Defaults to None.

None
boundary_geocode Union[str, dict]

A geocode string or dictionary representing the boundary. Only used if node_ids are None. Defaults to None.

None
boundary_file Union[str, Path]

A path to the boundary file. Only used if node_ids and boundary_geocode are None. Defaults to None.

None
boundary_gdf GeoDataFrame

A GeoDataFrame representing the boundary. Only used if node_ids, boundary_geocode and boundary_file are None. Defaults to None.

None
ref_nodes_df Optional[Union[None, GeoDataFrame]]

GeoDataFrame of geographic references for node_ids. Only used if node_ids is None and one of boundary_* is not None.

None
roadway_net Optional[Union[None, RoadwayNetwork]]

Roadway Network instance to clip transit network to. Only used if node_ids is None and allof boundary_* are None

None
min_stops int

minimum number of stops needed to retain a transit trip within clipped area. Defaults to DEFAULT_MIN_STOPS which is set to 2.

DEFAULT_MIN_STOPS
Source code in network_wrangler/transit/clip.py
def clip_transit(
    network: Union[TransitNetwork, str, Path],
    node_ids: Optional[Union[None, list[str]]] = None,
    boundary_geocode: Optional[Union[str, dict, None]] = None,
    boundary_file: Optional[Union[str, Path]] = None,
    boundary_gdf: Optional[Union[None, gpd.GeoDataFrame]] = None,
    ref_nodes_df: Optional[Union[None, gpd.GeoDataFrame]] = None,
    roadway_net: Optional[Union[None, RoadwayNetwork]] = None,
    min_stops: int = DEFAULT_MIN_STOPS,
) -> TransitNetwork:
    """Returns a new TransitNetwork clipped to a boundary as determined by arguments.

    Will clip based on which arguments are provided as prioritized below:

    1. If `node_ids` provided, will clip based on `node_ids`
    2. If `boundary_geocode` provided, will clip based on on search in OSM for that jurisdiction
        boundary using reference geometry from `ref_nodes_df`, `roadway_net`, or `roadway_path`
    3. If `boundary_file` provided, will clip based on that polygon  using reference geometry
        from `ref_nodes_df`, `roadway_net`, or `roadway_path`
    4. If `boundary_gdf` provided, will clip based on that geodataframe  using reference geometry
        from `ref_nodes_df`, `roadway_net`, or `roadway_path`
    5. If `roadway_net` provided, will clip based on that roadway network

    Args:
        network (TransitNetwork): TransitNetwork to clip.
        node_ids (list[str], optional): A list of node_ids to clip to. Defaults to None.
        boundary_geocode (Union[str, dict], optional): A geocode string or dictionary
            representing the boundary. Only used if node_ids are None. Defaults to None.
        boundary_file (Union[str, Path], optional): A path to the boundary file. Only used if
            node_ids and boundary_geocode are None. Defaults to None.
        boundary_gdf (gpd.GeoDataFrame, optional): A GeoDataFrame representing the boundary.
            Only used if node_ids, boundary_geocode and boundary_file are None. Defaults to None.
        ref_nodes_df: GeoDataFrame of geographic references for node_ids.  Only used if
            node_ids is None and one of boundary_* is not None.
        roadway_net: Roadway Network  instance to clip transit network to.  Only used if
            node_ids is None and allof boundary_* are None
        min_stops: minimum number of stops needed to retain a transit trip within clipped area.
            Defaults to DEFAULT_MIN_STOPS which is set to 2.
    """
    if not isinstance(network, TransitNetwork):
        network = load_transit(network)
    set_roadway_network = False
    feed = network.feed

    if node_ids is not None:
        clipped_feed = _clip_feed_to_nodes(feed, node_ids=node_ids, min_stops=min_stops)
    elif any(i is not None for i in [boundary_file, boundary_geocode, boundary_gdf]):
        if ref_nodes_df is None:
            ref_nodes_df = get_nodes(transit_net=network, roadway_net=roadway_net)

        clipped_feed = clip_feed_to_boundary(
            feed,
            ref_nodes_df,
            boundary_file=boundary_file,
            boundary_geocode=boundary_geocode,
            boundary_gdf=boundary_gdf,
            min_stops=min_stops,
        )
    elif roadway_net is not None:
        clipped_feed = clip_feed_to_roadway(feed, roadway_net=roadway_net)
        set_roadway_network = True
    else:
        msg = "Missing required arguments from clip_transit"
        raise ValueError(msg)

    # create a new TransitNetwork object with the clipped feed dataframes
    clipped_net = TransitNetwork(clipped_feed)

    if set_roadway_network:
        WranglerLogger.info("Setting roadway network for clipped transit network.")
        clipped_net.road_net = roadway_net
    return clipped_net

Utilities for working with transit geodataframes.

Translates shapes to shape links geodataframe using geometry from ref_nodes_df if provided.

TODO: Add join to links and then shapes to get true geometry.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

Feed shapes table

required
ref_nodes_df Optional[DataFrame[RoadNodesTable]]

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
from_field str

Field used for the link’s from node model_node_id. Defaults to “A”.

'A'
to_field str

Field used for the link’s to node model_node_id. Defaults to “B”.

'B'
crs int

Coordinate reference system. SHouldn’t be changed unless you know what you are doing. Defaults to LAT_LON_CRS which is WGS84 lat/long.

LAT_LON_CRS

Returns:

Type Description
GeoDataFrame

gpd.GeoDataFrame: description

Source code in network_wrangler/transit/geo.py
def shapes_to_shape_links_gdf(
    shapes: DataFrame[WranglerShapesTable],
    ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None,
    from_field: str = "A",
    to_field: str = "B",
    crs: int = LAT_LON_CRS,
) -> gpd.GeoDataFrame:
    """Translates shapes to shape links geodataframe using geometry from ref_nodes_df if provided.

    TODO: Add join to links and then shapes to get true geometry.

    Args:
        shapes: Feed shapes table
        ref_nodes_df: If specified, will use geometry from these nodes.  Otherwise, will use
            geometry in shapes file. Defaults to None.
        from_field: Field used for the link's from node `model_node_id`. Defaults to "A".
        to_field: Field used for the link's to node `model_node_id`. Defaults to "B".
        crs (int, optional): Coordinate reference system. SHouldn't be changed unless you know
            what you are doing. Defaults to LAT_LON_CRS which is WGS84 lat/long.

    Returns:
        gpd.GeoDataFrame: _description_
    """
    if ref_nodes_df is not None:
        shapes = update_shapes_geometry(shapes, ref_nodes_df)
    tr_links = unique_shape_links(shapes, from_field=from_field, to_field=to_field)
    # WranglerLogger.debug(f"tr_links :\n{tr_links }")

    geometry = linestring_from_lats_lons(
        tr_links,
        [f"shape_pt_lat_{from_field}", f"shape_pt_lat_{to_field}"],
        [f"shape_pt_lon_{from_field}", f"shape_pt_lon_{to_field}"],
    )
    # WranglerLogger.debug(f"geometry\n{geometry}")
    shapes_gdf = gpd.GeoDataFrame(tr_links, geometry=geometry, crs=crs).set_crs(LAT_LON_CRS)
    return shapes_gdf

shapes_to_trip_shapes_gdf(shapes, ref_nodes_df=None, crs=LAT_LON_CRS)

Geodataframe with one polyline shape per shape_id.

TODO: add information about the route and trips.

Parameters:

Name Type Description Default
shapes DataFrame[WranglerShapesTable]

WranglerShapesTable

required
trips

WranglerTripsTable

required
ref_nodes_df Optional[DataFrame[RoadNodesTable]]

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
crs int

int, optional, default 4326

LAT_LON_CRS
Source code in network_wrangler/transit/geo.py
def shapes_to_trip_shapes_gdf(
    shapes: DataFrame[WranglerShapesTable],
    # trips: WranglerTripsTable,
    ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None,
    crs: int = LAT_LON_CRS,
) -> gpd.GeoDataFrame:
    """Geodataframe with one polyline shape per shape_id.

    TODO: add information about the route and trips.

    Args:
        shapes: WranglerShapesTable
        trips: WranglerTripsTable
        ref_nodes_df: If specified, will use geometry from these nodes.  Otherwise, will use
            geometry in shapes file. Defaults to None.
        crs: int, optional, default 4326
    """
    if ref_nodes_df is not None:
        shapes = update_shapes_geometry(shapes, ref_nodes_df)

    shape_geom = (
        shapes[["shape_id", "shape_pt_lat", "shape_pt_lon"]]
        .groupby("shape_id")
        .agg(list)
        .apply(lambda x: LineString(zip(x[1], x[0])), axis=1)
    )

    route_shapes_gdf = gpd.GeoDataFrame(
        data=shape_geom.index, geometry=shape_geom.values, crs=crs
    ).set_crs(LAT_LON_CRS)

    return route_shapes_gdf

Stop times geodataframe as links using geometry from stops.txt or optionally another df.

Parameters:

Name Type Description Default
stop_times WranglerStopTimesTable

Feed stop times table.

required
stops WranglerStopsTable

Feed stops table.

required
ref_nodes_df DataFrame

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
from_field str

Field used for the link’s from node model_node_id. Defaults to “A”.

'A'
to_field str

Field used for the link’s to node model_node_id. Defaults to “B”.

'B'
Source code in network_wrangler/transit/geo.py
def stop_times_to_stop_time_links_gdf(
    stop_times: DataFrame[WranglerStopTimesTable],
    stops: DataFrame[WranglerStopsTable],
    ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None,
    from_field: str = "A",
    to_field: str = "B",
) -> gpd.GeoDataFrame:
    """Stop times geodataframe as links using geometry from stops.txt or optionally another df.

    Args:
        stop_times (WranglerStopTimesTable): Feed stop times table.
        stops (WranglerStopsTable): Feed stops table.
        ref_nodes_df (pd.DataFrame, optional): If specified, will use geometry from these nodes.
            Otherwise, will use geometry in shapes file. Defaults to None.
        from_field: Field used for the link's from node `model_node_id`. Defaults to "A".
        to_field: Field used for the link's to node `model_node_id`. Defaults to "B".
    """
    from ..utils.geo import linestring_from_lats_lons

    if ref_nodes_df is not None:
        stops = update_stops_geometry(stops, ref_nodes_df)

    lat_fields = []
    lon_fields = []
    tr_links = unique_stop_time_links(stop_times, from_field=from_field, to_field=to_field)
    for f in (from_field, to_field):
        tr_links = tr_links.merge(
            stops[["stop_id", "stop_lat", "stop_lon"]],
            right_on="stop_id",
            left_on=f,
            how="left",
        )
        lon_f = f"{f}_X"
        lat_f = f"{f}_Y"
        tr_links = tr_links.rename(columns={"stop_lon": lon_f, "stop_lat": lat_f})
        lon_fields.append(lon_f)
        lat_fields.append(lat_f)

    geometry = linestring_from_lats_lons(tr_links, lat_fields, lon_fields)
    return gpd.GeoDataFrame(tr_links, geometry=geometry).set_crs(LAT_LON_CRS)

stop_times_to_stop_time_points_gdf(stop_times, stops, ref_nodes_df=None)

Stoptimes geodataframe as points using geometry from stops.txt or optionally another df.

Parameters:

Name Type Description Default
stop_times WranglerStopTimesTable

Feed stop times table.

required
stops WranglerStopsTable

Feed stops table.

required
ref_nodes_df DataFrame

If specified, will use geometry from these nodes. Otherwise, will use geometry in shapes file. Defaults to None.

None
Source code in network_wrangler/transit/geo.py
def stop_times_to_stop_time_points_gdf(
    stop_times: DataFrame[WranglerStopTimesTable],
    stops: DataFrame[WranglerStopsTable],
    ref_nodes_df: Optional[DataFrame[RoadNodesTable]] = None,
) -> gpd.GeoDataFrame:
    """Stoptimes geodataframe as points using geometry from stops.txt or optionally another df.

    Args:
        stop_times (WranglerStopTimesTable): Feed stop times table.
        stops (WranglerStopsTable): Feed stops table.
        ref_nodes_df (pd.DataFrame, optional): If specified, will use geometry from these nodes.
            Otherwise, will use geometry in shapes file. Defaults to None.
    """
    if ref_nodes_df is not None:
        stops = update_stops_geometry(stops, ref_nodes_df)

    stop_times_geo = stop_times.merge(
        stops[["stop_id", "stop_lat", "stop_lon"]],
        right_on="stop_id",
        left_on="stop_id",
        how="left",
    )
    return gpd.GeoDataFrame(
        stop_times_geo,
        geometry=gpd.points_from_xy(stop_times_geo["stop_lon"], stop_times_geo["stop_lat"]),
        crs=LAT_LON_CRS,
    )

update_shapes_geometry(shapes, ref_nodes_df)

Returns shapes table with geometry updated from ref_nodes_df.

NOTE: does not update “geometry” field if it exists.

Source code in network_wrangler/transit/geo.py
def update_shapes_geometry(
    shapes: DataFrame[WranglerShapesTable], ref_nodes_df: DataFrame[RoadNodesTable]
) -> DataFrame[WranglerShapesTable]:
    """Returns shapes table with geometry updated from ref_nodes_df.

    NOTE: does not update "geometry" field if it exists.
    """
    return update_point_geometry(
        shapes,
        ref_nodes_df,
        id_field="shape_model_node_id",
        lon_field="shape_pt_lon",
        lat_field="shape_pt_lat",
    )

update_stops_geometry(stops, ref_nodes_df)

Returns stops table with geometry updated from ref_nodes_df.

NOTE: does not update “geometry” field if it exists.

Source code in network_wrangler/transit/geo.py
def update_stops_geometry(
    stops: DataFrame[WranglerStopsTable], ref_nodes_df: DataFrame[RoadNodesTable]
) -> DataFrame[WranglerStopsTable]:
    """Returns stops table with geometry updated from ref_nodes_df.

    NOTE: does not update "geometry" field if it exists.
    """
    return update_point_geometry(
        stops, ref_nodes_df, id_field="stop_id", lon_field="stop_lon", lat_field="stop_lat"
    )

Functions for reading and writing transit feeds and networks.

convert_transit_serialization(input_path, output_format, out_dir='.', input_file_format='csv', out_prefix='', overwrite=True)

Converts a transit network from one serialization to another.

Parameters:

Name Type Description Default
input_path Union[str, Path]

path to the input network

required
output_format TransitFileTypes

the format of the output files. Should be txt, csv, or parquet.

required
out_dir Union[Path, str]

directory to write the network to. Defaults to current directory.

'.'
input_file_format TransitFileTypes

the file_format of the files to read. Should be txt, csv, or parquet. Defaults to “txt”

'csv'
out_prefix str

prefix to add to the file name. Defaults to “”

''
overwrite bool

if True, will overwrite the files if they already exist. Defaults to True

True
Source code in network_wrangler/transit/io.py
def convert_transit_serialization(
    input_path: Union[str, Path],
    output_format: TransitFileTypes,
    out_dir: Union[Path, str] = ".",
    input_file_format: TransitFileTypes = "csv",
    out_prefix: str = "",
    overwrite: bool = True,
):
    """Converts a transit network from one serialization to another.

    Args:
        input_path: path to the input network
        output_format: the format of the output files. Should be txt, csv, or parquet.
        out_dir: directory to write the network to. Defaults to current directory.
        input_file_format: the file_format of the files to read. Should be txt, csv, or parquet.
            Defaults to "txt"
        out_prefix: prefix to add to the file name. Defaults to ""
        overwrite: if True, will overwrite the files if they already exist. Defaults to True
    """
    WranglerLogger.info(
        f"Loading transit net from {input_path} with input type {input_file_format}"
    )
    net = load_transit(input_path, file_format=input_file_format)
    WranglerLogger.info(f"Writing transit network to {out_dir} in {output_format} format.")
    write_transit(
        net,
        prefix=out_prefix,
        out_dir=out_dir,
        file_format=output_format,
        overwrite=overwrite,
    )

load_feed_from_dfs(feed_dfs)

Create a TransitNetwork object from a dictionary of DataFrames representing a GTFS feed.

Parameters:

Name Type Description Default
feed_dfs dict

A dictionary containing DataFrames representing the tables of a GTFS feed.

required

Returns:

Name Type Description
Feed Feed

A Feed object representing the transit network.

Raises:

Type Description
ValueError

If the feed_dfs dictionary does not contain all the required tables.

Example

feed_dfs = { … “agency”: agency_df, … “routes”: routes_df, … “stops”: stops_df, … “trips”: trips_df, … “stop_times”: stop_times_df, … } feed = load_feed_from_dfs(feed_dfs)

Source code in network_wrangler/transit/io.py
def load_feed_from_dfs(feed_dfs: dict) -> Feed:
    """Create a TransitNetwork object from a dictionary of DataFrames representing a GTFS feed.

    Args:
        feed_dfs (dict): A dictionary containing DataFrames representing the tables of a GTFS feed.

    Returns:
        Feed: A Feed object representing the transit network.

    Raises:
        ValueError: If the feed_dfs dictionary does not contain all the required tables.

    Example:
        >>> feed_dfs = {
        ...     "agency": agency_df,
        ...     "routes": routes_df,
        ...     "stops": stops_df,
        ...     "trips": trips_df,
        ...     "stop_times": stop_times_df,
        ... }
        >>> feed = load_feed_from_dfs(feed_dfs)
    """
    if not all(table in feed_dfs for table in Feed.table_names):
        msg = f"feed_dfs must contain the following tables: {Feed.table_names}"
        raise ValueError(msg)

    feed = Feed(**feed_dfs)

    return feed

load_feed_from_path(feed_path, file_format='txt')

Create a Feed object from the path to a GTFS transit feed.

Parameters:

Name Type Description Default
feed_path Union[Path, str]

The path to the GTFS transit feed.

required
file_format TransitFileTypes

the format of the files to read. Defaults to “txt”

'txt'

Returns:

Name Type Description
Feed Feed

The TransitNetwork object created from the GTFS transit feed.

Source code in network_wrangler/transit/io.py
def load_feed_from_path(
    feed_path: Union[Path, str], file_format: TransitFileTypes = "txt"
) -> Feed:
    """Create a Feed object from the path to a GTFS transit feed.

    Args:
        feed_path (Union[Path, str]): The path to the GTFS transit feed.
        file_format: the format of the files to read. Defaults to "txt"

    Returns:
        Feed: The TransitNetwork object created from the GTFS transit feed.
    """
    feed_path = _feed_path_ref(Path(feed_path))  # unzips if needs to be unzipped

    if not feed_path.is_dir():
        msg = f"Feed path not a directory: {feed_path}"
        raise NotADirectoryError(msg)

    WranglerLogger.info(f"Reading GTFS feed tables from {feed_path}")

    feed_possible_files = {
        table: list(feed_path.glob(f"*{table}.{file_format}")) for table in Feed.table_names
    }

    # make sure we have all the tables we need
    _missing_files = [t for t, v in feed_possible_files.items() if not v]

    if _missing_files:
        WranglerLogger.debug(f"!!! Missing transit files: {_missing_files}")
        msg = f"Required GTFS Feed table(s) not in {feed_path}: \n  {_missing_files}"
        raise RequiredTableError(msg)

    # but don't want to have more than one file per search
    _ambiguous_files = [t for t, v in feed_possible_files.items() if len(v) > 1]
    if _ambiguous_files:
        WranglerLogger.warning(
            f"! More than one file matches following tables. \
                               Using the first on the list: {_ambiguous_files}"
        )

    feed_files = {t: f[0] for t, f in feed_possible_files.items()}
    feed_dfs = {table: _read_table_from_file(table, file) for table, file in feed_files.items()}

    return load_feed_from_dfs(feed_dfs)

load_transit(feed, file_format='txt', config=DefaultConfig)

Create a TransitNetwork object.

This function takes in a feed parameter, which can be one of the following types: - Feed: A Feed object representing a transit feed. - dict[str, pd.DataFrame]: A dictionary of DataFrames representing transit data. - str or Path: A string or a Path object representing the path to a transit feed file.

Parameters:

Name Type Description Default
feed Union[Feed, GtfsModel, dict[str, DataFrame], str, Path]

Feed boject, dict of transit data frames, or path to transit feed data

required
file_format TransitFileTypes

the format of the files to read. Defaults to “txt”

'txt'
config WranglerConfig

WranglerConfig object. Defaults to DefaultConfig.

DefaultConfig

A TransitNetwork object representing the loaded transit network.

Raises: ValueError: If the feed parameter is not one of the supported types.

Example usage:

transit_network_from_zip = load_transit("path/to/gtfs.zip")

transit_network_from_unzipped_dir = load_transit("path/to/files")

transit_network_from_parquet = load_transit("path/to/files", file_format="parquet")

dfs_of_transit_data = {"routes": routes_df, "stops": stops_df, "trips": trips_df...}
transit_network_from_dfs = load_transit(dfs_of_transit_data)

Source code in network_wrangler/transit/io.py
def load_transit(
    feed: Union[Feed, GtfsModel, dict[str, pd.DataFrame], str, Path],
    file_format: TransitFileTypes = "txt",
    config: WranglerConfig = DefaultConfig,
) -> "TransitNetwork":
    """Create a TransitNetwork object.

    This function takes in a `feed` parameter, which can be one of the following types:
    - `Feed`: A Feed object representing a transit feed.
    - `dict[str, pd.DataFrame]`: A dictionary of DataFrames representing transit data.
    - `str` or `Path`: A string or a Path object representing the path to a transit feed file.

    Args:
        feed: Feed boject, dict of transit data frames, or path to transit feed data
        file_format: the format of the files to read. Defaults to "txt"
        config: WranglerConfig object. Defaults to DefaultConfig.

    Returns:
    A TransitNetwork object representing the loaded transit network.

    Raises:
    ValueError: If the `feed` parameter is not one of the supported types.

    Example usage:
    ```
    transit_network_from_zip = load_transit("path/to/gtfs.zip")

    transit_network_from_unzipped_dir = load_transit("path/to/files")

    transit_network_from_parquet = load_transit("path/to/files", file_format="parquet")

    dfs_of_transit_data = {"routes": routes_df, "stops": stops_df, "trips": trips_df...}
    transit_network_from_dfs = load_transit(dfs_of_transit_data)
    ```

    """
    if isinstance(feed, (Path, str)):
        feed = Path(feed)
        feed_obj = load_feed_from_path(feed, file_format=file_format)
        feed_obj.feed_path = feed
    elif isinstance(feed, dict):
        feed_obj = load_feed_from_dfs(feed)
    elif isinstance(feed, GtfsModel):
        feed_obj = Feed(**feed.__dict__)
    else:
        if not isinstance(feed, Feed):
            msg = f"TransitNetwork must be seeded with a Feed, dict of dfs or Path. Found {type(feed)}"
            raise ValueError(msg)
        feed_obj = feed

    return TransitNetwork(feed_obj, config=config)

write_feed_geo(feed, ref_nodes_df, out_dir, file_format='geojson', out_prefix=None, overwrite=True)

Write a Feed object to a directory in a geospatial format.

Parameters:

Name Type Description Default
feed Feed

Feed object to write

required
ref_nodes_df GeoDataFrame

Reference nodes dataframe to use for geometry

required
out_dir Union[str, Path]

directory to write the network to

required
file_format Literal['geojson', 'shp', 'parquet']

the format of the output files. Defaults to “geojson”

'geojson'
out_prefix

prefix to add to the file name

None
overwrite bool

if True, will overwrite the files if they already exist. Defaults to True

True
Source code in network_wrangler/transit/io.py
def write_feed_geo(
    feed: Feed,
    ref_nodes_df: gpd.GeoDataFrame,
    out_dir: Union[str, Path],
    file_format: Literal["geojson", "shp", "parquet"] = "geojson",
    out_prefix=None,
    overwrite: bool = True,
) -> None:
    """Write a Feed object to a directory in a geospatial format.

    Args:
        feed: Feed object to write
        ref_nodes_df: Reference nodes dataframe to use for geometry
        out_dir: directory to write the network to
        file_format: the format of the output files. Defaults to "geojson"
        out_prefix: prefix to add to the file name
        overwrite: if True, will overwrite the files if they already exist. Defaults to True
    """
    from .geo import shapes_to_shape_links_gdf

    out_dir = Path(out_dir)
    if not out_dir.is_dir():
        if out_dir.parent.is_dir():
            out_dir.mkdir()
        else:
            msg = f"Output directory {out_dir} ands its parent path does not exist"
            raise FileNotFoundError(msg)

    prefix = f"{out_prefix}_" if out_prefix else ""
    shapes_outpath = out_dir / f"{prefix}trn_shapes.{file_format}"
    shapes_gdf = shapes_to_shape_links_gdf(feed.shapes, ref_nodes_df=ref_nodes_df)
    write_table(shapes_gdf, shapes_outpath, overwrite=overwrite)

    stops_outpath = out_dir / f"{prefix}trn_stops.{file_format}"
    stops_gdf = to_points_gdf(feed.stops, ref_nodes_df=ref_nodes_df)
    write_table(stops_gdf, stops_outpath, overwrite=overwrite)

write_transit(transit_net, out_dir='.', prefix=None, file_format='txt', overwrite=True)

Writes a network in the transit network standard.

Parameters:

Name Type Description Default
transit_net

a TransitNetwork instance

required
out_dir Union[Path, str]

directory to write the network to

'.'
file_format Literal['txt', 'csv', 'parquet']

the format of the output files. Defaults to “txt” which is csv with txt file format.

'txt'
prefix Optional[Union[Path, str]]

prefix to add to the file name

None
overwrite bool

if True, will overwrite the files if they already exist. Defaults to True

True
Source code in network_wrangler/transit/io.py
def write_transit(
    transit_net,
    out_dir: Union[Path, str] = ".",
    prefix: Optional[Union[Path, str]] = None,
    file_format: Literal["txt", "csv", "parquet"] = "txt",
    overwrite: bool = True,
) -> None:
    """Writes a network in the transit network standard.

    Args:
        transit_net: a TransitNetwork instance
        out_dir: directory to write the network to
        file_format: the format of the output files. Defaults to "txt" which is csv with txt
            file format.
        prefix: prefix to add to the file name
        overwrite: if True, will overwrite the files if they already exist. Defaults to True
    """
    out_dir = Path(out_dir)
    prefix = f"{prefix}_" if prefix else ""
    for table in transit_net.feed.table_names:
        df = transit_net.feed.get_table(table)
        outpath = out_dir / f"{prefix}{table}.{file_format}"
        write_table(df, outpath, overwrite=overwrite)
    WranglerLogger.info(f"Wrote {len(transit_net.feed.tables)} files to {out_dir}")

ModelTransit class and functions for managing consistency between roadway and transit networks.

NOTE: this is not thoroughly tested and may not be fully functional.

ModelTransit

ModelTransit class for managing consistency between roadway and transit networks.

Source code in network_wrangler/transit/model_transit.py
class ModelTransit:
    """ModelTransit class for managing consistency between roadway and transit networks."""

    def __init__(
        self,
        transit_net: TransitNetwork,
        roadway_net: RoadwayNetwork,
        shift_transit_to_managed_lanes: bool = True,
    ):
        """ModelTransit class for managing consistency between roadway and transit networks."""
        self.transit_net = transit_net
        self.roadway_net = roadway_net
        self._roadway_net_hash = None
        self._transit_feed_hash = None
        self._transit_shifted_to_ML = shift_transit_to_managed_lanes

    @property
    def model_roadway_net(self):
        """ModelRoadwayNetwork associated with this ModelTransit."""
        return self.roadway_net.model_net

    @property
    def consistent_nets(self) -> bool:
        """Indicate if roadway and transit networks have changed since self.m_feed updated."""
        return bool(
            self.roadway_net.network_hash == self._roadway_net_hash
            and self.transit_net.feed_hash == self._transit_feed_hash
        )

    @property
    def m_feed(self):
        """TransitNetwork.feed with updates for consistency with associated ModelRoadwayNetwork."""
        if self.consistent_nets:
            return self._m_feed
        # NOTE: look at this
        # If netoworks have changed, updated model transit and update reference hash
        self._roadway_net_hash = copy.deepcopy(self.roadway_net.network_hash)
        self._transit_feed_hash = copy.deepcopy(self.transit_net.feed_hash)

        if not self._transit_shifted_to_ML:
            self._m_feed = copy.deepcopy(self.transit_net.feed)
            return self._m_feed
        return None
consistent_nets: bool property

Indicate if roadway and transit networks have changed since self.m_feed updated.

m_feed property

TransitNetwork.feed with updates for consistency with associated ModelRoadwayNetwork.

model_roadway_net property

ModelRoadwayNetwork associated with this ModelTransit.

__init__(transit_net, roadway_net, shift_transit_to_managed_lanes=True)

ModelTransit class for managing consistency between roadway and transit networks.

Source code in network_wrangler/transit/model_transit.py
def __init__(
    self,
    transit_net: TransitNetwork,
    roadway_net: RoadwayNetwork,
    shift_transit_to_managed_lanes: bool = True,
):
    """ModelTransit class for managing consistency between roadway and transit networks."""
    self.transit_net = transit_net
    self.roadway_net = roadway_net
    self._roadway_net_hash = None
    self._transit_feed_hash = None
    self._transit_shifted_to_ML = shift_transit_to_managed_lanes

Classes and functions for selecting transit trips from a transit network.

Usage:

Create a TransitSelection object by providing a TransitNetwork object and a selection dictionary:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
```python
selection_dict = {
    "links": {...},
    "nodes": {...},
    "route_properties": {...},
    "trip_properties": {...},
    "timespans": {...},
}
transit_selection = TransitSelection(transit_network, selection_dict)
```

Access the selected trip ids or dataframe as follows:

1
2
3
4
```python
selected_trips = transit_selection.selected_trips
selected_trips_df = transit_selection.selected_trips_df
```

Note: The selection dictionary should conform to the SelectTransitTrips model defined in the models.projects.transit_selection module.

TransitSelection

Object to perform and store information about a selection from a project card “facility”.

Attributes:

Name Type Description
selection_dict
selected_trips list
selected_trips_df DataFrame[WranglerTripsTable]

pd.DataFrame: DataFrame of selected trips

sel_key
net
Source code in network_wrangler/transit/selection.py
class TransitSelection:
    """Object to perform and store information about a selection from a project card "facility".

    Attributes:
        selection_dict: dict: Dictionary of selection criteria
        selected_trips: list: List of selected trips
        selected_trips_df: pd.DataFrame: DataFrame of selected trips
        sel_key: str: Hash of selection_dict
        net: TransitNetwork: Network to select from
    """

    def __init__(
        self,
        net: TransitNetwork,
        selection_dict: Union[dict, SelectTransitTrips],
    ):
        """Constructor for TransitSelection object.

        Args:
            net (TransitNetwork): Transit network object to select from.
            selection_dict: Selection dictionary conforming to SelectTransitTrips
        """
        self.net = net
        self.selection_dict = selection_dict

        # Initialize
        self._selected_trips_df = None
        self.sel_key = dict_to_hexkey(selection_dict)
        self._stored_feed_hash = copy.deepcopy(self.net.feed.hash)

        WranglerLogger.debug(f"...created TransitSelection object: {selection_dict}")

    def __nonzero__(self):
        """Return True if there are selected trips."""
        return len(self.selected_trips_df) > 0

    @property
    def selection_dict(self):
        """Getter for selection_dict."""
        return self._selection_dict

    @selection_dict.setter
    def selection_dict(self, value: Union[dict, SelectTransitTrips]):
        self._selection_dict = self.validate_selection_dict(value)

    def validate_selection_dict(self, selection_dict: Union[dict, SelectTransitTrips]) -> dict:
        """Check that selection dictionary has valid and used properties consistent with network.

        Checks that selection_dict is a valid TransitSelectionDict:
            - query vars exist in respective Feed tables
        Args:
            selection_dict (dict): selection dictionary

        Raises:
            TransitSelectionNetworkConsistencyError: If not consistent with transit network
            ValidationError: if format not consistent with SelectTransitTrips
        """
        if not isinstance(selection_dict, SelectTransitTrips):
            selection_dict = SelectTransitTrips(**selection_dict)
        selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
        WranglerLogger.debug(f"SELECT DICT - before Validation: \n{selection_dict}")
        _trip_selection_fields = list((selection_dict.get("trip_properties", {}) or {}).keys())
        _missing_trip_fields = set(_trip_selection_fields) - set(self.net.feed.trips.columns)

        if _missing_trip_fields:
            msg = f"Fields in trip selection dictionary but not trips.txt: {_missing_trip_fields}"
            raise TransitSelectionNetworkConsistencyError(msg)

        _route_selection_fields = list((selection_dict.get("route_properties", {}) or {}).keys())
        _missing_route_fields = set(_route_selection_fields) - set(self.net.feed.routes.columns)

        if _missing_route_fields:
            msg = (
                f"Fields in route selection dictionary but not routes.txt: {_missing_route_fields}"
            )
            raise TransitSelectionNetworkConsistencyError(msg)
        return selection_dict

    @property
    def selected_trips(self) -> list:
        """List of selected trip_ids."""
        if self.selected_trips_df is None:
            return []
        return self.selected_trips_df.trip_id.tolist()

    @property
    def selected_trips_df(self) -> DataFrame[WranglerTripsTable]:
        """Lazily evaluates selection for trips or returns stored value in self._selected_trips_df.

        Will re-evaluate if the current network hash is different than the stored one from the
        last selection.

        Returns:
            DataFrame[WranglerTripsTable] of selected trips
        """
        if (self._selected_trips_df is not None) and self._stored_feed_hash == self.net.feed_hash:
            return self._selected_trips_df

        self._selected_trips_df = self._select_trips()
        self._stored_feed_hash = copy.deepcopy(self.net.feed_hash)
        return self._selected_trips_df

    @property
    def selected_frequencies_df(self) -> DataFrame[WranglerFrequenciesTable]:
        """DataFrame of selected frequencies."""
        sel_freq_df = self.net.feed.frequencies.loc[
            self.net.feed.frequencies.trip_id.isin(self.selected_trips_df.trip_id)
        ]
        # if timespans are selected, filter to those that overlap
        if self.selection_dict.get("timespans"):
            sel_freq_df = filter_df_to_overlapping_timespans(
                sel_freq_df, self.selection_dict.get("timespans")
            )
        return sel_freq_df

    @property
    def selected_shapes_df(self) -> DataFrame[WranglerShapesTable]:
        """DataFrame of selected shapes.

        Can visualize the selected shapes quickly using the following code:

        ```python
        all_routes = net.feed.shapes.plot(color="gray")
        selection.selected_shapes_df.plot(ax=all_routes, color="red")
        ```

        """
        return self.net.feed.shapes.loc[
            self.net.feed.shapes.shape_id.isin(self.selected_trips_df.shape_id)
        ]

    def _select_trips(self) -> DataFrame[WranglerTripsTable]:
        """Selects transit trips based on selection dictionary.

        Returns:
            DataFrame[WranglerTripsTable]: trips_df DataFrame of selected trips
        """
        return _filter_trips_by_selection_dict(
            self.net.feed,
            self.selection_dict,
        )
selected_frequencies_df: DataFrame[WranglerFrequenciesTable] property

DataFrame of selected frequencies.

selected_shapes_df: DataFrame[WranglerShapesTable] property

DataFrame of selected shapes.

Can visualize the selected shapes quickly using the following code:

all_routes = net.feed.shapes.plot(color="gray")
selection.selected_shapes_df.plot(ax=all_routes, color="red")
selected_trips: list property

List of selected trip_ids.

selected_trips_df: DataFrame[WranglerTripsTable] property

Lazily evaluates selection for trips or returns stored value in self._selected_trips_df.

Will re-evaluate if the current network hash is different than the stored one from the last selection.

Returns:

Type Description
DataFrame[WranglerTripsTable]

DataFrame[WranglerTripsTable] of selected trips

selection_dict property writable

Getter for selection_dict.

__init__(net, selection_dict)

Constructor for TransitSelection object.

Parameters:

Name Type Description Default
net TransitNetwork

Transit network object to select from.

required
selection_dict Union[dict, SelectTransitTrips]

Selection dictionary conforming to SelectTransitTrips

required
Source code in network_wrangler/transit/selection.py
def __init__(
    self,
    net: TransitNetwork,
    selection_dict: Union[dict, SelectTransitTrips],
):
    """Constructor for TransitSelection object.

    Args:
        net (TransitNetwork): Transit network object to select from.
        selection_dict: Selection dictionary conforming to SelectTransitTrips
    """
    self.net = net
    self.selection_dict = selection_dict

    # Initialize
    self._selected_trips_df = None
    self.sel_key = dict_to_hexkey(selection_dict)
    self._stored_feed_hash = copy.deepcopy(self.net.feed.hash)

    WranglerLogger.debug(f"...created TransitSelection object: {selection_dict}")
__nonzero__()

Return True if there are selected trips.

Source code in network_wrangler/transit/selection.py
def __nonzero__(self):
    """Return True if there are selected trips."""
    return len(self.selected_trips_df) > 0
validate_selection_dict(selection_dict)

Check that selection dictionary has valid and used properties consistent with network.

Checks that selection_dict is a valid TransitSelectionDict
  • query vars exist in respective Feed tables

Raises:

Type Description
TransitSelectionNetworkConsistencyError

If not consistent with transit network

ValidationError

if format not consistent with SelectTransitTrips

Source code in network_wrangler/transit/selection.py
def validate_selection_dict(self, selection_dict: Union[dict, SelectTransitTrips]) -> dict:
    """Check that selection dictionary has valid and used properties consistent with network.

    Checks that selection_dict is a valid TransitSelectionDict:
        - query vars exist in respective Feed tables
    Args:
        selection_dict (dict): selection dictionary

    Raises:
        TransitSelectionNetworkConsistencyError: If not consistent with transit network
        ValidationError: if format not consistent with SelectTransitTrips
    """
    if not isinstance(selection_dict, SelectTransitTrips):
        selection_dict = SelectTransitTrips(**selection_dict)
    selection_dict = selection_dict.model_dump(exclude_none=True, by_alias=True)
    WranglerLogger.debug(f"SELECT DICT - before Validation: \n{selection_dict}")
    _trip_selection_fields = list((selection_dict.get("trip_properties", {}) or {}).keys())
    _missing_trip_fields = set(_trip_selection_fields) - set(self.net.feed.trips.columns)

    if _missing_trip_fields:
        msg = f"Fields in trip selection dictionary but not trips.txt: {_missing_trip_fields}"
        raise TransitSelectionNetworkConsistencyError(msg)

    _route_selection_fields = list((selection_dict.get("route_properties", {}) or {}).keys())
    _missing_route_fields = set(_route_selection_fields) - set(self.net.feed.routes.columns)

    if _missing_route_fields:
        msg = (
            f"Fields in route selection dictionary but not routes.txt: {_missing_route_fields}"
        )
        raise TransitSelectionNetworkConsistencyError(msg)
    return selection_dict

Functions to check for transit network validity and consistency with roadway network.

Validate that links in transit shapes exist in referenced roadway links.

Parameters:

Name Type Description Default
tr_shapes DataFrame[WranglerShapesTable]

transit shapes from shapes.txt to validate foreign key to.

required
rd_links_df DataFrame[RoadLinksTable]

Links dataframe from roadway network to validate

required

Returns:

Type Description
DataFrame

df with shape_id and A, B

Source code in network_wrangler/transit/validate.py
def shape_links_without_road_links(
    tr_shapes: DataFrame[WranglerShapesTable],
    rd_links_df: DataFrame[RoadLinksTable],
) -> pd.DataFrame:
    """Validate that links in transit shapes exist in referenced roadway links.

    Args:
        tr_shapes: transit shapes from shapes.txt to validate foreign key to.
        rd_links_df: Links dataframe from roadway network to validate

    Returns:
        df with shape_id and A, B
    """
    tr_shape_links = unique_shape_links(tr_shapes)
    # WranglerLogger.debug(f"Unique shape links: \n {tr_shape_links}")
    rd_links_transit_ok = rd_links_df[
        (rd_links_df["drive_access"]) | (rd_links_df["bus_only"]) | (rd_links_df["rail_only"])
    ]

    merged_df = tr_shape_links.merge(
        rd_links_transit_ok[["A", "B"]],
        how="left",
        on=["A", "B"],
        indicator=True,
    )

    missing_links_df = merged_df.loc[merged_df._merge == "left_only", ["shape_id", "A", "B"]]
    if len(missing_links_df):
        WranglerLogger.error(
            f"! Transit shape links missing in roadway network: \n {missing_links_df}"
        )
    return missing_links_df[["shape_id", "A", "B"]]

Validate that links in transit shapes exist in referenced roadway links.

Parameters:

Name Type Description Default
tr_stop_times DataFrame[WranglerStopTimesTable]

transit stop_times from stop_times.txt to validate foreign key to.

required
rd_links_df DataFrame[RoadLinksTable]

Links dataframe from roadway network to validate

required

Returns:

Type Description
DataFrame

df with shape_id and A, B

Source code in network_wrangler/transit/validate.py
def stop_times_without_road_links(
    tr_stop_times: DataFrame[WranglerStopTimesTable],
    rd_links_df: DataFrame[RoadLinksTable],
) -> pd.DataFrame:
    """Validate that links in transit shapes exist in referenced roadway links.

    Args:
        tr_stop_times: transit stop_times from stop_times.txt to validate foreign key to.
        rd_links_df: Links dataframe from roadway network to validate

    Returns:
        df with shape_id and A, B
    """
    tr_links = unique_stop_time_links(tr_stop_times)

    rd_links_transit_ok = rd_links_df[
        (rd_links_df["drive_access"]) | (rd_links_df["bus_only"]) | (rd_links_df["rail_only"])
    ]

    merged_df = tr_links.merge(
        rd_links_transit_ok[["A", "B"]],
        how="left",
        on=["A", "B"],
        indicator=True,
    )

    missing_links_df = merged_df.loc[merged_df._merge == "left_only", ["trip_id", "A", "B"]]
    if len(missing_links_df):
        WranglerLogger.error(
            f"! Transit stop_time links missing in roadway network: \n {missing_links_df}"
        )
    return missing_links_df[["trip_id", "A", "B"]]

transit_nodes_without_road_nodes(feed, nodes_df, rd_field='model_node_id')

Validate all of a transit feeds node foreign keys exist in referenced roadway nodes.

Parameters:

Name Type Description Default
feed Feed

Transit Feed to query.

required
nodes_df DataFrame

Nodes dataframe from roadway network to validate foreign key to. Defaults to self.roadway_net.nodes_df

required
rd_field str

field in roadway nodes to check against. Defaults to “model_node_id”

'model_node_id'

Returns:

Type Description
list[int]

boolean indicating if relationships are all valid

Source code in network_wrangler/transit/validate.py
def transit_nodes_without_road_nodes(
    feed: Feed,
    nodes_df: DataFrame[RoadNodesTable],
    rd_field: str = "model_node_id",
) -> list[int]:
    """Validate all of a transit feeds node foreign keys exist in referenced roadway nodes.

    Args:
        feed: Transit Feed to query.
        nodes_df (pd.DataFrame, optional): Nodes dataframe from roadway network to validate
            foreign key to. Defaults to self.roadway_net.nodes_df
        rd_field: field in roadway nodes to check against. Defaults to "model_node_id"

    Returns:
        boolean indicating if relationships are all valid
    """
    feed_nodes_series = [
        feed.stops["stop_id"],
        feed.shapes["shape_model_node_id"],
        feed.stop_times["stop_id"],
    ]
    tr_nodes = set(concat_with_attr(feed_nodes_series).unique())
    rd_nodes = set(nodes_df[rd_field].unique().tolist())
    # nodes in tr_nodes but not rd_nodes
    missing_tr_nodes = list(tr_nodes - rd_nodes)

    if missing_tr_nodes:
        WranglerLogger.error(
            f"! Transit nodes in missing in roadway network: \n {missing_tr_nodes}"
        )
    return missing_tr_nodes

transit_road_net_consistency(feed, road_net)

Checks foreign key and network link relationships between transit feed and a road_net.

Parameters:

Name Type Description Default
feed Feed

Transit Feed.

required
road_net RoadwayNetwork

Roadway network to check relationship with.

required

Returns:

Name Type Description
bool bool

boolean indicating if road_net is consistent with transit network.

Source code in network_wrangler/transit/validate.py
def transit_road_net_consistency(feed: Feed, road_net: RoadwayNetwork) -> bool:
    """Checks foreign key and network link relationships between transit feed and a road_net.

    Args:
        feed: Transit Feed.
        road_net (RoadwayNetwork): Roadway network to check relationship with.

    Returns:
        bool: boolean indicating if road_net is consistent with transit network.
    """
    _missing_links = shape_links_without_road_links(feed.shapes, road_net.links_df)
    _missing_nodes = transit_nodes_without_road_nodes(feed, road_net.nodes_df)
    _consistency = _missing_links.empty and not _missing_nodes
    return _consistency

validate_transit_in_dir(dir, file_format='txt', road_dir=None, road_file_format='geojson')

Validates a roadway network in a directory to the wrangler data model specifications.

Parameters:

Name Type Description Default
dir Path

The transit network file directory.

required
file_format str

The format of roadway network file name. Defaults to “txt”.

'txt'
road_dir Path

The roadway network file directory. Defaults to None.

None
road_file_format str

The format of roadway network file name. Defaults to “geojson”.

'geojson'
output_dir str

The output directory for the validation report. Defaults to “.”.

required
Source code in network_wrangler/transit/validate.py
def validate_transit_in_dir(
    dir: Path,
    file_format: TransitFileTypes = "txt",
    road_dir: Optional[Path] = None,
    road_file_format: RoadwayFileTypes = "geojson",
) -> bool:
    """Validates a roadway network in a directory to the wrangler data model specifications.

    Args:
        dir (Path): The transit network file directory.
        file_format (str): The format of roadway network file name. Defaults to "txt".
        road_dir (Path): The roadway network file directory. Defaults to None.
        road_file_format (str): The format of roadway network file name. Defaults to "geojson".
        output_dir (str): The output directory for the validation report. Defaults to ".".
    """
    from .io import load_transit

    try:
        t = load_transit(dir, file_format=file_format)
    except SchemaErrors as e:
        WranglerLogger.error(f"!!! [Transit Network invalid] - Failed Loading to Feed object\n{e}")
        return False
    if road_dir is not None:
        from ..roadway import load_roadway_from_dir
        from .network import TransitRoadwayConsistencyError

        try:
            r = load_roadway_from_dir(road_dir, file_format=road_file_format)
        except FileNotFoundError:
            WranglerLogger.error(f"! Roadway network not found in {road_dir}")
            return False
        except Exception as e:
            WranglerLogger.error(f"! Error loading roadway network. \
                                 Skipping validation of road to transit network.\n{e}")
        try:
            t.road_net = r
        except TransitRoadwayConsistencyError as e:
            WranglerLogger.error(f"!!! [Tranit Network inconsistent] Error in road to transit \
                                 network consistency.\n{e}")
            return False

    return True

Utils and Functions

General utility functions used throughout package.

DictionaryMergeError

Bases: Exception

Error raised when there is a conflict in merging two dictionaries.

Source code in network_wrangler/utils/utils.py
class DictionaryMergeError(Exception):
    """Error raised when there is a conflict in merging two dictionaries."""

check_one_or_one_superset_present(mixed_list, all_fields_present)

Checks that exactly one of the fields in mixed_list is in fields_present or one superset.

Source code in network_wrangler/utils/utils.py
def check_one_or_one_superset_present(
    mixed_list: list[Union[str, list[str]]], all_fields_present: list[str]
) -> bool:
    """Checks that exactly one of the fields in mixed_list is in fields_present or one superset."""
    normalized_list = normalize_to_lists(mixed_list)

    list_items_present = [i for i in normalized_list if set(i).issubset(all_fields_present)]

    if len(list_items_present) == 1:
        return True

    return list_elements_subset_of_single_element(list_items_present)

combine_unique_unhashable_list(list1, list2)

Combines lists preserving order of first and removing duplicates.

Parameters:

Name Type Description Default
list1 list

The first list.

required
list2 list

The second list.

required

Returns:

Name Type Description
list

A new list containing the elements from list1 followed by the

unique elements from list2.

Example

list1 = [1, 2, 3] list2 = [2, 3, 4, 5] combine_unique_unhashable_list(list1, list2) [1, 2, 3, 4, 5]

Source code in network_wrangler/utils/utils.py
def combine_unique_unhashable_list(list1: list, list2: list):
    """Combines lists preserving order of first and removing duplicates.

    Args:
        list1 (list): The first list.
        list2 (list): The second list.

    Returns:
        list: A new list containing the elements from list1 followed by the
        unique elements from list2.

    Example:
        >>> list1 = [1, 2, 3]
        >>> list2 = [2, 3, 4, 5]
        >>> combine_unique_unhashable_list(list1, list2)
        [1, 2, 3, 4, 5]
    """
    return [item for item in list1 if item not in list2] + list2

delete_keys_from_dict(dictionary, keys)

Removes list of keys from potentially nested dictionary.

SOURCE: https://stackoverflow.com/questions/3405715/ User: @mseifert

Parameters:

Name Type Description Default
dictionary dict

dictionary to remove keys from

required
keys list

list of keys to remove

required
Source code in network_wrangler/utils/utils.py
def delete_keys_from_dict(dictionary: dict, keys: list) -> dict:
    """Removes list of keys from potentially nested dictionary.

    SOURCE: https://stackoverflow.com/questions/3405715/
    User: @mseifert

    Args:
        dictionary: dictionary to remove keys from
        keys: list of keys to remove

    """
    keys_set = list(set(keys))  # Just an optimization for the "if key in keys" lookup.

    modified_dict = {}
    for key, value in dictionary.items():
        if key not in keys_set:
            if isinstance(value, dict):
                modified_dict[key] = delete_keys_from_dict(value, keys_set)
            else:
                modified_dict[key] = (
                    value  # or copy.deepcopy(value) if a copy is desired for non-dicts.
                )
    return modified_dict

dict_to_hexkey(d)

Converts a dictionary to a hexdigest of the sha1 hash of the dictionary.

Parameters:

Name Type Description Default
d dict

dictionary to convert to string

required

Returns:

Name Type Description
str str

hexdigest of the sha1 hash of dictionary

Source code in network_wrangler/utils/utils.py
def dict_to_hexkey(d: dict) -> str:
    """Converts a dictionary to a hexdigest of the sha1 hash of the dictionary.

    Args:
        d (dict): dictionary to convert to string

    Returns:
        str: hexdigest of the sha1 hash of dictionary
    """
    return hashlib.sha1(str(d).encode()).hexdigest()

findkeys(node, kv)

Returns values of all keys in various objects.

Adapted from arainchi on Stack Overflow: https://stackoverflow.com/questions/9807634/find-all-occurrences-of-a-key-in-nested-dictionaries-and-lists

Source code in network_wrangler/utils/utils.py
def findkeys(node, kv):
    """Returns values of all keys in various objects.

    Adapted from arainchi on Stack Overflow:
    https://stackoverflow.com/questions/9807634/find-all-occurrences-of-a-key-in-nested-dictionaries-and-lists
    """
    if isinstance(node, list):
        for i in node:
            for x in findkeys(i, kv):
                yield x
    elif isinstance(node, dict):
        if kv in node:
            yield node[kv]
        for j in node.values():
            for x in findkeys(j, kv):
                yield x

get_overlapping_range(ranges)

Returns the overlapping range for a list of ranges or tuples defining ranges.

Parameters:

Name Type Description Default
ranges list[Union[tuple[int], range]]

A list of ranges or tuples defining ranges.

required

Returns:

Type Description
Union[None, range]

Union[None, range]: The overlapping range if found, otherwise None.

Example

ranges = [(1, 5), (3, 7), (6, 10)] get_overlapping_range(ranges) range(3, 5)

Source code in network_wrangler/utils/utils.py
def get_overlapping_range(ranges: list[Union[tuple[int, int], range]]) -> Union[None, range]:
    """Returns the overlapping range for a list of ranges or tuples defining ranges.

    Args:
        ranges (list[Union[tuple[int], range]]): A list of ranges or tuples defining ranges.

    Returns:
        Union[None, range]: The overlapping range if found, otherwise None.

    Example:
        >>> ranges = [(1, 5), (3, 7), (6, 10)]
        >>> get_overlapping_range(ranges)
        range(3, 5)

    """
    # check that any tuples have two values
    if any(isinstance(r, tuple) and len(r) != 2 for r in ranges):  # noqa: PLR2004
        msg = "Tuple ranges must have two values."
        WranglerLogger.error(msg)
        raise ValueError(msg)

    _ranges = [r if isinstance(r, range) else range(r[0], r[1]) for r in ranges]

    _overlap_start = max(r.start for r in _ranges)
    _overlap_end = min(r.stop for r in _ranges)

    if _overlap_start < _overlap_end:
        return range(_overlap_start, _overlap_end)
    return None

list_elements_subset_of_single_element(mixed_list)

Find the first list in the mixed_list.

Source code in network_wrangler/utils/utils.py
@validate_call
def list_elements_subset_of_single_element(mixed_list: list[Union[str, list[str]]]) -> bool:
    """Find the first list in the mixed_list."""
    potential_supersets = []
    for item in mixed_list:
        if isinstance(item, list) and len(item) > 0:
            potential_supersets.append(set(item))

    # If no list is found, return False
    if not potential_supersets:
        return False

    normalized_list = normalize_to_lists(mixed_list)

    valid_supersets = []
    for ss in potential_supersets:
        if all(ss.issuperset(i) for i in normalized_list):
            valid_supersets.append(ss)

    return len(valid_supersets) == 1

make_slug(text, delimiter='_')

Makes a slug from text.

Source code in network_wrangler/utils/utils.py
def make_slug(text: str, delimiter: str = "_") -> str:
    """Makes a slug from text."""
    text = re.sub("[,.;@#?!&$']+", "", text.lower())
    return re.sub("[\ ]+", delimiter, text)

merge_dicts(right, left, path=None)

Merges the contents of nested dict left into nested dict right.

Raises errors in case of namespace conflicts.

Parameters:

Name Type Description Default
right

dict, modified in place

required
left

dict to be merged into right

required
path

default None, sequence of keys to be reported in case of error in merging nested dictionaries

None
Source code in network_wrangler/utils/utils.py
def merge_dicts(right, left, path=None):
    """Merges the contents of nested dict left into nested dict right.

    Raises errors in case of namespace conflicts.

    Args:
        right: dict, modified in place
        left: dict to be merged into right
        path: default None, sequence of keys to be reported in case of
            error in merging nested dictionaries
    """
    if path is None:
        path = []
    for key in left:
        if key in right:
            if isinstance(right[key], dict) and isinstance(left[key], dict):
                merge_dicts(right[key], left[key], [*path, str(key)])
            else:
                path = ".".join([*path, str(key)])
                msg = f"duplicate keys in source dict files: {path}"
                WranglerLogger.error(msg)
                raise DictionaryMergeError(msg)
        else:
            right[key] = left[key]

normalize_to_lists(mixed_list)

Turn a mixed list of scalars and lists into a list of lists.

Source code in network_wrangler/utils/utils.py
def normalize_to_lists(mixed_list: list[Union[str, list]]) -> list[list]:
    """Turn a mixed list of scalars and lists into a list of lists."""
    normalized_list = []
    for item in mixed_list:
        if isinstance(item, str):
            normalized_list.append([item])
        else:
            normalized_list.append(item)
    return normalized_list

split_string_prefix_suffix_from_num(input_string)

Split a string prefix and suffix from last number.

Parameters:

Name Type Description Default
input_string str

The input string to be processed.

required

Returns:

Name Type Description
tuple

A tuple containing the prefix (including preceding numbers), the last numeric part as an integer, and the suffix.

Notes

This function uses regular expressions to split a string into three parts: the prefix, the last numeric part, and the suffix. The prefix includes any preceding numbers, the last numeric part is converted to an integer, and the suffix includes any non-digit characters after the last numeric part.

Examples:

>>> split_string_prefix_suffix_from_num("abc123def456")
('abc', 123, 'def456')
>>> split_string_prefix_suffix_from_num("hello")
('hello', 0, '')
>>> split_string_prefix_suffix_from_num("123")
('', 123, '')
Source code in network_wrangler/utils/utils.py
def split_string_prefix_suffix_from_num(input_string: str):
    """Split a string prefix and suffix from *last* number.

    Args:
        input_string (str): The input string to be processed.

    Returns:
        tuple: A tuple containing the prefix (including preceding numbers),
               the last numeric part as an integer, and the suffix.

    Notes:
        This function uses regular expressions to split a string into three parts:
        the prefix, the last numeric part, and the suffix. The prefix includes any
        preceding numbers, the last numeric part is converted to an integer, and
        the suffix includes any non-digit characters after the last numeric part.

    Examples:
        >>> split_string_prefix_suffix_from_num("abc123def456")
        ('abc', 123, 'def456')

        >>> split_string_prefix_suffix_from_num("hello")
        ('hello', 0, '')

        >>> split_string_prefix_suffix_from_num("123")
        ('', 123, '')

    """
    input_string = str(input_string)
    pattern = re.compile(r"(.*?)(\d+)(\D*)$")
    match = pattern.match(input_string)

    if match:
        # Extract the groups: prefix (including preceding numbers), last numeric part, suffix
        prefix, numeric_part, suffix = match.groups()
        # Convert the numeric part to an integer
        num_variable = int(numeric_part)
        return prefix, num_variable, suffix
    return input_string, 0, ""

topological_sort(adjacency_list, visited_list)

Topological sorting for Acyclic Directed Graph.

Parameters: - adjacency_list (dict): A dictionary representing the adjacency list of the graph. - visited_list (list): A list representing the visited status of each vertex in the graph.

Returns: - output_stack (list): A list containing the vertices in topological order.

This function performs a topological sort on an acyclic directed graph. It takes an adjacency list and a visited list as input. The adjacency list represents the connections between vertices in the graph, and the visited list keeps track of the visited status of each vertex.

The function uses a recursive helper function to perform the topological sort. It starts by iterating over each vertex in the visited list. For each unvisited vertex, it calls the helper function, which recursively visits all the neighbors of the vertex and adds them to the output stack in reverse order. Finally, it returns the output stack, which contains the vertices in topological order.

Source code in network_wrangler/utils/utils.py
def topological_sort(adjacency_list, visited_list):
    """Topological sorting for Acyclic Directed Graph.

    Parameters:
    - adjacency_list (dict): A dictionary representing the adjacency list of the graph.
    - visited_list (list): A list representing the visited status of each vertex in the graph.

    Returns:
    - output_stack (list): A list containing the vertices in topological order.

    This function performs a topological sort on an acyclic directed graph. It takes an adjacency
    list and a visited list as input. The adjacency list represents the connections between
    vertices in the graph, and the visited list keeps track of the visited status of each vertex.

    The function uses a recursive helper function to perform the topological sort. It starts by
    iterating over each vertex in the visited list. For each unvisited vertex, it calls the helper
    function, which recursively visits all the neighbors of the vertex and adds them to the output
    stack in reverse order. Finally, it returns the output stack, which contains the vertices in
    topological order.
    """
    output_stack = []

    def _topology_sort_util(vertex):
        if not visited_list[vertex]:
            visited_list[vertex] = True
            for neighbor in adjacency_list[vertex]:
                _topology_sort_util(neighbor)
            output_stack.insert(0, vertex)

    for vertex in visited_list:
        _topology_sort_util(vertex)

    return output_stack

Helper functions for reading and writing files to reduce boilerplate.

FileReadError

Bases: Exception

Raised when there is an error reading a file.

Source code in network_wrangler/utils/io_table.py
class FileReadError(Exception):
    """Raised when there is an error reading a file."""

FileWriteError

Bases: Exception

Raised when there is an error writing a file.

Source code in network_wrangler/utils/io_table.py
class FileWriteError(Exception):
    """Raised when there is an error writing a file."""

convert_file_serialization(input_file, output_file, overwrite=True, boundary_gdf=None, boundary_geocode=None, boundary_file=None, node_filter_s=None, chunk_size=None)

Convert a file serialization format to another and optionally filter to a boundary.

If the input file is a JSON file that is larger than a reasonable portion of available memory, and the output file is a Parquet file the JSON file will be read in chunks.

If the input file is a Geographic data type (shp, geojon, geoparquet) and a boundary is provided, the data will be filtered to the boundary.

Parameters:

Name Type Description Default
input_file Path

Path to the input JSON or GEOJSON file.

required
output_file Path

Path to the output Parquet file.

required
overwrite bool

If True, overwrite the output file if it exists.

True
boundary_gdf Optional[GeoDataFrame]

GeoDataFrame to filter the input data to. Only used for geographic data. Defaults to None.

None
boundary_geocode Optional[str]

Geocode to filter the input data to. Only used for geographic data. Defaults to None.

None
boundary_file Optional[Path]

File to load as a boundary to filter the input data to. Only used for geographic data. Defaults to None.

None
node_filter_s Optional[Series]

If provided, will filter links in .json file to only those that connect to nodes. Defaults to None.

None
chunk_size Optional[int]

Number of JSON objects to process in each chunk. Only works for JSON to Parquet. If None, will determine if chunking needed and what size.

None
Source code in network_wrangler/utils/io_table.py
def convert_file_serialization(
    input_file: Path,
    output_file: Path,
    overwrite: bool = True,
    boundary_gdf: Optional[gpd.GeoDataFrame] = None,
    boundary_geocode: Optional[str] = None,
    boundary_file: Optional[Path] = None,
    node_filter_s: Optional[pd.Series] = None,
    chunk_size: Optional[int] = None,
):
    """Convert a file serialization format to another and optionally filter to a boundary.

    If the input file is a JSON file that is larger than a reasonable portion of available
    memory, *and* the output file is a Parquet file the JSON file will be read in chunks.

    If the input file is a Geographic data type (shp, geojon, geoparquet) and a boundary is
    provided, the data will be filtered to the boundary.

    Args:
        input_file: Path to the input JSON or GEOJSON file.
        output_file: Path to the output Parquet file.
        overwrite: If True, overwrite the output file if it exists.
        boundary_gdf: GeoDataFrame to filter the input data to. Only used for geographic data.
            Defaults to None.
        boundary_geocode: Geocode to filter the input data to. Only used for geographic data.
            Defaults to None.
        boundary_file: File to load as a boundary to filter the input data to. Only used for
            geographic data. Defaults to None.
        node_filter_s: If provided, will filter links in .json file to only those that connect to
            nodes. Defaults to None.
        chunk_size: Number of JSON objects to process in each chunk. Only works for
            JSON to Parquet. If None, will determine if chunking needed and what size.
    """
    WranglerLogger.debug(f"Converting {input_file} to {output_file}.")

    if output_file.exists() and not overwrite:
        msg = f"File {output_file} already exists and overwrite is False."
        raise FileExistsError(msg)

    if Path(input_file).suffix == ".json" and Path(output_file).suffix == ".parquet":
        if chunk_size is None:
            chunk_size = _suggest_json_chunk_size(input_file)
        if chunk_size is None:
            df = read_table(input_file)
            if node_filter_s is not None and "A" in df.columns and "B" in df.columns:
                df = df[df["A"].isin(node_filter_s) | df["B"].isin(node_filter_s)]
            write_table(df, output_file, overwrite=overwrite)
        else:
            _json_to_parquet_in_chunks(input_file, output_file, chunk_size)

    df = read_table(
        input_file,
        boundary_gdf=boundary_gdf,
        boundary_geocode=boundary_geocode,
        boundary_file=boundary_file,
    )
    if node_filter_s is not None and "A" in df.columns and "B" in df.columns:
        df = df[df["A"].isin(node_filter_s) | df["B"].isin(node_filter_s)]
    write_table(df, output_file, overwrite=overwrite)

prep_dir(outdir, overwrite=True)

Prepare a directory for writing files.

Source code in network_wrangler/utils/io_table.py
def prep_dir(outdir: Path, overwrite: bool = True):
    """Prepare a directory for writing files."""
    if not overwrite and outdir.exists() and len(list(outdir.iterdir())) > 0:
        msg = f"Directory {outdir} is not empty and overwrite is False."
        raise FileExistsError(msg)
    outdir.mkdir(parents=True, exist_ok=True)

    # clean out existing files
    for f in outdir.iterdir():
        if f.is_file():
            f.unlink()

read_table(filename, sub_filename=None, boundary_gdf=None, boundary_geocode=None, boundary_file=None, read_speed=DefaultConfig.CPU.EST_PD_READ_SPEED)

Read file and return a dataframe or geodataframe.

If filename is a zip file, will unzip to a temporary directory.

If filename is a geojson or shapefile, will filter the data to the boundary_gdf, boundary_geocode, or boundary_file if provided. Note that you can only provide one of these boundary filters.

If filename is a geoparquet file, will filter the data to the bounding box of the boundary_gdf, boundary_geocode, or boundary_file if provided. Note that you can only provide one of these boundary filters.

NOTE: if you are accessing multiple files from this zip file you will want to unzip it first and THEN access the table files so you don’t create multiple duplicate unzipped tmp dirs.

Parameters:

Name Type Description Default
filename Path

filename to load.

required
sub_filename Optional[str]

if the file is a zip, the sub_filename to load.

None
boundary_gdf Optional[GeoDataFrame]

GeoDataFrame to filter the input data to. Only used for geographic data. Defaults to None.

None
boundary_geocode Optional[str]

Geocode to filter the input data to. Only used for geographic data. Defaults to None.

None
boundary_file Optional[Path]

File to load as a boundary to filter the input data to. Only used for geographic data. Defaults to None.

None
read_speed dict

dictionary of read speeds for different file types. Defaults to DefaultConfig.CPU.EST_PD_READ_SPEED.

EST_PD_READ_SPEED
Source code in network_wrangler/utils/io_table.py
def read_table(
    filename: Path,
    sub_filename: Optional[str] = None,
    boundary_gdf: Optional[gpd.GeoDataFrame] = None,
    boundary_geocode: Optional[str] = None,
    boundary_file: Optional[Path] = None,
    read_speed: dict = DefaultConfig.CPU.EST_PD_READ_SPEED,
) -> Union[pd.DataFrame, gpd.GeoDataFrame]:
    """Read file and return a dataframe or geodataframe.

    If filename is a zip file, will unzip to a temporary directory.

    If filename is a geojson or shapefile, will filter the data
    to the boundary_gdf, boundary_geocode, or boundary_file if provided. Note that you can only
    provide one of these boundary filters.

    If filename is a geoparquet file, will filter the data to the *bounding box* of the
    boundary_gdf, boundary_geocode, or boundary_file if provided. Note that you can only
    provide one of these boundary filters.

    NOTE:  if you are accessing multiple files from this zip file you will want to unzip it first
    and THEN access the table files so you don't create multiple duplicate unzipped tmp dirs.

    Args:
        filename (Path): filename to load.
        sub_filename: if the file is a zip, the sub_filename to load.
        boundary_gdf: GeoDataFrame to filter the input data to. Only used for geographic data.
            Defaults to None.
        boundary_geocode: Geocode to filter the input data to. Only used for geographic data.
            Defaults to None.
        boundary_file: File to load as a boundary to filter the input data to. Only used for
            geographic data. Defaults to None.
        read_speed: dictionary of read speeds for different file types. Defaults to
            DefaultConfig.CPU.EST_PD_READ_SPEED.
    """
    filename = Path(filename)
    if not filename.exists():
        msg = f"Input file {filename} does not exist."
        raise FileNotFoundError(msg)
    if filename.stat().st_size == 0:
        msg = f"File {filename} is empty."
        raise FileExistsError(msg)
    if filename.suffix == ".zip":
        if not sub_filename:
            msg = "sub_filename must be provided for zip files."
            raise ValueError(msg)
        filename = unzip_file(filename) / sub_filename
    WranglerLogger.debug(
        f"Estimated read time: {_estimate_read_time_of_file(filename, read_speed)}."
    )

    # will result in None if no boundary is provided
    mask_gdf = get_bounding_polygon(
        boundary_gdf=boundary_gdf,
        boundary_geocode=boundary_geocode,
        boundary_file=boundary_file,
    )

    if any(x in filename.suffix for x in ["geojson", "shp", "csv"]):
        try:
            # masking only supported by fiona engine, which is slower.
            if mask_gdf is None:
                return gpd.read_file(filename, engine="pyogrio")
            return gpd.read_file(filename, mask=mask_gdf, engine="fiona")
        except Exception as err:
            if "csv" in filename.suffix:
                return pd.read_csv(filename)
            raise FileReadError from err
    elif "parquet" in filename.suffix:
        return _read_parquet_table(filename, mask_gdf)
    elif "json" in filename.suffix:
        with filename.open() as f:
            return pd.read_json(f, orient="records")
    msg = f"Filetype {filename.suffix} not implemented."
    raise NotImplementedError(msg)

unzip_file(path)

Unzips a file to a temporary directory and returns the directory path.

Source code in network_wrangler/utils/io_table.py
def unzip_file(path: Path) -> Path:
    """Unzips a file to a temporary directory and returns the directory path."""
    tmpdir = tempfile.mkdtemp()
    shutil.unpack_archive(path, tmpdir)

    def finalize() -> None:
        shutil.rmtree(tmpdir)

    # Lazy cleanup
    weakref.finalize(tmpdir, finalize)

    return Path(tmpdir)

write_table(df, filename, overwrite=False, **kwargs)

Write a dataframe or geodataframe to a file.

Parameters:

Name Type Description Default
df DataFrame

dataframe to write.

required
filename Path

filename to write to.

required
overwrite bool

whether to overwrite the file if it exists. Defaults to False.

False
kwargs

additional arguments to pass to the writer.

{}
Source code in network_wrangler/utils/io_table.py
def write_table(
    df: Union[pd.DataFrame, gpd.GeoDataFrame],
    filename: Path,
    overwrite: bool = False,
    **kwargs,
) -> None:
    """Write a dataframe or geodataframe to a file.

    Args:
        df (pd.DataFrame): dataframe to write.
        filename (Path): filename to write to.
        overwrite (bool): whether to overwrite the file if it exists. Defaults to False.
        kwargs: additional arguments to pass to the writer.

    """
    filename = Path(filename)
    if filename.exists() and not overwrite:
        msg = f"File {filename} already exists and overwrite is False."
        raise FileExistsError(msg)

    if filename.parent.is_dir() and not filename.parent.exists():
        filename.parent.mkdir(parents=True)

    WranglerLogger.debug(f"Writing to {filename}.")

    if "shp" in filename.suffix:
        df.to_file(filename, index=False, **kwargs)
    elif "parquet" in filename.suffix:
        df.to_parquet(filename, index=False, **kwargs)
    elif "csv" in filename.suffix or "txt" in filename.suffix:
        df.to_csv(filename, index=False, date_format="%H:%M:%S", **kwargs)
    elif "geojson" in filename.suffix:
        # required due to issues with list-like columns
        if isinstance(df, gpd.GeoDataFrame):
            data = df.to_json(drop_id=True)
        else:
            data = df.to_json(orient="records", index=False)
        with filename.open("w", encoding="utf-8") as file:
            file.write(data)
    elif "json" in filename.suffix:
        with filename.open("w") as f:
            f.write(df.to_json(orient="records"))
    else:
        msg = f"Filetype {filename.suffix} not implemented."
        raise NotImplementedError(msg)

Utility functions for loading dictionaries from files.

load_dict(path)

Load a dictionary from a file.

Source code in network_wrangler/utils/io_dict.py
def load_dict(path: Path) -> dict:
    """Load a dictionary from a file."""
    path = Path(path)
    if not path.is_file():
        msg = f"Specified dict file {path} not found."
        raise FileNotFoundError(msg)

    if path.suffix.lower() == ".toml":
        return _load_toml(path)
    if path.suffix.lower() == ".json":
        return _load_json(path)
    if path.suffix.lower() == ".yaml" or path.suffix.lower() == ".yml":
        return _load_yaml(path)
    msg = f"Filetype {path.suffix} not implemented."
    raise NotImplementedError(msg)

load_merge_dict(path)

Load and merge multiple dictionaries from files.

Source code in network_wrangler/utils/io_dict.py
def load_merge_dict(path: Union[Path, list[Path]]) -> dict:
    """Load and merge multiple dictionaries from files."""
    if not isinstance(path, list):
        path = [path]
    data = load_dict(path[0])
    for path_item in path[1:]:
        merge_dicts(data, load_dict(path_item))
    return data

Helper functions for data models.

DatamodelDataframeIncompatableError

Bases: Exception

Raised when a data model and a dataframe are not compatable.

Source code in network_wrangler/utils/models.py
class DatamodelDataframeIncompatableError(Exception):
    """Raised when a data model and a dataframe are not compatable."""

TableValidationError

Bases: Exception

Raised when a table validation fails.

Source code in network_wrangler/utils/models.py
class TableValidationError(Exception):
    """Raised when a table validation fails."""

coerce_extra_fields_to_type_in_df(data, model, df)

Coerce extra fields in data that aren’t specified in Pydantic model to the type in the df.

Note: will not coerce lists of submodels, etc.

Parameters:

Name Type Description Default
data dict

The data to coerce.

required
model BaseModel

The Pydantic model to validate the data against.

required
df DataFrame

The DataFrame to coerce the data to.

required
Source code in network_wrangler/utils/models.py
def coerce_extra_fields_to_type_in_df(
    data: BaseModel, model: BaseModel, df: pd.DataFrame
) -> BaseModel:
    """Coerce extra fields in data that aren't specified in Pydantic model to the type in the df.

    Note: will not coerce lists of submodels, etc.

    Args:
        data (dict): The data to coerce.
        model (BaseModel): The Pydantic model to validate the data against.
        df (pd.DataFrame): The DataFrame to coerce the data to.
    """
    out_data = copy.deepcopy(data)

    # Coerce submodels
    for field in submodel_fields_in_model(model, data):
        out_data.__dict__[field] = coerce_extra_fields_to_type_in_df(
            data.__dict__[field], model.__annotations__[field], df
        )

    for field in extra_attributes_undefined_in_model(data, model):
        try:
            v = coerce_val_to_df_types(field, data.model_extra[field], df)
        except ValueError as err:
            raise DatamodelDataframeIncompatableError() from err
        out_data.model_extra[field] = v
    return out_data

default_from_datamodel(data_model, field)

Returns default value from pandera data model for a given field name.

Source code in network_wrangler/utils/models.py
def default_from_datamodel(data_model: pa.DataFrameModel, field: str):
    """Returns default value from pandera data model for a given field name."""
    if field in data_model.__fields__ and hasattr(data_model.__fields__[field][1], "default"):
        return data_model.__fields__[field][1].default
    return None

empty_df_from_datamodel(model, crs=LAT_LON_CRS)

Create an empty DataFrame or GeoDataFrame with the specified columns.

Parameters:

Name Type Description Default
model BaseModel

A pandera data model to create empty [Geo]DataFrame from.

required
crs int

if schema has geometry, will use this as the geometry’s crs. Defaults to LAT_LONG_CRS

LAT_LON_CRS
Source code in network_wrangler/utils/models.py
def empty_df_from_datamodel(
    model: DataFrameModel, crs: int = LAT_LON_CRS
) -> Union[gpd.GeoDataFrame, pd.DataFrame]:
    """Create an empty DataFrame or GeoDataFrame with the specified columns.

    Args:
        model (BaseModel): A pandera data model to create empty [Geo]DataFrame from.
        crs: if schema has geometry, will use this as the geometry's crs. Defaults to LAT_LONG_CRS
    Returns:
        An empty [Geo]DataFrame that validates to the specified model.
    """
    schema = model.to_schema()
    data: dict[str, list] = {col: [] for col in schema.columns}

    if "geometry" in data:
        return model(gpd.GeoDataFrame(data, crs=crs))

    return model(pd.DataFrame(data))

extra_attributes_undefined_in_model(instance, model)

Find the extra attributes in a pydantic model that are not defined in the model.

Source code in network_wrangler/utils/models.py
def extra_attributes_undefined_in_model(instance: BaseModel, model: BaseModel) -> list:
    """Find the extra attributes in a pydantic model that are not defined in the model."""
    defined_fields = model.model_fields
    all_attributes = list(instance.model_dump(exclude_none=True, by_alias=True).keys())
    extra_attributes = [a for a in all_attributes if a not in defined_fields]
    return extra_attributes

fill_df_with_defaults_from_model(df, model)

Fill a DataFrame with default values from a Pandera DataFrameModel.

Parameters:

Name Type Description Default
df

DataFrame to fill with default values.

required
model

Pandera DataFrameModel to get default values from.

required
Source code in network_wrangler/utils/models.py
def fill_df_with_defaults_from_model(df, model):
    """Fill a DataFrame with default values from a Pandera DataFrameModel.

    Args:
        df: DataFrame to fill with default values.
        model: Pandera DataFrameModel to get default values from.
    """
    for c in df.columns:
        default_value = default_from_datamodel(model, c)
        if default_value is None:
            df[c] = df[c].where(pd.notna(df[c]), None)
        else:
            df[c] = df[c].fillna(default_value)
    return df

identify_model(data, models)

Identify the model that the input data conforms to.

Parameters:

Name Type Description Default
data Union[DataFrame, dict]

The input data to identify.

required
models list[DataFrameModel, BaseModel]

A list of models to validate the input data against.

required
Source code in network_wrangler/utils/models.py
def identify_model(
    data: Union[pd.DataFrame, dict], models: list
) -> Union[DataFrameModel, BaseModel]:
    """Identify the model that the input data conforms to.

    Args:
        data (Union[pd.DataFrame, dict]): The input data to identify.
        models (list[DataFrameModel,BaseModel]): A list of models to validate the input
          data against.
    """
    for m in models:
        try:
            if isinstance(data, pd.DataFrame):
                validate_df_to_model(data, m)
            else:
                m(**data)
            return m
        except ValidationError:
            continue
        except SchemaError:
            continue

    WranglerLogger.error(
        f"The input data isn't consistant with any provided data model.\
                         \nInput data: {data}\
                         \nData Models: {models}"
    )
    msg = "The input data isn't consistant with any provided data model."
    raise TableValidationError(msg)

order_fields_from_data_model(df, model)

Order the fields in a DataFrame to match the order in a Pandera DataFrameModel.

Will add any fields that are not in the model to the end of the DataFrame. Will not add any fields that are in the model but not in the DataFrame.

Parameters:

Name Type Description Default
df DataFrame

DataFrame to order.

required
model DataFrameModel

Pandera DataFrameModel to order the DataFrame to.

required
Source code in network_wrangler/utils/models.py
def order_fields_from_data_model(df: pd.DataFrame, model: DataFrameModel) -> pd.DataFrame:
    """Order the fields in a DataFrame to match the order in a Pandera DataFrameModel.

    Will add any fields that are not in the model to the end of the DataFrame.
    Will not add any fields that are in the model but not in the DataFrame.

    Args:
        df: DataFrame to order.
        model: Pandera DataFrameModel to order the DataFrame to.
    """
    model_fields = list(model.__fields__.keys())
    df_model_fields = [f for f in model_fields if f in df.columns]
    df_additional_fields = [f for f in df.columns if f not in model_fields]
    return df[df_model_fields + df_additional_fields]

submodel_fields_in_model(model, instance=None)

Find the fields in a pydantic model that are submodels.

Source code in network_wrangler/utils/models.py
def submodel_fields_in_model(model: type, instance: Optional[BaseModel] = None) -> list:
    """Find the fields in a pydantic model that are submodels."""
    types = get_type_hints(model)
    model_type = (ModelMetaclass, BaseModel)
    submodels = [f for f in model.model_fields if isinstance(types.get(f), model_type)]
    if instance is not None:
        defined = list(instance.model_dump(exclude_none=True, by_alias=True).keys())
        return [f for f in submodels if f in defined]
    return submodels

validate_call_pyd(func)

Decorator to validate the function i/o using Pydantic models without Pandera.

Source code in network_wrangler/utils/models.py
def validate_call_pyd(func):
    """Decorator to validate the function i/o using Pydantic models without Pandera."""

    @wraps(func)
    def wrapper(*args, **kwargs):
        type_hints = get_type_hints(func)
        # Modify the type hints to replace pandera DataFrame models with pandas DataFrames
        modified_type_hints = {
            key: value
            for key, value in type_hints.items()
            if not _is_type_from_type_hint(value, PanderaDataFrame)
        }

        new_func = func
        new_func.__annotations__ = modified_type_hints
        validated_func = validate_call(new_func, config={"arbitrary_types_allowed": True})

        return validated_func(*args, **kwargs)

    return wrapper

validate_df_to_model(df, model, output_file=Path('validation_failure_cases.csv'))

Wrapper to validate a DataFrame against a Pandera DataFrameModel with better logging.

Also copies the attrs from the input DataFrame to the validated DataFrame.

Parameters:

Name Type Description Default
df DataFrame

DataFrame to validate.

required
model type

Pandera DataFrameModel to validate against.

required
output_file Path

Optional file to write validation errors to. Defaults to validation_failure_cases.csv.

Path('validation_failure_cases.csv')
Source code in network_wrangler/utils/models.py
@validate_call(config={"arbitrary_types_allowed": True})
def validate_df_to_model(
    df: DataFrame, model: type, output_file: Path = Path("validation_failure_cases.csv")
) -> DataFrame:
    """Wrapper to validate a DataFrame against a Pandera DataFrameModel with better logging.

    Also copies the attrs from the input DataFrame to the validated DataFrame.

    Args:
        df: DataFrame to validate.
        model: Pandera DataFrameModel to validate against.
        output_file: Optional file to write validation errors to. Defaults to
            validation_failure_cases.csv.
    """
    attrs = copy.deepcopy(df.attrs)
    err_msg = f"Validation to {model.__name__} failed."
    try:
        model_df = model.validate(df, lazy=True)
        model_df = fill_df_with_defaults_from_model(model_df, model)
        model_df.attrs = attrs
        return model_df
    except (TypeError, ValueError) as e:
        WranglerLogger.error(f"Validation to {model.__name__} failed.\n{e}")
        raise TableValidationError(err_msg) from e
    except SchemaErrors as e:
        # Log the summary of errors
        WranglerLogger.error(
            f"Validation to {model.__name__} failed with {len(e.failure_cases)} \
            errors: \n{e.failure_cases}"
        )

        # If there are many errors, save them to a file
        if len(e.failure_cases) > SMALL_RECS:
            error_file = output_file
            e.failure_cases.to_csv(error_file)
            WranglerLogger.info(f"Detailed error cases written to {error_file}")
        else:
            # Otherwise log the errors directly
            WranglerLogger.error("Detailed failure cases:\n%s", e.failure_cases)
        raise TableValidationError(err_msg) from e
    except SchemaError as e:
        WranglerLogger.error(f"Validation to {model.__name__} failed with error: {e}")
        WranglerLogger.error(f"Failure Cases:\n{e.failure_cases}")
        raise TableValidationError(err_msg) from e

Functions to help with network manipulations in dataframes.

Translates a df with tidy data representing a sequence of points into links.

Parameters:

Name Type Description Default
point_seq_df DataFrame

Dataframe with source breadcrumbs

required
id_field str

Trace ID

required
seq_field str

Order of breadcrumbs within ID_field

required
node_id_field str

field denoting the node ID

required
from_field str

Field to export from_field to. Defaults to “A”.

'A'
to_field str

Field to export to_field to. Defaults to “B”.

'B'

Returns:

Type Description
DataFrame

pd.DataFrame: Link records with id_field, from_field, to_field

Source code in network_wrangler/utils/net.py
def point_seq_to_links(
    point_seq_df: DataFrame,
    id_field: str,
    seq_field: str,
    node_id_field: str,
    from_field: str = "A",
    to_field: str = "B",
) -> DataFrame:
    """Translates a df with tidy data representing a sequence of points into links.

    Args:
        point_seq_df (pd.DataFrame): Dataframe with source breadcrumbs
        id_field (str): Trace ID
        seq_field (str): Order of breadcrumbs within ID_field
        node_id_field (str): field denoting the node ID
        from_field (str, optional): Field to export from_field to. Defaults to "A".
        to_field (str, optional): Field to export to_field to. Defaults to "B".

    Returns:
        pd.DataFrame: Link records with id_field, from_field, to_field
    """
    point_seq_df = point_seq_df.sort_values(by=[id_field, seq_field])

    links = point_seq_df.add_suffix(f"_{from_field}").join(
        point_seq_df.shift(-1).add_suffix(f"_{to_field}")
    )

    links = links[links[f"{id_field}_{to_field}"] == links[f"{id_field}_{from_field}"]]

    links = links.drop(columns=[f"{id_field}_{to_field}"])
    links = links.rename(
        columns={
            f"{id_field}_{from_field}": id_field,
            f"{node_id_field}_{from_field}": from_field,
            f"{node_id_field}_{to_field}": to_field,
        }
    )

    links = links.dropna(subset=[from_field, to_field])
    # Since join with a shift() has some NAs, we need to recast the columns to int
    _int_cols = [to_field, f"{seq_field}_{to_field}"]
    links[_int_cols] = links[_int_cols].astype(int)
    return links

Functions related to parsing and comparing time objects and series.

Internal function terminology for timespan scopes:

  • matching: a scope that could be applied for a given timespan combination. This includes the default timespan as well as scopes wholely contained within.
  • overlapping: a timespan that fully or partially overlaps a given timespan. This includes the default timespan, all matching timespans and all timespans where at least one minute overlap.
  • conflicting: a timespan that is overlapping but not matching. By definition default scope values are not conflicting.
  • independent a timespan that is not overlapping.

TimespanDfQueryError

Bases: Exception

Error for timespan query errors.

Source code in network_wrangler/utils/time.py
class TimespanDfQueryError(Exception):
    """Error for timespan query errors."""

calc_overlap_duration_with_query(start_time_s, end_time_s, start_time_q, end_time_q)

Calculate the overlap series of start and end times and a query start and end times.

Parameters:

Name Type Description Default
start_time_s Series[datetime]

Series of start times to calculate overlap with.

required
end_time_s Series[datetime]

Series of end times to calculate overlap with.

required
start_time_q datetime

Query start time to calculate overlap with.

required
end_time_q datetime

Query end time to calculate overlap with.

required
Source code in network_wrangler/utils/time.py
def calc_overlap_duration_with_query(
    start_time_s: pd.Series[datetime],
    end_time_s: pd.Series[datetime],
    start_time_q: datetime,
    end_time_q: datetime,
) -> pd.Series[timedelta]:
    """Calculate the overlap series of start and end times and a query start and end times.

    Args:
        start_time_s: Series of start times to calculate overlap with.
        end_time_s: Series of end times to calculate overlap with.
        start_time_q: Query start time to calculate overlap with.
        end_time_q: Query end time to calculate overlap with.
    """
    overlap_start = start_time_s.combine(start_time_q, max)
    overlap_end = end_time_s.combine(end_time_q, min)
    overlap_duration_s = (overlap_end - overlap_start).dt.total_seconds() / 60

    return overlap_duration_s

convert_timespan_to_start_end_dt(timespan_s)

Convert a timespan string [‘12:00’,‘14:00] to start_time & end_time datetime cols in df.

Source code in network_wrangler/utils/time.py
def convert_timespan_to_start_end_dt(timespan_s: pd.Serie[str]) -> pd.DataFrame:
    """Convert a timespan string ['12:00','14:00] to start_time & end_time datetime cols in df."""
    start_time = timespan_s.apply(lambda x: str_to_time(x[0]))
    end_time = timespan_s.apply(lambda x: str_to_time(x[1]))
    return pd.DataFrame({"start_time": start_time, "end_time": end_time})

dt_contains(timespan1, timespan2)

Check timespan1 inclusively contains timespan2.

If the end time is less than the start time, it is assumed to be the next day.

Parameters:

Name Type Description Default
timespan1 list[time]

The first timespan represented as a list containing the start time and end time.

required
timespan2 list[time]

The second timespan represented as a list containing the start time and end time.

required

Returns:

Name Type Description
bool bool

True if the first timespan contains the second timespan, False otherwise.

Source code in network_wrangler/utils/time.py
@validate_call
def dt_contains(timespan1: list[datetime], timespan2: list[datetime]) -> bool:
    """Check timespan1 inclusively contains timespan2.

    If the end time is less than the start time, it is assumed to be the next day.

    Args:
        timespan1 (list[time]): The first timespan represented as a list containing the start
            time and end time.
        timespan2 (list[time]): The second timespan represented as a list containing the start
            time and end time.

    Returns:
        bool: True if the first timespan contains the second timespan, False otherwise.
    """
    start_time_dt, end_time_dt = timespan1

    if end_time_dt < start_time_dt:
        end_time_dt = end_time_dt + timedelta(days=1)

    start_time_dt2, end_time_dt2 = timespan2

    if end_time_dt2 < start_time_dt2:
        end_time_dt2 = end_time_dt2 + timedelta(days=1)

    return (start_time_dt <= start_time_dt2) and (end_time_dt >= end_time_dt2)

dt_list_overlaps(timespans)

Check if any of the timespans overlap.

overlapping: a timespan that fully or partially overlaps a given timespan. This includes and all timespans where at least one minute overlap.

Source code in network_wrangler/utils/time.py
def dt_list_overlaps(timespans: list[list[datetime]]) -> bool:
    """Check if any of the timespans overlap.

    `overlapping`: a timespan that fully or partially overlaps a given timespan.
    This includes and all timespans where at least one minute overlap.
    """
    return bool(filter_dt_list_to_overlaps(timespans))

dt_overlap_duration(timedelta1, timedelta2)

Check if two timespans overlap and return the amount of overlap.

If the end time is less than the start time, it is assumed to be the next day.

Source code in network_wrangler/utils/time.py
@validate_call
def dt_overlap_duration(timedelta1: timedelta, timedelta2: timedelta) -> timedelta:
    """Check if two timespans overlap and return the amount of overlap.

    If the end time is less than the start time, it is assumed to be the next day.
    """
    if timedelta1.end_time < timedelta1.start_time:
        timedelta1 = timedelta1 + timedelta(days=1)
    if timedelta2.end_time < timedelta2.start_time:
        timedelta2 = timedelta2 + timedelta(days=1)
    overlap_start = max(timedelta1.start_time, timedelta2.start_time)
    overlap_end = min(timedelta1.end_time, timedelta2.end_time)
    overlap_duration = max(overlap_end - overlap_start, timedelta(0))
    return overlap_duration

dt_overlaps(timespan1, timespan2)

Check if two timespans overlap.

If the end time is less than the start time, it is assumed to be the next day.

overlapping: a timespan that fully or partially overlaps a given timespan. This includes and all timespans where at least one minute overlap.

Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def dt_overlaps(timespan1: list[datetime], timespan2: list[datetime]) -> bool:
    """Check if two timespans overlap.

    If the end time is less than the start time, it is assumed to be the next day.

    `overlapping`: a timespan that fully or partially overlaps a given timespan.
    This includes and all timespans where at least one minute overlap.
    """
    time1_start, time1_end = timespan1
    time2_start, time2_end = timespan2

    if time1_end < time1_start:
        time1_end += timedelta(days=1)
    if time2_end < time2_start:
        time2_end += timedelta(days=1)

    return (time1_start < time2_end) and (time2_start < time1_end)

dt_to_seconds_from_midnight(dt)

Convert a datetime object to the number of seconds since midnight.

Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def dt_to_seconds_from_midnight(dt: datetime) -> int:
    """Convert a datetime object to the number of seconds since midnight."""
    return round((dt - dt.replace(hour=0, minute=0, second=0, microsecond=0)).total_seconds())

duration_dt(start_time_dt, end_time_dt)

Returns a datetime.timedelta object representing the duration of the timespan.

If end_time is less than start_time, the duration will assume that it crosses over midnight.

Source code in network_wrangler/utils/time.py
def duration_dt(start_time_dt: datetime, end_time_dt: datetime) -> timedelta:
    """Returns a datetime.timedelta object representing the duration of the timespan.

    If end_time is less than start_time, the duration will assume that it crosses over
    midnight.
    """
    if end_time_dt < start_time_dt:
        return timedelta(
            hours=24 - start_time_dt.hour + end_time_dt.hour,
            minutes=end_time_dt.minute - start_time_dt.minute,
            seconds=end_time_dt.second - start_time_dt.second,
        )
    return end_time_dt - start_time_dt

filter_df_to_max_overlapping_timespans(orig_df, query_timespan, strict_match=False, min_overlap_minutes=1, keep_max_of_cols=None)

Filters dataframe for entries that have maximum overlap with the given query timespan.

If the end time is less than the start time, it is assumed to be the next day.

Parameters:

Name Type Description Default
orig_df DataFrame

dataframe to query timespans for with start_time and end_time fields.

required
query_timespan list[TimeString]

TimespanString of format [‘HH:MM’,’HH:MM’] to query orig_df for overlapping records.

required
strict_match bool

boolean indicating if the returned df should only contain records that fully contain the query timespan. If set to True, min_overlap_minutes does not apply. Defaults to False.

False
min_overlap_minutes int

minimum number of minutes the timespans need to overlap to keep. Defaults to 1.

1
keep_max_of_cols Optional[list[str]]

list of fields to return the maximum value of overlap for. If None, will return all overlapping time periods. Defaults to ['model_link_id']

None
Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def filter_df_to_max_overlapping_timespans(
    orig_df: pd.DataFrame,
    query_timespan: list[TimeString],
    strict_match: bool = False,
    min_overlap_minutes: int = 1,
    keep_max_of_cols: Optional[list[str]] = None,
) -> pd.DataFrame:
    """Filters dataframe for entries that have maximum overlap with the given query timespan.

    If the end time is less than the start time, it is assumed to be the next day.

    Args:
        orig_df: dataframe to query timespans for with `start_time` and `end_time` fields.
        query_timespan: TimespanString of format ['HH:MM','HH:MM'] to query orig_df for overlapping
            records.
        strict_match: boolean indicating if the returned df should only contain
            records that fully contain the query timespan. If set to True, min_overlap_minutes
            does not apply. Defaults to False.
        min_overlap_minutes: minimum number of minutes the timespans need to overlap to keep.
            Defaults to 1.
        keep_max_of_cols: list of fields to return the maximum value of overlap for.  If None,
            will return all overlapping time periods. Defaults to `['model_link_id']`
    """
    if keep_max_of_cols is None:
        keep_max_of_cols = ["model_link_id"]
    if "start_time" not in orig_df.columns or "end_time" not in orig_df.columns:
        msg = "DataFrame must have 'start_time' and 'end_time' columns"
        WranglerLogger.error(msg)
        raise TimespanDfQueryError(msg)
    q_start, q_end = str_to_time_list(query_timespan)

    real_end = orig_df["end_time"]
    real_end.loc[orig_df["end_time"] < orig_df["start_time"]] += pd.Timedelta(days=1)

    orig_df["overlap_duration"] = calc_overlap_duration_with_query(
        orig_df["start_time"],
        real_end,
        q_start,
        q_end,
    )
    if strict_match:
        overlap_df = orig_df.loc[(orig_df.start_time <= q_start) & (real_end >= q_end)]
    else:
        overlap_df = orig_df.loc[orig_df.overlap_duration > min_overlap_minutes]
    WranglerLogger.debug(f"overlap_df: \n{overlap_df}")
    if keep_max_of_cols:
        # keep only the maximum overlap
        idx = overlap_df.groupby(keep_max_of_cols)["overlap_duration"].idxmax()
        overlap_df = overlap_df.loc[idx]
    return overlap_df

filter_df_to_overlapping_timespans(orig_df, query_timespans)

Filters dataframe for entries that have any overlap with ANY of the given query timespans.

If the end time is less than the start time, it is assumed to be the next day.

Parameters:

Name Type Description Default
orig_df DataFrame

dataframe to query timespans for with start_time and end_time fields.

required
query_timespans list[TimespanString]

List of a list of TimespanStr of format [‘HH:MM’,’HH:MM’] to query orig_df for overlapping records.

required
Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def filter_df_to_overlapping_timespans(
    orig_df: pd.DataFrame,
    query_timespans: list[TimespanString],
) -> pd.DataFrame:
    """Filters dataframe for entries that have any overlap with ANY of the given query timespans.

    If the end time is less than the start time, it is assumed to be the next day.

    Args:
        orig_df: dataframe to query timespans for with `start_time` and `end_time` fields.
        query_timespans: List of a list of TimespanStr of format ['HH:MM','HH:MM'] to query orig_df
            for overlapping records.
    """
    if "start_time" not in orig_df.columns or "end_time" not in orig_df.columns:
        msg = "DataFrame must have 'start_time' and 'end_time' columns"
        WranglerLogger.error(msg)
        raise TimespanDfQueryError(msg)

    mask = pd.Series([False] * len(orig_df), index=orig_df.index)
    for query_timespan in query_timespans:
        q_start_time, q_end_time = str_to_time_list(query_timespan)
        end_time_s = orig_df["end_time"]
        end_time_s.loc[orig_df["end_time"] < orig_df["start_time"]] += pd.Timedelta(days=1)
        this_ts_mask = (orig_df["start_time"] < q_end_time) & (q_start_time < end_time_s)
        mask |= this_ts_mask
    return orig_df.loc[mask]

filter_dt_list_to_overlaps(timespans)

Filter a list of timespans to only include those that overlap.

overlapping: a timespan that fully or partially overlaps a given timespan. This includes and all timespans where at least one minute overlap.

Source code in network_wrangler/utils/time.py
@validate_call
def filter_dt_list_to_overlaps(timespans: list[list[datetime]]) -> list[list[datetime]]:
    """Filter a list of timespans to only include those that overlap.

    `overlapping`: a timespan that fully or partially overlaps a given timespan.
    This includes and all timespans where at least one minute overlap.
    """
    overlaps = []
    for i in range(len(timespans)):
        for j in range(i + 1, len(timespans)):
            if dt_overlaps(timespans[i], timespans[j]):
                overlaps += [timespans[i], timespans[j]]

    # remove dupes
    return list(map(list, set(map(tuple, overlaps))))

format_seconds_to_legible_str(seconds)

Formats seconds into a human-friendly string for log files.

Source code in network_wrangler/utils/time.py
def format_seconds_to_legible_str(seconds: int) -> str:
    """Formats seconds into a human-friendly string for log files."""
    if seconds < 60:  # noqa: PLR2004
        return f"{int(seconds)} seconds"
    if seconds < 3600:  # noqa: PLR2004
        return f"{int(seconds // 60)} minutes"
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    return f"{hours} hours and {minutes} minutes"

is_increasing(datetimes)

Check if a list of datetime objects is increasing in time.

Source code in network_wrangler/utils/time.py
def is_increasing(datetimes: list[datetime]) -> bool:
    """Check if a list of datetime objects is increasing in time."""
    return all(datetimes[i] <= datetimes[i + 1] for i in range(len(datetimes) - 1))

seconds_from_midnight_to_str(seconds)

Convert the number of seconds since midnight to a TimeString (HH:MM).

Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def seconds_from_midnight_to_str(seconds: int) -> TimeString:
    """Convert the number of seconds since midnight to a TimeString (HH:MM)."""
    return str(timedelta(seconds=seconds))

str_to_seconds_from_midnight(time_str)

Convert a TimeString (HH:MM<:SS>) to the number of seconds since midnight.

Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def str_to_seconds_from_midnight(time_str: TimeString) -> int:
    """Convert a TimeString (HH:MM<:SS>) to the number of seconds since midnight."""
    dt = str_to_time(time_str)
    return dt_to_seconds_from_midnight(dt)

str_to_time(time_str, base_date=None)

Convert TimeString (HH:MM<:SS>) to datetime object.

If HH > 24, will subtract 24 to be within 24 hours. Timespans will be treated as the next day.

Parameters:

Name Type Description Default
time_str TimeString

TimeString in HH:MM:SS or HH:MM format.

required
base_date Optional[date]

optional date to base the datetime on. Defaults to None. If not provided, will use today.

None
Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def str_to_time(time_str: TimeString, base_date: Optional[date] = None) -> datetime:
    """Convert TimeString (HH:MM<:SS>) to datetime object.

    If HH > 24, will subtract 24 to be within 24 hours. Timespans will be treated as the next day.

    Args:
        time_str: TimeString in HH:MM:SS or HH:MM format.
        base_date: optional date to base the datetime on. Defaults to None.
            If not provided, will use today.
    """
    # Set the base date to today if not provided
    if base_date is None:
        base_date = date.today()

    # Split the time string to extract hours, minutes, and seconds
    parts = time_str.split(":")
    hours = int(parts[0])
    minutes = int(parts[1])
    seconds = int(parts[2]) if len(parts) == 3 else 0  # noqa: PLR2004

    if hours >= 24:  # noqa: PLR2004
        add_days = hours // 24
        base_date += timedelta(days=add_days)
        hours -= 24 * add_days

    # Create a time object with the adjusted hours, minutes, and seconds
    adjusted_time = datetime.strptime(f"{hours:02}:{minutes:02}:{seconds:02}", "%H:%M:%S").time()

    # Combine the base date with the adjusted time and add the extra days if needed
    combined_datetime = datetime.combine(base_date, adjusted_time)

    return combined_datetime

str_to_time_list(timespan)

Convert list of TimeStrings (HH:MM<:SS>) to list of datetime.time objects.

Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def str_to_time_list(timespan: list[TimeString]) -> list[datetime]:
    """Convert list of TimeStrings (HH:MM<:SS>) to list of datetime.time objects."""
    timespan_dt: list[datetime] = list(map(str_to_time, timespan))
    if not is_increasing(timespan_dt):
        timespan_dt = [timespan_dt[0], timespan_dt[1] + timedelta(days=1)]
        WranglerLogger.warning(f"Timespan is not in increasing order: {timespan}.\
            End time will be treated as next day.")
    return timespan_dt

str_to_time_series(time_str_s, base_date=None)

Convert mixed panda series datetime and TimeString (HH:MM<:SS>) to datetime object.

If HH > 24, will subtract 24 to be within 24 hours. Timespans will be treated as the next day.

Parameters:

Name Type Description Default
time_str_s Series

Pandas Series of TimeStrings in HH:MM:SS or HH:MM format.

required
base_date Optional[Union[Series, date]]

optional date to base the datetime on. Defaults to None. If not provided, will use today. Can be either a single instance or a series of same length as time_str_s

None
Source code in network_wrangler/utils/time.py
def str_to_time_series(
    time_str_s: pd.Series, base_date: Optional[Union[pd.Series, date]] = None
) -> pd.Series:
    """Convert mixed panda series datetime and TimeString (HH:MM<:SS>) to datetime object.

    If HH > 24, will subtract 24 to be within 24 hours. Timespans will be treated as the next day.

    Args:
        time_str_s: Pandas Series of TimeStrings in HH:MM:SS or HH:MM format.
        base_date: optional date to base the datetime on. Defaults to None.
            If not provided, will use today. Can be either a single instance or a series of
            same length as time_str_s
    """
    # check strings are in the correct format, leave existing date times alone
    is_string = time_str_s.apply(lambda x: isinstance(x, str))
    time_strings = time_str_s[is_string]
    result = time_str_s.copy()
    if is_string.any():
        result[is_string] = _all_str_to_time_series(time_strings, base_date)
    result = result.astype("datetime64[ns]")
    return result

timespan_str_list_to_dt(timespans)

Convert list of TimespanStrings to list of datetime.time objects.

Source code in network_wrangler/utils/time.py
@validate_call(config={"arbitrary_types_allowed": True})
def timespan_str_list_to_dt(timespans: list[TimespanString]) -> list[list[datetime]]:
    """Convert list of TimespanStrings to list of datetime.time objects."""
    return [str_to_time_list(ts) for ts in timespans]

timespans_overlap(timespan1, timespan2)

Check if two timespan strings overlap.

overlapping: a timespan that fully or partially overlaps a given timespan. This includes and all timespans where at least one minute overlap.

Source code in network_wrangler/utils/time.py
def timespans_overlap(timespan1: list[TimespanString], timespan2: list[TimespanString]) -> bool:
    """Check if two timespan strings overlap.

    `overlapping`: a timespan that fully or partially overlaps a given timespan.
    This includes and all timespans where at least one minute overlap.
    """
    timespan1 = str_to_time_list(timespan1)
    timespan2 = str_to_time_list(timespan2)
    return dt_overlaps(timespan1, timespan2)

Utility functions for pandas data manipulation.

DataSegmentationError

Bases: Exception

Raised when there is an error segmenting data.

Source code in network_wrangler/utils/data.py
class DataSegmentationError(Exception):
    """Raised when there is an error segmenting data."""

InvalidJoinFieldError

Bases: Exception

Raised when the join field is not unique.

Source code in network_wrangler/utils/data.py
class InvalidJoinFieldError(Exception):
    """Raised when the join field is not unique."""

MissingPropertiesError

Bases: Exception

Raised when properties are missing from the dataframe.

Source code in network_wrangler/utils/data.py
class MissingPropertiesError(Exception):
    """Raised when properties are missing from the dataframe."""

coerce_dict_to_df_types(d, df, skip_keys=None, return_skipped=False)

Coerce dictionary values to match the type of a dataframe columns matching dict keys.

Will also coerce a list of values.

Parameters:

Name Type Description Default
d dict

dictionary to coerce with singleton or list values

required
df DataFrame

dataframe to get types from

required
skip_keys Optional[list]

list of dict keys to skip. Defaults to []/

None
return_skipped bool

keep the uncoerced, skipped keys/vals in the resulting dict. Defaults to False.

False

Returns:

Name Type Description
dict dict[str, CoerceTypes]

dict with coerced types

Source code in network_wrangler/utils/data.py
def coerce_dict_to_df_types(
    d: dict[str, CoerceTypes],
    df: pd.DataFrame,
    skip_keys: Optional[list] = None,
    return_skipped: bool = False,
) -> dict[str, CoerceTypes]:
    """Coerce dictionary values to match the type of a dataframe columns matching dict keys.

    Will also coerce a list of values.

    Args:
        d (dict): dictionary to coerce with singleton or list values
        df (pd.DataFrame): dataframe to get types from
        skip_keys: list of dict keys to skip. Defaults to []/
        return_skipped: keep the uncoerced, skipped keys/vals in the resulting dict.
            Defaults to False.

    Returns:
        dict: dict with coerced types
    """
    if skip_keys is None:
        skip_keys = []
    coerced_dict: dict[str, CoerceTypes] = {}
    for k, vals in d.items():
        if k in skip_keys:
            if return_skipped:
                coerced_dict[k] = vals
            continue
        if k not in df.columns:
            msg = f"Key {k} not in dataframe columns."
            raise ValueError(msg)
        if pd.api.types.infer_dtype(df[k]) == "integer":
            if isinstance(vals, list):
                coerced_v: CoerceTypes = [int(float(v)) for v in vals]
            else:
                coerced_v = int(float(vals))
        elif pd.api.types.infer_dtype(df[k]) == "floating":
            coerced_v = [float(v) for v in vals] if isinstance(vals, list) else float(vals)
        elif pd.api.types.infer_dtype(df[k]) == "boolean":
            coerced_v = [bool(v) for v in vals] if isinstance(vals, list) else bool(vals)
        elif isinstance(vals, list):
            coerced_v = [str(v) for v in vals]
        else:
            coerced_v = str(vals)
        coerced_dict[k] = coerced_v
    return coerced_dict

coerce_gdf(df, geometry=None, in_crs=LAT_LON_CRS)

Coerce a DataFrame to a GeoDataFrame, optionally with a new geometry.

Source code in network_wrangler/utils/data.py
def coerce_gdf(
    df: pd.DataFrame, geometry: GeoSeries = None, in_crs: int = LAT_LON_CRS
) -> GeoDataFrame:
    """Coerce a DataFrame to a GeoDataFrame, optionally with a new geometry."""
    if isinstance(df, GeoDataFrame):
        if df.crs is None:
            df.crs = in_crs
        return df
    p = None

    if "geometry" not in df and geometry is None:
        msg = "Must give geometry argument if don't have Geometry in dataframe"
        raise ValueError(msg)

    geometry = geometry if geometry is not None else df["geometry"]
    if not isinstance(geometry, GeoSeries):
        try:
            geometry = GeoSeries(geometry)
        except Exception:
            geometry = geometry.apply(wkt.loads)
    df = GeoDataFrame(df, geometry=geometry, crs=in_crs)

    return df

coerce_val_to_df_types(field, val, df)

Coerce field value to match the type of a matching dataframe columns.

Parameters:

Name Type Description Default
field str

field to lookup

required
val CoerceTypes

value or list of values to coerce

required
df DataFrame

dataframe to get types from

required
Source code in network_wrangler/utils/data.py
def coerce_val_to_df_types(  # noqa: PLR0911
    field: str,
    val: CoerceTypes,
    df: pd.DataFrame,
) -> CoerceTypes:
    """Coerce field value to match the type of a matching dataframe columns.

    Args:
        field: field to lookup
        val: value or list of values to coerce
        df (pd.DataFrame): dataframe to get types from

    Returns: coerced value or list of values
    """
    if field not in df.columns:
        msg = f"Field {field} not in dataframe columns."
        raise ValueError(msg)
    if pd.api.types.infer_dtype(df[field]) == "integer":
        if isinstance(val, list):
            return [int(float(v)) for v in val]
        return int(float(val))
    if pd.api.types.infer_dtype(df[field]) == "floating":
        if isinstance(val, list):
            return [float(v) for v in val]
        return float(val)
    if pd.api.types.infer_dtype(df[field]) == "boolean":
        if isinstance(val, list):
            return [bool(v) for v in val]
        return bool(val)
    if isinstance(val, list):
        return [str(v) for v in val]
    return str(val)

coerce_val_to_series_type(val, s)

Coerces a value to match type of pandas series.

Will try not to fail so if you give it a value that can’t convert to a number, it will return a string.

Parameters:

Name Type Description Default
val

Any type of singleton value

required
s Series

series to match the type to

required
Source code in network_wrangler/utils/data.py
def coerce_val_to_series_type(val, s: pd.Series) -> Union[float, str, bool]:
    """Coerces a value to match type of pandas series.

    Will try not to fail so if you give it a value that can't convert to a number, it will
    return a string.

    Args:
        val: Any type of singleton value
        s (pd.Series): series to match the type to
    """
    # WranglerLogger.debug(f"Input val: {val} of type {type(val)} to match with series type \
    #    {pd.api.types.infer_dtype(s)}.")
    if pd.api.types.infer_dtype(s) in ["integer", "floating"]:
        try:
            v: Union[float, str, bool] = float(val)
        except:
            v = str(val)
    elif pd.api.types.infer_dtype(s) == "boolean":
        v = bool(val)
    else:
        v = str(val)
    # WranglerLogger.debug(f"Return value: {v}")
    return v

compare_df_values(df1, df2, join_col=None, ignore=None, atol=1e-05)

Compare overlapping part of dataframes and returns where there are differences.

Source code in network_wrangler/utils/data.py
def compare_df_values(
    df1, df2, join_col: Optional[str] = None, ignore: Optional[list[str]] = None, atol=1e-5
):
    """Compare overlapping part of dataframes and returns where there are differences."""
    if ignore is None:
        ignore = []
    comp_c = [
        c
        for c in df1.columns
        if c in df2.columns and c not in ignore and not isinstance(df1[c], GeoSeries)
    ]
    if join_col is None:
        comp_df = df1[comp_c].merge(
            df2[comp_c],
            how="inner",
            right_index=True,
            left_index=True,
            suffixes=["_a", "_b"],
        )
    else:
        comp_df = df1[comp_c].merge(df2[comp_c], how="inner", on=join_col, suffixes=["_a", "_b"])

    # Filter columns by data type
    numeric_cols = [col for col in comp_c if np.issubdtype(df1[col].dtype, np.number)]
    ll_cols = list(set(list_like_columns(df1) + list_like_columns(df2)))
    other_cols = [col for col in comp_c if col not in numeric_cols and col not in ll_cols]

    # For numeric columns, use np.isclose
    if numeric_cols:
        numeric_a = comp_df[[f"{col}_a" for col in numeric_cols]]
        numeric_b = comp_df[[f"{col}_b" for col in numeric_cols]]
        is_close = np.isclose(numeric_a, numeric_b, atol=atol, equal_nan=True)
        comp_df[numeric_cols] = ~is_close

    if ll_cols:
        for ll_c in ll_cols:
            comp_df[ll_c] = diff_list_like_series(comp_df[ll_c + "_a"], comp_df[ll_c + "_b"])

    # For non-numeric columns, use direct comparison
    if other_cols:
        for col in other_cols:
            comp_df[col] = (comp_df[f"{col}_a"] != comp_df[f"{col}_b"]) & ~(
                comp_df[f"{col}_a"].isna() & comp_df[f"{col}_b"].isna()
            )

    # Filter columns and rows where no differences
    cols_w_diffs = [col for col in comp_c if comp_df[col].any()]
    out_cols = [col for subcol in cols_w_diffs for col in (f"{subcol}_a", f"{subcol}_b", subcol)]
    comp_df = comp_df[out_cols]
    comp_df = comp_df.loc[comp_df[cols_w_diffs].any(axis=1)]

    return comp_df

compare_lists(list1, list2)

Compare two lists.

Source code in network_wrangler/utils/data.py
def compare_lists(list1, list2) -> bool:
    """Compare two lists."""
    list1 = convert_numpy_to_list(list1)
    list2 = convert_numpy_to_list(list1)
    return list1 != list2

concat_with_attr(dfs, **kwargs)

Concatenate a list of dataframes and retain the attributes of the first dataframe.

Source code in network_wrangler/utils/data.py
def concat_with_attr(dfs: list[pd.DataFrame], **kwargs) -> pd.DataFrame:
    """Concatenate a list of dataframes and retain the attributes of the first dataframe."""
    import copy

    if not dfs:
        msg = "No dataframes to concatenate."
        raise ValueError(msg)
    attrs = copy.deepcopy(dfs[0].attrs)
    df = pd.concat(dfs, **kwargs)
    df.attrs = attrs
    return df

convert_numpy_to_list(item)

Function to recursively convert numpy arrays to lists.

Source code in network_wrangler/utils/data.py
def convert_numpy_to_list(item):
    """Function to recursively convert numpy arrays to lists."""
    if isinstance(item, np.ndarray):
        return item.tolist()
    if isinstance(item, list):
        return [convert_numpy_to_list(sub_item) for sub_item in item]
    if isinstance(item, dict):
        return {key: convert_numpy_to_list(value) for key, value in item.items()}
    return item

dict_fields_in_df(d, df)

Check if all fields in dict are in dataframe.

Source code in network_wrangler/utils/data.py
def dict_fields_in_df(d: dict, df: pd.DataFrame) -> bool:
    """Check if all fields in dict are in dataframe."""
    missing_fields = [f for f in d if f not in df.columns]
    if missing_fields:
        msg = f"Fields in dictionary missing from dataframe: {missing_fields}."
        WranglerLogger.error(msg)
        raise ValueError(msg)
    return True

dict_to_query(selection_dict)

Generates the query of from selection_dict.

Parameters:

Name Type Description Default
selection_dict Mapping[str, Any]

selection dictionary

required

Returns:

Name Type Description
_type_ str

Query value

Source code in network_wrangler/utils/data.py
def dict_to_query(
    selection_dict: Mapping[str, Any],
) -> str:
    """Generates the query of from selection_dict.

    Args:
        selection_dict: selection dictionary

    Returns:
        _type_: Query value
    """
    WranglerLogger.debug("Building selection query")

    def _kv_to_query_part(k, v, _q_part=""):
        if isinstance(v, list):
            _q_part += "(" + " or ".join([_kv_to_query_part(k, i) for i in v]) + ")"
            return _q_part
        if isinstance(v, str):
            return k + '.str.contains("' + v + '")'
        return k + "==" + str(v)

    query = "(" + " and ".join([_kv_to_query_part(k, v) for k, v in selection_dict.items()]) + ")"
    WranglerLogger.debug(f"Selection query: \n{query}")
    return query

diff_dfs(df1, df2, ignore=None)

Returns True if two dataframes are different and log differences.

Source code in network_wrangler/utils/data.py
def diff_dfs(df1, df2, ignore: Optional[list[str]] = None) -> bool:
    """Returns True if two dataframes are different and log differences."""
    if ignore is None:
        ignore = []
    diff = False
    if set(df1.columns) != set(df2.columns):
        WranglerLogger.warning(
            f" Columns are different 1vs2 \n    {set(df1.columns) ^ set(df2.columns)}"
        )
        common_cols = [col for col in df1.columns if col in df2.columns]
        df1 = df1[common_cols]
        df2 = df2[common_cols]
        diff = True

    cols_to_compare = [col for col in df1.columns if col not in ignore]
    df1 = df1[cols_to_compare]
    df2 = df2[cols_to_compare]

    if len(df1) != len(df2):
        WranglerLogger.warning(
            f" Length is different /" f"DF1: {len(df1)} vs /" f"DF2: {len(df2)}\n /"
        )
        diff = True

    diff_df = compare_df_values(df1, df2)

    if not diff_df.empty:
        WranglerLogger.error(f"!!! Differences dfs: \n{diff_df}")
        return True

    if not diff:
        WranglerLogger.info("...no differences in df found.")
    return diff

diff_list_like_series(s1, s2)

Compare two series that contain list-like items as strings.

Source code in network_wrangler/utils/data.py
def diff_list_like_series(s1, s2) -> bool:
    """Compare two series that contain list-like items as strings."""
    diff_df = concat_with_attr([s1, s2], axis=1, keys=["s1", "s2"])
    # diff_df["diff"] = diff_df.apply(lambda x: str(x["s1"]) != str(x["s2"]), axis=1)
    diff_df["diff"] = diff_df.apply(lambda x: compare_lists(x["s1"], x["s2"]), axis=1)
    if diff_df["diff"].any():
        WranglerLogger.info("List-Like differences:")
        WranglerLogger.info(diff_df)
        return True
    return False

fk_in_pk(pk, fk, ignore_nan=True)

Check if all foreign keys are in the primary keys, optionally ignoring NaN.

Source code in network_wrangler/utils/data.py
def fk_in_pk(
    pk: Union[pd.Series, list], fk: Union[pd.Series, list], ignore_nan: bool = True
) -> tuple[bool, list]:
    """Check if all foreign keys are in the primary keys, optionally ignoring NaN."""
    if isinstance(fk, list):
        fk = pd.Series(fk)

    if ignore_nan:
        fk = fk.dropna()

    missing_flag = ~fk.isin(pk)

    if missing_flag.any():
        WranglerLogger.warning(
            f"Following keys referenced in {fk.name} but missing in\
            primary key table: \n{fk[missing_flag]} "
        )
        return False, fk[missing_flag].tolist()

    return True, []

isin_dict(df, d, ignore_missing=True, strict_str=False)

Filter the dataframe using a dictionary - faster than using isin.

Uses merge to filter the dataframe by the dictionary keys and values.

Parameters:

Name Type Description Default
df DataFrame

dataframe to filter

required
d dict

dictionary with keys as column names and values as values to filter by

required
ignore_missing bool

if True, will ignore missing values in the selection dict.

True
strict_str bool

if True, will not allow partial string matches and will force case-matching. Defaults to False. If False, will be overridden if key is in STRICT_MATCH_FIELDS or if ignore_missing is False.

False
Source code in network_wrangler/utils/data.py
def isin_dict(
    df: pd.DataFrame, d: dict, ignore_missing: bool = True, strict_str: bool = False
) -> pd.DataFrame:
    """Filter the dataframe using a dictionary - faster than using isin.

    Uses merge to filter the dataframe by the dictionary keys and values.

    Args:
        df: dataframe to filter
        d: dictionary with keys as column names and values as values to filter by
        ignore_missing: if True, will ignore missing values in the selection dict.
        strict_str: if True, will not allow partial string matches and will force case-matching.
            Defaults to False. If False, will be overridden if key is in STRICT_MATCH_FIELDS or if
            ignore_missing is False.
    """
    sel_links_mask = np.zeros(len(df), dtype=bool)
    missing = {}
    for col, vals in d.items():
        if vals is None:
            continue
        if col not in df.columns:
            msg = f"Key {col} not in dataframe columns."
            raise DataframeSelectionError(msg)
        _strict_str = strict_str or col in STRICT_MATCH_FIELDS or not ignore_missing
        vals_list = [vals] if not isinstance(vals, list) else vals

        index_name = df.index.name if df.index.name is not None else "index"
        _df = df[[col]].reset_index(names=index_name)

        if isinstance(vals_list[0], str) and not _strict_str:
            vals_list = [val.lower() for val in vals_list]
            _df[col] = _df[col].str.lower()

            # Use str.contains for partial matching
            mask = np.zeros(len(_df), dtype=bool)
            for val in vals_list:
                mask |= _df[col].str.contains(val, case=False, na=False)
            selected = _df[mask].set_index(index_name)
        else:
            vals_df = pd.DataFrame({col: vals_list}, index=range(len(vals_list)))
            merged_df = _df.merge(vals_df, on=col, how="outer", indicator=True)
            selected = merged_df[merged_df["_merge"] == "both"].set_index(index_name)
            _missing_vals = merged_df[merged_df["_merge"] == "right_only"][col].tolist()
            if _missing_vals:
                missing[col] = _missing_vals
                WranglerLogger.warning(f"Missing values in selection dict for {col}: {missing}")

        sel_links_mask |= df.index.isin(selected.index)

    if not ignore_missing and any(missing):
        msg = "Missing values in selection dict."
        raise DataframeSelectionError(msg)

    return df.loc[sel_links_mask]

list_like_columns(df, item_type=None)

Find columns in a dataframe that contain list-like items that can’t be json-serialized.

Parameters:

Name Type Description Default
df

dataframe to check

required
item_type Optional[type]

if not None, will only return columns where all items are of this type by checking only the first item in the column. Defaults to None.

None
Source code in network_wrangler/utils/data.py
def list_like_columns(df, item_type: Optional[type] = None) -> list[str]:
    """Find columns in a dataframe that contain list-like items that can't be json-serialized.

    Args:
        df: dataframe to check
        item_type: if not None, will only return columns where all items are of this type by
            checking **only** the first item in the column.  Defaults to None.
    """
    list_like_columns = []

    for column in df.columns:
        if df[column].apply(lambda x: isinstance(x, (list, ndarray))).any():
            if item_type is not None and not isinstance(df[column].iloc[0], item_type):
                continue
            list_like_columns.append(column)
    return list_like_columns

segment_data_by_selection(item_list, data, field=None, end_val=0)

Segment a dataframe or series into before, middle, and end segments based on item_list.

selected segment = everything from the first to last item in item_list inclusive of the first and last items. Before segment = everything before After segment = everything after

Parameters:

Name Type Description Default
item_list list

List of items to segment data by. If longer than two, will only use the first and last items.

required
data Union[Series, DataFrame]

Data to segment into before, middle, and after.

required
field str

If a dataframe, specifies which field to reference. Defaults to None.

None
end_val int

Notation for util the end or from the begining. Defaults to 0.

0

Raises:

Type Description
DataSegmentationError

If item list isn’t found in data in correct order.

Returns:

Name Type Description
tuple tuple[Union[Series, list, DataFrame], Union[Series, list, DataFrame], Union[Series, list, DataFrame]]

data broken out by beofore, selected segment, and after.

Source code in network_wrangler/utils/data.py
def segment_data_by_selection(
    item_list: list,
    data: Union[list, pd.DataFrame, pd.Series],
    field: Optional[str] = None,
    end_val=0,
) -> tuple[
    Union[pd.Series, list, pd.DataFrame],
    Union[pd.Series, list, pd.DataFrame],
    Union[pd.Series, list, pd.DataFrame],
]:
    """Segment a dataframe or series into before, middle, and end segments based on item_list.

    selected segment = everything from the first to last item in item_list inclusive of the first
        and last items.
    Before segment = everything before
    After segment = everything after

    Args:
        item_list (list): List of items to segment data by. If longer than two, will only
            use the first and last items.
        data (Union[pd.Series, pd.DataFrame]): Data to segment into before, middle, and after.
        field (str, optional): If a dataframe, specifies which field to reference.
            Defaults to None.
        end_val (int, optional): Notation for util the end or from the begining. Defaults to 0.

    Raises:
        DataSegmentationError: If item list isn't found in data in correct order.

    Returns:
        tuple: data broken out by beofore, selected segment, and after.
    """
    ref_data = data
    if isinstance(data, pd.DataFrame):
        ref_data = data[field].tolist()
    elif isinstance(data, pd.Series):
        ref_data = data.tolist()

    # ------- Replace "to the end" indicators with first or last value --------
    start_item, end_item = item_list[0], item_list[-1]
    if start_item == end_val:
        start_item = ref_data[0]
    if end_item == end_val:
        end_item = ref_data[-1]

    # --------Find the start and end indices -----------------------------------
    start_idxs = list({i for i, item in enumerate(ref_data) if item == start_item})
    if not start_idxs:
        msg = f"Segment start item: {start_item} not in data."
        raise DataSegmentationError(msg)
    if len(start_idxs) > 1:
        WranglerLogger.warning(
            f"Found multiple starting locations for data segment: {start_item}.\
                                Choosing first ... largest segment being selected."
        )
    start_idx = min(start_idxs)

    # find the end node starting from the start index.
    end_idxs = [i + start_idx for i, item in enumerate(ref_data[start_idx:]) if item == end_item]
    # WranglerLogger.debug(f"End indexes: {end_idxs}")
    if not end_idxs:
        msg = f"Segment end item: {end_item} not in data after starting idx."
        raise DataSegmentationError(msg)
    if len(end_idxs) > 1:
        WranglerLogger.warning(
            f"Found multiple ending locations for data segment: {end_item}.\
                                Choosing last ... largest segment being selected."
        )
    end_idx = max(end_idxs) + 1
    # WranglerLogger.debug(
    # f"Segmenting data fr {start_item} idx:{start_idx} to {end_item} idx:{end_idx}.\n{ref_data}")
    # -------- Extract the segments --------------------------------------------
    if isinstance(data, pd.DataFrame):
        before_segment = data.iloc[:start_idx]
        selected_segment = data.iloc[start_idx:end_idx]
        after_segment = data.iloc[end_idx:]
    else:
        before_segment = data[:start_idx]
        selected_segment = data[start_idx:end_idx]
        after_segment = data[end_idx:]

    if isinstance(data, (pd.DataFrame, pd.Series)):
        before_segment = before_segment.reset_index(drop=True)
        selected_segment = selected_segment.reset_index(drop=True)
        after_segment = after_segment.reset_index(drop=True)

    # WranglerLogger.debug(f"Segmented data into before, selected, and after.\n \
    #    Before:\n{before_segment}\nSelected:\n{selected_segment}\nAfter:\n{after_segment}")

    return before_segment, selected_segment, after_segment

segment_data_by_selection_min_overlap(selection_list, data, field, replacements_list, end_val=0)

Segments data based on item_list reducing overlap with replacement list.

selected segment: everything from the first to last item in item_list inclusive of the first and last items but not if first and last items overlap with replacement list. Before segment = everything before After segment = everything after

Example: selection_list = [2,5] data = pd.DataFrame({“i”:[1,2,3,4,5,6]}) field = “i” replacements_list = [2,22,33]

Returns:

Type Description
list

[22,33]

tuple[Union[Series, DataFrame], Union[Series, DataFrame], Union[Series, DataFrame]]

[1], [2,3,4,5], [6]

Parameters:

Name Type Description Default
selection_list list

List of items to segment data by. If longer than two, will only use the first and last items.

required
data Union[Series, DataFrame]

Data to segment into before, middle, and after.

required
field str

Specifies which field to reference.

required
replacements_list list

List of items to eventually replace the selected segment with.

required
end_val int

Notation for util the end or from the begining. Defaults to 0.

0

tuple containing:

Type Description
list
  • updated replacement_list
tuple[Union[Series, DataFrame], Union[Series, DataFrame], Union[Series, DataFrame]]
  • tuple of before, selected segment, and after data
Source code in network_wrangler/utils/data.py
def segment_data_by_selection_min_overlap(
    selection_list: list,
    data: pd.DataFrame,
    field: str,
    replacements_list: list,
    end_val=0,
) -> tuple[
    list,
    tuple[
        Union[pd.Series, pd.DataFrame],
        Union[pd.Series, pd.DataFrame],
        Union[pd.Series, pd.DataFrame],
    ],
]:
    """Segments data based on item_list reducing overlap with replacement list.

    *selected segment*: everything from the first to last item in item_list inclusive of the first
        and last items but not if first and last items overlap with replacement list.
    Before segment = everything before
    After segment = everything after

    Example:
    selection_list = [2,5]
    data = pd.DataFrame({"i":[1,2,3,4,5,6]})
    field = "i"
    replacements_list = [2,22,33]

    returns:
        [22,33]
        [1], [2,3,4,5], [6]

    Args:
        selection_list (list): List of items to segment data by. If longer than two, will only
            use the first and last items.
        data (Union[pd.Series, pd.DataFrame]): Data to segment into before, middle, and after.
        field (str): Specifies which field to reference.
        replacements_list (list): List of items to eventually replace the selected segment with.
        end_val (int, optional): Notation for util the end or from the begining. Defaults to 0.

    Returns: tuple containing:
        - updated replacement_list
        - tuple of before, selected segment, and after data
    """
    before_segment, segment_df, after_segment = segment_data_by_selection(
        selection_list, data, field=field, end_val=end_val
    )
    if not isinstance(segment_df, pd.DataFrame):
        msg = "segment_df should be a DataFrame - something is wrong."
        raise ValueError(msg)

    if replacements_list and replacements_list[0] == segment_df[field].iat[0]:
        # move first item from selected segment to the before_segment df
        replacements_list = replacements_list[1:]
        before_segment = concat_with_attr(
            [before_segment, segment_df.iloc[:1]], ignore_index=True, sort=False
        )
        segment_df = segment_df.iloc[1:]
        # WranglerLogger.debug(f"item start overlaps with replacement. Repl: {replacements_list}")
    if replacements_list and replacements_list[-1] == data[field].iat[-1]:
        # move last item from selected segment to the after_segment df
        replacements_list = replacements_list[:-1]
        after_segment = concat_with_attr(
            [data.iloc[-1:], after_segment], ignore_index=True, sort=False
        )
        segment_df = segment_df.iloc[:-1]
        # WranglerLogger.debug(f"item end overlaps with replacement. Repl: {replacements_list}")

    return replacements_list, (before_segment, segment_df, after_segment)

update_df_by_col_value(destination_df, source_df, join_col, properties=None, fail_if_missing=True)

Updates destination_df with ALL values in source_df for specified props with same join_col.

Source_df can contain a subset of IDs of destination_df. If fail_if_missing is true, destination_df must have all the IDS in source DF - ensuring all source_df values are contained in resulting df.

>> destination_df
trip_id  property1  property2
1         10      100
2         20      200
3         30      300
4         40      400

>> source_df
trip_id  property1  property2
2         25      250
3         35      350

>> updated_df
trip_id  property1  property2
0        1       10      100
1        2       25      250
2        3       35      350
3        4       40      400

Parameters:

Name Type Description Default
destination_df DataFrame

Dataframe to modify.

required
source_df DataFrame

Dataframe with updated columns

required
join_col str

column to join on

required
properties list[str]

List of properties to use. If None, will default to all in source_df.

None
fail_if_missing bool

If True, will raise an error if there are missing IDs in destination_df that exist in source_df.

True
Source code in network_wrangler/utils/data.py
def update_df_by_col_value(
    destination_df: pd.DataFrame,
    source_df: pd.DataFrame,
    join_col: str,
    properties: Optional[list[str]] = None,
    fail_if_missing: bool = True,
) -> pd.DataFrame:
    """Updates destination_df with ALL values in source_df for specified props with same join_col.

    Source_df can contain a subset of IDs of destination_df.
    If fail_if_missing is true, destination_df must have all
    the IDS in source DF - ensuring all source_df values are contained in resulting df.

    ```
    >> destination_df
    trip_id  property1  property2
    1         10      100
    2         20      200
    3         30      300
    4         40      400

    >> source_df
    trip_id  property1  property2
    2         25      250
    3         35      350

    >> updated_df
    trip_id  property1  property2
    0        1       10      100
    1        2       25      250
    2        3       35      350
    3        4       40      400
    ```

    Args:
        destination_df (pd.DataFrame): Dataframe to modify.
        source_df (pd.DataFrame): Dataframe with updated columns
        join_col (str): column to join on
        properties (list[str]): List of properties to use. If None, will default to all
            in source_df.
        fail_if_missing (bool): If True, will raise an error if there are missing IDs in
            destination_df that exist in source_df.
    """
    # 1. Identify which properties should be updated; and if they exist in both DFs.
    if properties is None:
        properties = [
            c for c in source_df.columns if c in destination_df.columns and c != join_col
        ]
    else:
        _dest_miss = _df_missing_cols(destination_df, [*properties, join_col])
        if _dest_miss:
            msg = f"Properties missing from destination_df: {_dest_miss}"
            raise MissingPropertiesError(msg)
        _source_miss = _df_missing_cols(source_df, [*properties, join_col])
        if _source_miss:
            msg = f"Properties missing from source_df: {_source_miss}"
            raise MissingPropertiesError(msg)

    # 2. Identify if there are IDs missing from destination_df that exist in source_df
    if fail_if_missing:
        missing_ids = set(source_df[join_col]) - set(destination_df[join_col])
        if missing_ids:
            msg = f"IDs missing from source_df: \n{missing_ids}"
            raise InvalidJoinFieldError(msg)

    WranglerLogger.debug(f"Updating properties for {len(source_df)} records: {properties}.")

    if not source_df[join_col].is_unique:
        msg = f"Can't join from source_df when join_col: {join_col} is not unique."
        raise InvalidJoinFieldError(msg)

    if not destination_df[join_col].is_unique:
        return _update_props_from_one_to_many(destination_df, source_df, join_col, properties)

    return _update_props_for_common_idx(destination_df, source_df, join_col, properties)

validate_existing_value_in_df(df, idx, field, expected_value)

Validate if df[field]==expected_value for all indices in idx.

Source code in network_wrangler/utils/data.py
def validate_existing_value_in_df(df: pd.DataFrame, idx: list[int], field: str, expected_value):
    """Validate if df[field]==expected_value for all indices in idx."""
    if field not in df.columns:
        WranglerLogger.warning(f"!! {field} Not an existing field.")
        return False
    if not df.loc[idx, field].eq(expected_value).all():
        WranglerLogger.warning(
            f"Existing value defined for {field} in project card \
            does not match the value in the selection links. \n\
            Specified Existing: {expected_value}\n\
            Actual Existing: \n {df.loc[idx, field]}."
        )
        return False
    return True

Helper geographic manipulation functions.

InvalidCRSError

Bases: Exception

Raised when a point is not valid for a given coordinate reference system.

Source code in network_wrangler/utils/geo.py
class InvalidCRSError(Exception):
    """Raised when a point is not valid for a given coordinate reference system."""

check_point_valid_for_crs(point, crs)

Check if a point is valid for a given coordinate reference system.

Parameters:

Name Type Description Default
point Point

Shapely Point

required
crs int

coordinate reference system in ESPG code

required
Source code in network_wrangler/utils/geo.py
def check_point_valid_for_crs(point: Point, crs: int):
    """Check if a point is valid for a given coordinate reference system.

    Args:
        point: Shapely Point
        crs: coordinate reference system in ESPG code

    raises: InvalidCRSError if point is not valid for the given crs
    """
    crs = CRS.from_user_input(crs)
    minx, miny, maxx, maxy = crs.area_of_use.bounds
    ok_bounds = True
    if not minx <= point.x <= maxx:
        WranglerLogger.error(f"Invalid X coordinate for CRS {crs}: {point.x}")
        ok_bounds = False
    if not miny <= point.y <= maxy:
        WranglerLogger.error(f"Invalid Y coordinate for CRS {crs}: {point.y}")
        ok_bounds = False

    if not ok_bounds:
        msg = f"Invalid coordinate for CRS {crs}: {point.x}, {point.y}"
        raise InvalidCRSError(msg)

get_bearing(lat1, lon1, lat2, lon2)

Calculate the bearing (forward azimuth) b/w the two points.

returns: bearing in radians

Source code in network_wrangler/utils/geo.py
def get_bearing(lat1, lon1, lat2, lon2):
    """Calculate the bearing (forward azimuth) b/w the two points.

    returns: bearing in radians
    """
    # bearing in degrees
    brng = Geodesic.WGS84.Inverse(lat1, lon1, lat2, lon2)["azi1"]

    # convert bearing to radians
    brng = math.radians(brng)

    return brng

get_bounding_polygon(boundary_geocode=None, boundary_file=None, boundary_gdf=None, crs=LAT_LON_CRS)

Get the bounding polygon for a given boundary.

Will return None if no arguments given. Will raise a ValueError if more than one given.

This function retrieves the bounding polygon for a given boundary. The boundary can be provided as a GeoDataFrame, a geocode string or dictionary, or a boundary file. The resulting polygon geometry is returned as a GeoSeries.

Parameters:

Name Type Description Default
boundary_geocode Union[str, dict]

A geocode string or dictionary representing the boundary. Defaults to None.

None
boundary_file Union[str, Path]

A path to the boundary file. Only used if boundary_geocode is None. Defaults to None.

None
boundary_gdf GeoDataFrame

A GeoDataFrame representing the boundary. Only used if boundary_geocode and boundary_file are None. Defaults to None.

None
crs int

The coordinate reference system (CRS) code. Defaults to 4326 (WGS84).

LAT_LON_CRS

Returns:

Type Description
GeoSeries

gpd.GeoSeries: The polygon geometry representing the bounding polygon.

Source code in network_wrangler/utils/geo.py
def get_bounding_polygon(
    boundary_geocode: Optional[Union[str, dict]] = None,
    boundary_file: Optional[Union[str, Path]] = None,
    boundary_gdf: Optional[gpd.GeoDataFrame] = None,
    crs: int = LAT_LON_CRS,  # WGS84
) -> gpd.GeoSeries:
    """Get the bounding polygon for a given boundary.

    Will return None if no arguments given. Will raise a ValueError if more than one given.

    This function retrieves the bounding polygon for a given boundary. The boundary can be provided
    as a GeoDataFrame, a geocode string or dictionary, or a boundary file. The resulting polygon
    geometry is returned as a GeoSeries.

    Args:
        boundary_geocode (Union[str, dict], optional): A geocode string or dictionary
            representing the boundary. Defaults to None.
        boundary_file (Union[str, Path], optional): A path to the boundary file. Only used if
            boundary_geocode is None. Defaults to None.
        boundary_gdf (gpd.GeoDataFrame, optional): A GeoDataFrame representing the boundary.
            Only used if boundary_geocode and boundary_file are None. Defaults to None.
        crs (int, optional): The coordinate reference system (CRS) code. Defaults to 4326 (WGS84).

    Returns:
        gpd.GeoSeries: The polygon geometry representing the bounding polygon.
    """
    import osmnx as ox

    nargs = sum(x is not None for x in [boundary_gdf, boundary_geocode, boundary_file])
    if nargs == 0:
        return None
    if nargs != 1:
        msg = "Exactly one of boundary_gdf, boundary_geocode, or boundary_file must be provided."
        raise ValueError(msg)

    OK_BOUNDARY_SUFF = [".shp", ".geojson", ".parquet"]

    if boundary_geocode is not None:
        boundary_gdf = ox.geocode_to_gdf(boundary_geocode)
    elif boundary_file is not None:
        boundary_file = Path(boundary_file)
        if boundary_file.suffix not in OK_BOUNDARY_SUFF:
            msg = "Boundary file must have one of the following suffixes: {OK_BOUNDARY_SUFF}"
            raise ValueError(msg)
        if not boundary_file.exists():
            msg = f"Boundary file {boundary_file} does not exist"
            raise FileNotFoundError(msg)
        if boundary_file.suffix == ".parquet":
            boundary_gdf = gpd.read_parquet(boundary_file)
        else:
            boundary_gdf = gpd.read_file(boundary_file)
            if boundary_file.suffix == ".geojson":  # geojson standard is WGS84
                boundary_gdf.crs = crs

    if boundary_gdf is None:
        msg = "One of boundary_gdf, boundary_geocode or boundary_file must be provided."
        raise ValueError(msg)

    if boundary_gdf.crs is not None:
        boundary_gdf = boundary_gdf.to_crs(crs)
    # make sure boundary_gdf is a polygon
    if len(boundary_gdf.geom_type[boundary_gdf.geom_type != "Polygon"]) > 0:
        msg = "boundary_gdf must all be Polygons"
        raise ValueError(msg)
    # get the boundary as a single polygon
    boundary_gs = gpd.GeoSeries([boundary_gdf.geometry.union_all()], crs=crs)

    return boundary_gs

get_point_geometry_from_linestring(polyline_geometry, pos=0)

Get a point geometry from a linestring geometry.

Parameters:

Name Type Description Default
polyline_geometry

shapely LineString instance

required
pos int

position in the linestring to get the point from. Defaults to 0.

0
Source code in network_wrangler/utils/geo.py
def get_point_geometry_from_linestring(polyline_geometry, pos: int = 0):
    """Get a point geometry from a linestring geometry.

    Args:
        polyline_geometry: shapely LineString instance
        pos: position in the linestring to get the point from. Defaults to 0.
    """
    # WranglerLogger.debug(
    #    f"get_point_geometry_from_linestring.polyline_geometry.coords[0]: \
    #    {polyline_geometry.coords[0]}."
    # )

    # Note: when upgrading to shapely 2.0, will need to use following command
    # _point_coords = get_coordinates(polyline_geometry).tolist()[pos]
    return point_from_xy(*polyline_geometry.coords[pos])

length_of_linestring_miles(gdf)

Returns a Series with the linestring length in miles.

Parameters:

Name Type Description Default
gdf Union[GeoSeries, GeoDataFrame]

GeoDataFrame with linestring geometry. If given a GeoSeries will attempt to convert to a GeoDataFrame.

required
Source code in network_wrangler/utils/geo.py
def length_of_linestring_miles(gdf: Union[gpd.GeoSeries, gpd.GeoDataFrame]) -> pd.Series:
    """Returns a Series with the linestring length in miles.

    Args:
        gdf: GeoDataFrame with linestring geometry.  If given a GeoSeries will attempt to convert
            to a GeoDataFrame.
    """
    # WranglerLogger.debug(f"length_of_linestring_miles.gdf input:\n{gdf}.")
    if isinstance(gdf, gpd.GeoSeries):
        gdf = gpd.GeoDataFrame(geometry=gdf)

    p_crs = gdf.estimate_utm_crs()
    gdf = gdf.to_crs(p_crs)
    METERS_IN_MILES = 1609.34
    length_miles = gdf.geometry.length / METERS_IN_MILES
    length_s = pd.Series(length_miles, index=gdf.index)

    return length_s

linestring_from_lats_lons(df, lat_fields, lon_fields)

Create a LineString geometry from a DataFrame with lon/lat fields.

Parameters:

Name Type Description Default
df

DataFrame with columns for lon/lat fields.

required
lat_fields

list of column names for the lat fields.

required
lon_fields

list of column names for the lon fields.

required
Source code in network_wrangler/utils/geo.py
def linestring_from_lats_lons(df, lat_fields, lon_fields) -> gpd.GeoSeries:
    """Create a LineString geometry from a DataFrame with lon/lat fields.

    Args:
        df: DataFrame with columns for lon/lat fields.
        lat_fields: list of column names for the lat fields.
        lon_fields: list of column names for the lon fields.
    """
    if len(lon_fields) != len(lat_fields):
        msg = "lon_fields and lat_fields lists must have the same length"
        raise ValueError(msg)

    line_geometries = gpd.GeoSeries(
        [
            LineString([(row[lon], row[lat]) for lon, lat in zip(lon_fields, lat_fields)])
            for _, row in df.iterrows()
        ]
    )

    return gpd.GeoSeries(line_geometries)

linestring_from_nodes(links_df, nodes_df, from_node='A', to_node='B', node_pk='model_node_id')

Creates a LineString geometry GeoSeries from a DataFrame of links and a DataFrame of nodes.

Parameters:

Name Type Description Default
links_df DataFrame

DataFrame with columns for from_node and to_node.

required
nodes_df GeoDataFrame

GeoDataFrame with geometry column.

required
from_node str

column name in links_df for the from node. Defaults to “A”.

'A'
to_node str

column name in links_df for the to node. Defaults to “B”.

'B'
node_pk str

primary key column name in nodes_df. Defaults to “model_node_id”.

'model_node_id'
Source code in network_wrangler/utils/geo.py
def linestring_from_nodes(
    links_df: pd.DataFrame,
    nodes_df: gpd.GeoDataFrame,
    from_node: str = "A",
    to_node: str = "B",
    node_pk: str = "model_node_id",
) -> gpd.GeoSeries:
    """Creates a LineString geometry GeoSeries from a DataFrame of links and a DataFrame of nodes.

    Args:
        links_df: DataFrame with columns for from_node and to_node.
        nodes_df: GeoDataFrame with geometry column.
        from_node: column name in links_df for the from node. Defaults to "A".
        to_node: column name in links_df for the to node. Defaults to "B".
        node_pk: primary key column name in nodes_df. Defaults to "model_node_id".
    """
    assert "geometry" in nodes_df.columns, "nodes_df must have a 'geometry' column"

    idx_name = "index" if links_df.index.name is None else links_df.index.name
    # WranglerLogger.debug(f"Index name: {idx_name}")
    required_link_cols = [from_node, to_node]

    if not all(col in links_df.columns for col in required_link_cols):
        WranglerLogger.error(
            f"links_df.columns missing required columns.\n\
                            links_df.columns: {links_df.columns}\n\
                            required_link_cols: {required_link_cols}"
        )
        msg = "links_df must have columns {required_link_cols} to create linestring from nodes"
        raise ValueError(msg)

    links_geo_df = copy.deepcopy(links_df[required_link_cols])
    # need to continuously reset the index to make sure the index is the same as the link index
    links_geo_df = (
        links_geo_df.reset_index()
        .merge(
            nodes_df[[node_pk, "geometry"]],
            left_on=from_node,
            right_on=node_pk,
            how="left",
        )
        .set_index(idx_name)
    )

    links_geo_df = links_geo_df.rename(columns={"geometry": "geometry_A"})

    links_geo_df = (
        links_geo_df.reset_index()
        .merge(
            nodes_df[[node_pk, "geometry"]],
            left_on=to_node,
            right_on=node_pk,
            how="left",
        )
        .set_index(idx_name)
    )

    links_geo_df = links_geo_df.rename(columns={"geometry": "geometry_B"})

    # makes sure all nodes exist
    _missing_geo_links_df = links_geo_df[
        links_geo_df["geometry_A"].isnull() | links_geo_df["geometry_B"].isnull()
    ]
    if not _missing_geo_links_df.empty:
        missing_nodes = _missing_geo_links_df[[from_node, to_node]].values
        WranglerLogger.error(
            f"Cannot create link geometry from nodes because the nodes are\
                             missing from the network. Missing nodes: {missing_nodes}"
        )
        msg = "Cannot create link geometry from nodes because the nodes are missing from the network."
        raise MissingNodesError(msg)

    # create geometry from points
    links_geo_df["geometry"] = links_geo_df.apply(
        lambda row: LineString([row["geometry_A"], row["geometry_B"]]), axis=1
    )

    # convert to GeoDataFrame
    links_gdf = gpd.GeoDataFrame(links_geo_df["geometry"], geometry=links_geo_df["geometry"])
    return links_gdf["geometry"]

location_ref_from_point(geometry, sequence=1, bearing=None, distance_to_next_ref=None)

Generates a shared street point location reference.

Parameters:

Name Type Description Default
geometry Point

Point shapely geometry

required
sequence int

Sequence if part of polyline. Defaults to None.

1
bearing float

Direction of line if part of polyline. Defaults to None.

None
distance_to_next_ref float

Distnce to next point if part of polyline. Defaults to None.

None

Returns:

Name Type Description
LocationReference LocationReference

As defined by sharedStreets.io schema

Source code in network_wrangler/utils/geo.py
def location_ref_from_point(
    geometry: Point,
    sequence: int = 1,
    bearing: Optional[float] = None,
    distance_to_next_ref: Optional[float] = None,
) -> LocationReference:
    """Generates a shared street point location reference.

    Args:
        geometry (Point): Point shapely geometry
        sequence (int, optional): Sequence if part of polyline. Defaults to None.
        bearing (float, optional): Direction of line if part of polyline. Defaults to None.
        distance_to_next_ref (float, optional): Distnce to next point if part of polyline.
            Defaults to None.

    Returns:
        LocationReference: As defined by sharedStreets.io schema
    """
    lr = {
        "point": LatLongCoordinates(geometry.coords[0]),
    }

    for arg in ["sequence", "bearing", "distance_to_next_ref"]:
        if locals()[arg] is not None:
            lr[arg] = locals()[arg]

    return LocationReference(**lr)

location_refs_from_linestring(geometry)

Generates a shared street location reference from linestring.

Parameters:

Name Type Description Default
geometry LineString

Shapely LineString instance

required

Returns:

Name Type Description
LocationReferences list[LocationReference]

As defined by sharedStreets.io schema

Source code in network_wrangler/utils/geo.py
def location_refs_from_linestring(geometry: LineString) -> list[LocationReference]:
    """Generates a shared street location reference from linestring.

    Args:
        geometry (LineString): Shapely LineString instance

    Returns:
        LocationReferences: As defined by sharedStreets.io schema
    """
    return [
        location_ref_from_point(
            point,
            sequence=i + 1,
            distance_to_next_ref=point.distance(geometry.coords[i + 1]),
            bearing=get_bearing(*point.coords[0], *geometry.coords[i + 1]),
        )
        for i, point in enumerate(geometry.coords[:-1])
    ]

offset_geometry_meters(geo_s, offset_distance_meters)

Offset a GeoSeries of LineStrings by a given distance in meters.

Parameters:

Name Type Description Default
geo_s GeoSeries

GeoSeries of LineStrings to offset.

required
offset_distance_meters float

distance in meters to offset the LineStrings.

required
Source code in network_wrangler/utils/geo.py
def offset_geometry_meters(geo_s: gpd.GeoSeries, offset_distance_meters: float) -> gpd.GeoSeries:
    """Offset a GeoSeries of LineStrings by a given distance in meters.

    Args:
        geo_s: GeoSeries of LineStrings to offset.
        offset_distance_meters: distance in meters to offset the LineStrings.
    """
    if geo_s.empty:
        return geo_s
    og_crs = geo_s.crs
    meters_crs = _id_utm_crs(geo_s)
    geo_s = geo_s.to_crs(meters_crs)
    offset_geo = geo_s.apply(lambda x: x.offset_curve(offset_distance_meters))
    offset_geo = gpd.GeoSeries(offset_geo)
    return offset_geo.to_crs(og_crs)

offset_point_with_distance_and_bearing(lon, lat, distance, bearing)

Get the new lon-lat (in degrees) given current point (lon-lat), distance and bearing.

Parameters:

Name Type Description Default
lon float

longitude of original point

required
lat float

latitude of original point

required
distance float

distance in meters to offset point by

required
bearing float

direction to offset point to in radians

required
Source code in network_wrangler/utils/geo.py
def offset_point_with_distance_and_bearing(
    lon: float, lat: float, distance: float, bearing: float
) -> list[float]:
    """Get the new lon-lat (in degrees) given current point (lon-lat), distance and bearing.

    Args:
        lon: longitude of original point
        lat: latitude of original point
        distance: distance in meters to offset point by
        bearing: direction to offset point to in radians

    returns: list of new offset lon-lat
    """
    # Earth's radius in meters
    radius = 6378137

    # convert the lat long from degree to radians
    lat_radians = math.radians(lat)
    lon_radians = math.radians(lon)

    # calculate the new lat long in radians
    out_lat_radians = math.asin(
        math.sin(lat_radians) * math.cos(distance / radius)
        + math.cos(lat_radians) * math.sin(distance / radius) * math.cos(bearing)
    )

    out_lon_radians = lon_radians + math.atan2(
        math.sin(bearing) * math.sin(distance / radius) * math.cos(lat_radians),
        math.cos(distance / radius) - math.sin(lat_radians) * math.sin(lat_radians),
    )
    # convert the new lat long back to degree
    out_lat = math.degrees(out_lat_radians)
    out_lon = math.degrees(out_lon_radians)

    return [out_lon, out_lat]

point_from_xy(x, y, xy_crs=LAT_LON_CRS, point_crs=LAT_LON_CRS)

Creates point geometry from x and y coordinates.

Parameters:

Name Type Description Default
x

x coordinate, in xy_crs

required
y

y coordinate, in xy_crs

required
xy_crs int

coordinate reference system in ESPG code for x/y inputs. Defaults to 4326 (WGS84)

LAT_LON_CRS
point_crs int

coordinate reference system in ESPG code for point output. Defaults to 4326 (WGS84)

LAT_LON_CRS
Source code in network_wrangler/utils/geo.py
def point_from_xy(x, y, xy_crs: int = LAT_LON_CRS, point_crs: int = LAT_LON_CRS):
    """Creates point geometry from x and y coordinates.

    Args:
        x: x coordinate, in xy_crs
        y: y coordinate, in xy_crs
        xy_crs: coordinate reference system in ESPG code for x/y inputs. Defaults to 4326 (WGS84)
        point_crs: coordinate reference system in ESPG code for point output.
            Defaults to 4326 (WGS84)

    Returns: Shapely Point in point_crs
    """
    point = Point(x, y)

    if xy_crs == point_crs:
        check_point_valid_for_crs(point, point_crs)
        return point

    if (xy_crs, point_crs) not in transformers:
        # store transformers in dictionary because they are an "expensive" operation
        transformers[(xy_crs, point_crs)] = Transformer.from_proj(
            Proj(init="epsg:" + str(xy_crs)),
            Proj(init="epsg:" + str(point_crs)),
            always_xy=True,  # required b/c Proj v6+ uses lon/lat instead of x/y
        )

    return transform(transformers[(xy_crs, point_crs)].transform, point)

to_points_gdf(table, ref_nodes_df=None, ref_road_net=None)

Convert a table to a GeoDataFrame.

If the table is already a GeoDataFrame, return it as is. Otherwise, attempt to convert the table to a GeoDataFrame using the following methods: 1. If the table has a ‘geometry’ column, return a GeoDataFrame using that column. 2. If the table has ‘lat’ and ‘lon’ columns, return a GeoDataFrame using those columns. 3. If the table has a ‘*model_node_id’ or ‘stop_id’ column, return a GeoDataFrame using that column and the nodes_df provided. If none of the above, raise a ValueError.

Parameters:

Name Type Description Default
table DataFrame

DataFrame to convert to GeoDataFrame.

required
ref_nodes_df Optional[GeoDataFrame]

GeoDataFrame of nodes to use to convert model_node_id to geometry.

None
ref_road_net Optional[RoadwayNetwork]

RoadwayNetwork object to use to convert model_node_id to geometry.

None

Returns:

Name Type Description
GeoDataFrame GeoDataFrame

GeoDataFrame representation of the table.

Source code in network_wrangler/utils/geo.py
def to_points_gdf(
    table: pd.DataFrame,
    ref_nodes_df: Optional[gpd.GeoDataFrame] = None,
    ref_road_net: Optional[RoadwayNetwork] = None,
) -> gpd.GeoDataFrame:
    """Convert a table to a GeoDataFrame.

    If the table is already a GeoDataFrame, return it as is. Otherwise, attempt to convert the
    table to a GeoDataFrame using the following methods:
    1. If the table has a 'geometry' column, return a GeoDataFrame using that column.
    2. If the table has 'lat' and 'lon' columns, return a GeoDataFrame using those columns.
    3. If the table has a '*model_node_id' or 'stop_id' column, return a GeoDataFrame using that column and the
         nodes_df provided.
    If none of the above, raise a ValueError.

    Args:
        table: DataFrame to convert to GeoDataFrame.
        ref_nodes_df: GeoDataFrame of nodes to use to convert model_node_id to geometry.
        ref_road_net: RoadwayNetwork object to use to convert model_node_id to geometry.

    Returns:
        GeoDataFrame: GeoDataFrame representation of the table.
    """
    if table is gpd.GeoDataFrame:
        return table

    WranglerLogger.debug("Converting GTFS table to GeoDataFrame")
    if "geometry" in table.columns:
        return gpd.GeoDataFrame(table, geometry="geometry")

    lat_cols = list(filter(lambda col: "lat" in col, table.columns))
    lon_cols = list(filter(lambda col: "lon" in col, table.columns))
    model_node_id_cols = [
        c for c in ["model_node_id", "stop_id", "shape_model_node_id"] if c in table.columns
    ]

    if not (lat_cols and lon_cols) or not model_node_id_cols:
        WranglerLogger.error(
            "Needed either lat/long or *model_node_id columns to convert \
            to GeoDataFrame. Columns found: {table.columns}"
        )
        if not (lat_cols and lon_cols):
            WranglerLogger.error("No lat/long cols found.")
        if not model_node_id_cols:
            WranglerLogger.error("No *model_node_id cols found.")
        msg = "Could not find lat/long, geometry columns or *model_node_id column in \
                         table necessary to convert to GeoDataFrame"
        raise ValueError(msg)

    if lat_cols and lon_cols:
        # using first found lat and lon columns
        return gpd.GeoDataFrame(
            table,
            geometry=gpd.points_from_xy(table[lon_cols[0]], table[lat_cols[0]]),
            crs="EPSG:4326",
        )

    if model_node_id_cols:
        node_id_col = model_node_id_cols[0]

        if ref_nodes_df is None:
            if ref_road_net is None:
                msg = "Must provide either nodes_df or road_net to convert \
                                 model_node_id to geometry"
                raise ValueError(msg)
            ref_nodes_df = ref_road_net.nodes_df

        WranglerLogger.debug("Converting table to GeoDataFrame using model_node_id")

        _table = table.merge(
            ref_nodes_df[["model_node_id", "geometry"]],
            left_on=node_id_col,
            right_on="model_node_id",
        )
        return gpd.GeoDataFrame(_table, geometry="geometry")
    msg = "Could not find lat/long, geometry columns or *model_node_id column in table \
                        necessary to convert to GeoDataFrame"
    raise ValueError(msg)

update_nodes_in_linestring_geometry(original_df, updated_nodes_df, position)

Updates the nodes in a linestring geometry and returns updated geometry.

Parameters:

Name Type Description Default
original_df GeoDataFrame

GeoDataFrame with the model_node_id and linestring geometry

required
updated_nodes_df GeoDataFrame

GeoDataFrame with updated node geometries.

required
position int

position in the linestring to update with the node.

required
Source code in network_wrangler/utils/geo.py
def update_nodes_in_linestring_geometry(
    original_df: gpd.GeoDataFrame,
    updated_nodes_df: gpd.GeoDataFrame,
    position: int,
) -> gpd.GeoSeries:
    """Updates the nodes in a linestring geometry and returns updated geometry.

    Args:
        original_df: GeoDataFrame with the `model_node_id` and linestring geometry
        updated_nodes_df: GeoDataFrame with updated node geometries.
        position: position in the linestring to update with the node.
    """
    LINK_FK_NODE = ["A", "B"]
    original_index = original_df.index

    updated_df = original_df.reset_index().merge(
        updated_nodes_df[["model_node_id", "geometry"]],
        left_on=LINK_FK_NODE[position],
        right_on="model_node_id",
        suffixes=("", "_node"),
    )

    updated_df["geometry"] = updated_df.apply(
        lambda row: update_points_in_linestring(
            row["geometry"], row["geometry_node"].coords[0], position
        ),
        axis=1,
    )

    updated_df = updated_df.reset_index().set_index(original_index.names)

    WranglerLogger.debug(f"updated_df - AFTER: \n {updated_df.geometry}")
    return updated_df["geometry"]

update_point_geometry(df, ref_point_df, lon_field='X', lat_field='Y', id_field='model_node_id', ref_lon_field='X', ref_lat_field='Y', ref_id_field='model_node_id')

Returns copy of df with lat and long fields updated with geometry from ref_point_df.

NOTE: does not update “geometry” field if it exists.

Source code in network_wrangler/utils/geo.py
def update_point_geometry(
    df: pd.DataFrame,
    ref_point_df: pd.DataFrame,
    lon_field: str = "X",
    lat_field: str = "Y",
    id_field: str = "model_node_id",
    ref_lon_field: str = "X",
    ref_lat_field: str = "Y",
    ref_id_field: str = "model_node_id",
) -> pd.DataFrame:
    """Returns copy of df with lat and long fields updated with geometry from ref_point_df.

    NOTE: does not update "geometry" field if it exists.
    """
    df = copy.deepcopy(df)

    ref_df = ref_point_df.rename(
        columns={
            ref_lon_field: lon_field,
            ref_lat_field: lat_field,
            ref_id_field: id_field,
        }
    )

    updated_df = update_df_by_col_value(
        df,
        ref_df[[id_field, lon_field, lat_field]],
        id_field,
        properties=[lat_field, lon_field],
        fail_if_missing=False,
    )
    return updated_df

update_points_in_linestring(linestring, updated_coords, position)

Replaces a point in a linestring with a new point.

Parameters:

Name Type Description Default
linestring LineString

original_linestring

required
updated_coords List[float]

updated poimt coordinates

required
position int

position in the linestring to update

required
Source code in network_wrangler/utils/geo.py
def update_points_in_linestring(
    linestring: LineString, updated_coords: list[float], position: int
):
    """Replaces a point in a linestring with a new point.

    Args:
        linestring (LineString): original_linestring
        updated_coords (List[float]): updated poimt coordinates
        position (int): position in the linestring to update
    """
    coords = [c for c in linestring.coords]
    coords[position] = updated_coords
    return LineString(coords)

Dataframe accessors that allow functions to be called directly on the dataframe.

DictQueryAccessor

Query link, node and shape dataframes using project selection dictionary.

Will overlook any keys which are not columns in the dataframe.

Usage:

selection_dict = {
    "lanes": [1, 2, 3],
    "name": ["6th", "Sixth", "sixth"],
    "drive_access": 1,
}
selected_links_df = links_df.dict_query(selection_dict)
Source code in network_wrangler/utils/df_accessors.py
@pd.api.extensions.register_dataframe_accessor("dict_query")
class DictQueryAccessor:
    """Query link, node and shape dataframes using project selection dictionary.

    Will overlook any keys which are not columns in the dataframe.

    Usage:

    ```
    selection_dict = {
        "lanes": [1, 2, 3],
        "name": ["6th", "Sixth", "sixth"],
        "drive_access": 1,
    }
    selected_links_df = links_df.dict_query(selection_dict)
    ```

    """

    def __init__(self, pandas_obj):
        """Initialization function for the dictionary query accessor."""
        self._obj = pandas_obj

    def __call__(self, selection_dict: dict, return_all_if_none: bool = False):
        """Queries the dataframe using the selection dictionary.

        Args:
            selection_dict (dict): _description_
            return_all_if_none (bool, optional): If True, will return entire df if dict has
                 no values. Defaults to False.
        """
        _not_selection_keys = ["modes", "all", "ignore_missing"]
        _selection_dict = {
            k: v
            for k, v in selection_dict.items()
            if k not in _not_selection_keys and v is not None
        }
        missing_columns = [k for k in _selection_dict if k not in self._obj.columns]
        if missing_columns:
            msg = f"Selection fields not found in dataframe: {missing_columns}"
            raise SelectionError(msg)

        if not _selection_dict:
            if return_all_if_none:
                return self._obj
            msg = f"Relevant part of selection dictionary is empty: {selection_dict}"
            raise SelectionError(msg)

        _sel_query = dict_to_query(_selection_dict)
        # WranglerLogger.debug(f"_sel_query: \n   {_sel_query}")
        _df = self._obj.query(_sel_query, engine="python")

        if len(_df) == 0:
            WranglerLogger.warning(
                f"No records found in df \
                  using selection: {selection_dict}"
            )
        return _df

__call__(selection_dict, return_all_if_none=False)

Queries the dataframe using the selection dictionary.

Parameters:

Name Type Description Default
selection_dict dict

description

required
return_all_if_none bool

If True, will return entire df if dict has no values. Defaults to False.

False
Source code in network_wrangler/utils/df_accessors.py
def __call__(self, selection_dict: dict, return_all_if_none: bool = False):
    """Queries the dataframe using the selection dictionary.

    Args:
        selection_dict (dict): _description_
        return_all_if_none (bool, optional): If True, will return entire df if dict has
             no values. Defaults to False.
    """
    _not_selection_keys = ["modes", "all", "ignore_missing"]
    _selection_dict = {
        k: v
        for k, v in selection_dict.items()
        if k not in _not_selection_keys and v is not None
    }
    missing_columns = [k for k in _selection_dict if k not in self._obj.columns]
    if missing_columns:
        msg = f"Selection fields not found in dataframe: {missing_columns}"
        raise SelectionError(msg)

    if not _selection_dict:
        if return_all_if_none:
            return self._obj
        msg = f"Relevant part of selection dictionary is empty: {selection_dict}"
        raise SelectionError(msg)

    _sel_query = dict_to_query(_selection_dict)
    # WranglerLogger.debug(f"_sel_query: \n   {_sel_query}")
    _df = self._obj.query(_sel_query, engine="python")

    if len(_df) == 0:
        WranglerLogger.warning(
            f"No records found in df \
              using selection: {selection_dict}"
        )
    return _df

__init__(pandas_obj)

Initialization function for the dictionary query accessor.

Source code in network_wrangler/utils/df_accessors.py
def __init__(self, pandas_obj):
    """Initialization function for the dictionary query accessor."""
    self._obj = pandas_obj

Isin_dict

Faster implimentation of isin for querying dataframes with dictionary.

Source code in network_wrangler/utils/df_accessors.py
@pd.api.extensions.register_dataframe_accessor("isin_dict")
class Isin_dict:
    """Faster implimentation of isin for querying dataframes with dictionary."""

    def __init__(self, pandas_obj):
        """Initialization function for the dataframe hash."""
        self._obj = pandas_obj

    def __call__(self, d: dict, **kwargs) -> pd.DataFrame:
        """Function to perform the faster dictionary isin()."""
        return isin_dict(self._obj, d, **kwargs)

__call__(d, **kwargs)

Function to perform the faster dictionary isin().

Source code in network_wrangler/utils/df_accessors.py
def __call__(self, d: dict, **kwargs) -> pd.DataFrame:
    """Function to perform the faster dictionary isin()."""
    return isin_dict(self._obj, d, **kwargs)

__init__(pandas_obj)

Initialization function for the dataframe hash.

Source code in network_wrangler/utils/df_accessors.py
def __init__(self, pandas_obj):
    """Initialization function for the dataframe hash."""
    self._obj = pandas_obj

dfHash

Creates a dataframe hash that is compatable with geopandas and various metadata.

Definitely not the fastest, but she seems to work where others have failed.

Source code in network_wrangler/utils/df_accessors.py
@pd.api.extensions.register_dataframe_accessor("df_hash")
class dfHash:
    """Creates a dataframe hash that is compatable with geopandas and various metadata.

    Definitely not the fastest, but she seems to work where others have failed.
    """

    def __init__(self, pandas_obj):
        """Initialization function for the dataframe hash."""
        self._obj = pandas_obj

    def __call__(self):
        """Function to hash the dataframe."""
        _value = str(self._obj.values).encode()
        hash = hashlib.sha1(_value).hexdigest()
        return hash

__call__()

Function to hash the dataframe.

Source code in network_wrangler/utils/df_accessors.py
def __call__(self):
    """Function to hash the dataframe."""
    _value = str(self._obj.values).encode()
    hash = hashlib.sha1(_value).hexdigest()
    return hash

__init__(pandas_obj)

Initialization function for the dataframe hash.

Source code in network_wrangler/utils/df_accessors.py
def __init__(self, pandas_obj):
    """Initialization function for the dataframe hash."""
    self._obj = pandas_obj

Logging utilities for Network Wrangler.

setup_logging(info_log_filename=None, debug_log_filename=None, std_out_level='info')

Sets up the WranglerLogger w.r.t. the debug file location and if logging to console.

Called by the test_logging fixture in conftest.py and can be called by the user to setup logging for their session. If called multiple times, the logger will be reset.

Parameters:

Name Type Description Default
info_log_filename Optional[Path]

the location of the log file that will get created to add the INFO log. The INFO Log is terse, just gives the bare minimum of details. Defaults to file in cwd() wrangler_[datetime].log. To turn off logging to a file, use log_filename = None.

None
debug_log_filename Optional[Path]

the location of the log file that will get created to add the DEBUG log The DEBUG log is very noisy, for debugging. Defaults to file in cwd() wrangler_[datetime].log. To turn off logging to a file, use log_filename = None.

None
std_out_level str

the level of logging to the console. One of “info”, “warning”, “debug”. Defaults to “info” but will be set to ERROR if nothing provided matches.

'info'
Source code in network_wrangler/logger.py
def setup_logging(
    info_log_filename: Optional[Path] = None,
    debug_log_filename: Optional[Path] = None,
    std_out_level: str = "info",
):
    """Sets up the WranglerLogger w.r.t. the debug file location and if logging to console.

    Called by the test_logging fixture in conftest.py and can be called by the user to setup
    logging for their session. If called multiple times, the logger will be reset.

    Args:
        info_log_filename: the location of the log file that will get created to add the INFO log.
            The INFO Log is terse, just gives the bare minimum of details.
            Defaults to file in cwd() `wrangler_[datetime].log`. To turn off logging to a file,
            use log_filename = None.
        debug_log_filename: the location of the log file that will get created to add the DEBUG log
            The DEBUG log is very noisy, for debugging. Defaults to file in cwd()
            `wrangler_[datetime].log`. To turn off logging to a file, use log_filename = None.
        std_out_level: the level of logging to the console. One of "info", "warning", "debug".
            Defaults to "info" but will be set to ERROR if nothing provided matches.
    """
    # add function variable so that we know if logging has been called
    setup_logging.called = True

    DEFAULT_LOG_PATH = Path(f"wrangler_{datetime.now().strftime('%Y_%m_%d__%H_%M_%S')}.debug.log")
    debug_log_filename = debug_log_filename if debug_log_filename else DEFAULT_LOG_PATH

    # Clear handles if any exist already
    WranglerLogger.handlers = []

    WranglerLogger.setLevel(logging.DEBUG)

    FORMAT = logging.Formatter(
        "%(asctime)-15s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S,"
    )
    default_info_f = f"network_wrangler_{datetime.now().strftime('%Y_%m_%d__%H_%M_%S')}.info.log"
    info_log_filename = info_log_filename or Path.cwd() / default_info_f

    info_file_handler = logging.FileHandler(Path(info_log_filename))
    info_file_handler.setLevel(logging.INFO)
    info_file_handler.setFormatter(FORMAT)
    WranglerLogger.addHandler(info_file_handler)

    # create debug file only when debug_log_filename is provided
    if debug_log_filename:
        debug_log_handler = logging.FileHandler(Path(debug_log_filename))
        debug_log_handler.setLevel(logging.DEBUG)
        debug_log_handler.setFormatter(FORMAT)
        WranglerLogger.addHandler(debug_log_handler)

    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.DEBUG)
    console_handler.setFormatter(FORMAT)
    WranglerLogger.addHandler(console_handler)
    if std_out_level in ("debug", "info"):
        console_handler.setLevel(logging.DEBUG)
    elif std_out_level == "warning":
        console_handler.setLevel(logging.WARNING)
    else:
        console_handler.setLevel(logging.ERROR)

Configuration utilities.

ConfigItem

Base class to add partial dict-like interface to configuration.

Allow use of .items() [“X”] and .get(“X”) .to_dict() from configuration.

Not to be constructed directly. To be used a mixin for dataclasses representing config schema. Do not use “get” “to_dict”, or “items” for key names.

Source code in network_wrangler/configs/utils.py
class ConfigItem:
    """Base class to add partial dict-like interface to  configuration.

    Allow use of .items() ["X"] and .get("X") .to_dict() from configuration.

    Not to be constructed directly. To be used a mixin for dataclasses
    representing config schema.
    Do not use "get" "to_dict", or "items" for key names.
    """

    base_path: Optional[Path] = None

    def __getitem__(self, key):
        """Return the value for key if key is in the dictionary, else default."""
        return getattr(self, key)

    def items(self):
        """A set-like object providing a view on D's items."""
        return self.__dict__.items()

    def to_dict(self):
        """Convert the configuration to a dictionary."""
        result = {}
        for key, value in self.__dict__.items():
            if isinstance(value, ConfigItem):
                result[key] = value.to_dict()
            else:
                result[key] = value
        return result

    def get(self, key, default=None):
        """Return the value for key if key is in the dictionary, else default."""
        return self.__dict__.get(key, default)

    def update(self, data: Union[Path, list[Path], dict]):
        """Update the configuration with a dictionary of new values."""
        if not isinstance(data, dict):
            WranglerLogger.info(f"Updating configuration with {data}.")
            data = load_merge_dict(data)

        self.__dict__.update(data)
        return self

    def resolve_paths(self, base_path):
        """Resolve relative paths in the configuration."""
        base_path = Path(base_path)
        for key, value in self.__dict__.items():
            if isinstance(value, ConfigItem):
                value.resolve_paths(base_path)
            elif isinstance(value, str) and value.startswith("."):
                resolved_path = (base_path / value).resolve()
                setattr(self, key, str(resolved_path))

__getitem__(key)

Return the value for key if key is in the dictionary, else default.

Source code in network_wrangler/configs/utils.py
def __getitem__(self, key):
    """Return the value for key if key is in the dictionary, else default."""
    return getattr(self, key)

get(key, default=None)

Return the value for key if key is in the dictionary, else default.

Source code in network_wrangler/configs/utils.py
def get(self, key, default=None):
    """Return the value for key if key is in the dictionary, else default."""
    return self.__dict__.get(key, default)

items()

A set-like object providing a view on D’s items.

Source code in network_wrangler/configs/utils.py
def items(self):
    """A set-like object providing a view on D's items."""
    return self.__dict__.items()

resolve_paths(base_path)

Resolve relative paths in the configuration.

Source code in network_wrangler/configs/utils.py
def resolve_paths(self, base_path):
    """Resolve relative paths in the configuration."""
    base_path = Path(base_path)
    for key, value in self.__dict__.items():
        if isinstance(value, ConfigItem):
            value.resolve_paths(base_path)
        elif isinstance(value, str) and value.startswith("."):
            resolved_path = (base_path / value).resolve()
            setattr(self, key, str(resolved_path))

to_dict()

Convert the configuration to a dictionary.

Source code in network_wrangler/configs/utils.py
def to_dict(self):
    """Convert the configuration to a dictionary."""
    result = {}
    for key, value in self.__dict__.items():
        if isinstance(value, ConfigItem):
            result[key] = value.to_dict()
        else:
            result[key] = value
    return result

update(data)

Update the configuration with a dictionary of new values.

Source code in network_wrangler/configs/utils.py
def update(self, data: Union[Path, list[Path], dict]):
    """Update the configuration with a dictionary of new values."""
    if not isinstance(data, dict):
        WranglerLogger.info(f"Updating configuration with {data}.")
        data = load_merge_dict(data)

    self.__dict__.update(data)
    return self

find_configs_in_dir(dir, config_type)

Find configuration files in the directory that match *config<ext>.

Source code in network_wrangler/configs/utils.py
def find_configs_in_dir(dir: Union[Path, list[Path]], config_type) -> list[Path]:
    """Find configuration files in the directory that match `*config<ext>`."""
    config_files: list[Path] = []
    if isinstance(dir, list):
        for d in dir:
            config_files.extend(find_configs_in_dir(d, config_type))
    elif dir.is_dir():
        dir = Path(dir)
        for ext in SUPPORTED_CONFIG_EXTENSIONS:
            config_like_files = list(dir.glob(f"*config{ext}"))
            config_files.extend(find_configs_in_dir(config_like_files, config_type))
    elif dir.is_file():
        try:
            config_type(load_dict(dir))
        except ValidationError:
            return config_files
        config_files.append(dir)

    if config_files:
        return [Path(config_file) for config_file in config_files]
    return []

Module for time and timespan objects.

Time

Represents a time object.

This class provides methods to initialize and manipulate time objects.

Attributes:

Name Type Description
datetime datetime

The underlying datetime object representing the time.

time_str str

The time string representation in HH:MM:SS format.

time_sec int

The time in seconds since midnight.

_raw_time_in TimeType

The raw input value used to initialize the Time object.

Source code in network_wrangler/time.py
class Time:
    """Represents a time object.

    This class provides methods to initialize and manipulate time objects.

    Attributes:
        datetime (datetime): The underlying datetime object representing the time.
        time_str (str): The time string representation in HH:MM:SS format.
        time_sec (int): The time in seconds since midnight.

        _raw_time_in (TimeType): The raw input value used to initialize the Time object.

    """

    def __init__(self, value: TimeType):
        """Initializes a Time object.

        Args:
            value (TimeType): A time object, string in HH:MM[:SS] format, or seconds since
                midnight.

        Raises:
            TimeFormatError: If the value is not a valid time format.

        """
        if isinstance(value, datetime):
            self.datetime: datetime = value
        elif isinstance(value, time):
            self.datetime = datetime.combine(datetime.today(), value)
        elif isinstance(value, str):
            self.datetime = str_to_time(value)
        elif isinstance(value, int):
            self.datetime = datetime.datetime.fromtimestamp(value).time()
        else:
            msg = "time must be a string, int, or time object"
            raise TimeFormatError(msg)

        self._raw_time_in = value

    def __getitem__(self, item: Any) -> str:
        """Get the time string representation.

        Args:
            item (Any): Not used.

        Returns:
            str: The time string representation in HH:MM:SS format.
        """
        return self.time_str

    @property
    def time_str(self):
        """Get the time string representation.

        Returns:
            str: The time string representation in HH:MM:SS format.
        """
        return self.datetime.strftime("%H:%M:%S")

    @property
    def time_sec(self):
        """Get the time in seconds since midnight.

        Returns:
            int: The time in seconds since midnight.
        """
        return self.datetime.hour * 3600 + self.datetime.minute * 60 + self.datetime.second

    def __str__(self) -> str:
        """Get the string representation of the Time object.

        Returns:
            str: The time string representation in HH:MM:SS format.
        """
        return self.time_str

    def __hash__(self) -> int:
        """Get the hash value of the Time object.

        Returns:
            int: The hash value of the Time object.
        """
        return hash(str(self))

time_sec property

Get the time in seconds since midnight.

Returns:

Name Type Description
int

The time in seconds since midnight.

time_str property

Get the time string representation.

Returns:

Name Type Description
str

The time string representation in HH:MM:SS format.

__getitem__(item)

Get the time string representation.

Parameters:

Name Type Description Default
item Any

Not used.

required

Returns:

Name Type Description
str str

The time string representation in HH:MM:SS format.

Source code in network_wrangler/time.py
def __getitem__(self, item: Any) -> str:
    """Get the time string representation.

    Args:
        item (Any): Not used.

    Returns:
        str: The time string representation in HH:MM:SS format.
    """
    return self.time_str

__hash__()

Get the hash value of the Time object.

Returns:

Name Type Description
int int

The hash value of the Time object.

Source code in network_wrangler/time.py
def __hash__(self) -> int:
    """Get the hash value of the Time object.

    Returns:
        int: The hash value of the Time object.
    """
    return hash(str(self))

__init__(value)

Initializes a Time object.

Parameters:

Name Type Description Default
value TimeType

A time object, string in HH:MM[:SS] format, or seconds since midnight.

required

Raises:

Type Description
TimeFormatError

If the value is not a valid time format.

Source code in network_wrangler/time.py
def __init__(self, value: TimeType):
    """Initializes a Time object.

    Args:
        value (TimeType): A time object, string in HH:MM[:SS] format, or seconds since
            midnight.

    Raises:
        TimeFormatError: If the value is not a valid time format.

    """
    if isinstance(value, datetime):
        self.datetime: datetime = value
    elif isinstance(value, time):
        self.datetime = datetime.combine(datetime.today(), value)
    elif isinstance(value, str):
        self.datetime = str_to_time(value)
    elif isinstance(value, int):
        self.datetime = datetime.datetime.fromtimestamp(value).time()
    else:
        msg = "time must be a string, int, or time object"
        raise TimeFormatError(msg)

    self._raw_time_in = value

__str__()

Get the string representation of the Time object.

Returns:

Name Type Description
str str

The time string representation in HH:MM:SS format.

Source code in network_wrangler/time.py
def __str__(self) -> str:
    """Get the string representation of the Time object.

    Returns:
        str: The time string representation in HH:MM:SS format.
    """
    return self.time_str

Timespan

Timespan object.

This class provides methods to initialize and manipulate time objects.

If the end_time is less than the start_time, the duration will assume that it crosses over midnight.

Attributes:

Name Type Description
start_time time

The start time of the timespan.

end_time time

The end time of the timespan.

timespan_str_list str

A list of start time and end time in HH:MM:SS format.

start_time_sec int

The start time in seconds since midnight.

end_time_sec int

The end time in seconds since midnight.

duration timedelta

The duration of the timespan.

duration_sec int

The duration of the timespan in seconds.

_raw_timespan_in Any

The raw input value used to initialize the Timespan object.

Source code in network_wrangler/time.py
class Timespan:
    """Timespan object.

    This class provides methods to initialize and manipulate time objects.

    If the end_time is less than the start_time, the duration will assume that it crosses
        over midnight.

    Attributes:
        start_time (datetime.time): The start time of the timespan.
        end_time (datetime.time): The end time of the timespan.
        timespan_str_list (str): A list of start time and end time in HH:MM:SS format.
        start_time_sec (int): The start time in seconds since midnight.
        end_time_sec (int): The end time in seconds since midnight.
        duration (datetime.timedelta): The duration of the timespan.
        duration_sec (int): The duration of the timespan in seconds.

        _raw_timespan_in (Any): The raw input value used to initialize the Timespan object.

    """

    def __init__(self, value: list[TimeType]):
        """Constructor for the Timespan object.

        If the value is a list of two time strings, datetime objects, Time, or seconds from
        midnight, the start_time and end_time attributes will be set accordingly.

        Args:
            value (time): a list of two time strings, datetime objects, Time, or seconds from
              midnight.
        """
        if len(value) != 2:  # noqa: PLR2004
            msg = "timespan must be a list of 2 time strings, datetime objs, Time, or sec from midnight."
            raise TimespanFormatError(msg)

        self.start_time, self.end_time = (Time(t) for t in value)
        self._raw_timespan_in = value

    @property
    def timespan_str_list(self):
        """Get the timespan string representation."""
        return [self.start_time.time_str, self.end_time.time_str]

    @property
    def start_time_sec(self):
        """Start time in seconds since midnight."""
        return self.start_time.time_sec

    @property
    def end_time_sec(self):
        """End time in seconds since midnight."""
        return self.end_time.time_sec

    @property
    def duration(self):
        """Duration of timespan as a timedelta object."""
        return duration_dt(self.start_time, self.end_time)

    @property
    def duration_sec(self):
        """Duration of timespan in seconds.

        If end_time is less than start_time, the duration will assume that it crosses over
        midnight.
        """
        if self.end_time_sec < self.start_time_sec:
            return (24 * 3600) - self.start_time_sec + self.end_time_sec
        return self.end_time_sec - self.start_time_sec

    def __str__(self) -> str:
        """String representation of the Timespan object."""
        return str(self.timespan_str)

    def __hash__(self) -> int:
        """Hash value of the Timespan object."""
        return hash(str(self))

    def overlaps(self, other: Timespan) -> bool:
        """Check if two timespans overlap.

        If the start time is greater than the end time, the timespan is assumed to cross over
        midnight.

        Args:
            other (Timespan): The other timespan to compare.

        Returns:
            bool: True if the two timespans overlap, False otherwise.
        """
        real_end_time = self.end_time.datetime
        if self.end_time.datetime > self.start_time.datetime:
            real_end_time = self.end_time.datetime + datetime.timedelta(days=1)

        real_other_end_time = other.end_time.datetime
        if other.end_time.datetime > other.start_time.datetime:
            real_other_end_time = other.end_time.datetime + datetime.timedelta(days=1)
        return (
            self.start_time.datetime <= real_other_end_time
            and real_end_time >= other.start_time.datetime
        )

duration property

Duration of timespan as a timedelta object.

duration_sec property

Duration of timespan in seconds.

If end_time is less than start_time, the duration will assume that it crosses over midnight.

end_time_sec property

End time in seconds since midnight.

start_time_sec property

Start time in seconds since midnight.

timespan_str_list property

Get the timespan string representation.

__hash__()

Hash value of the Timespan object.

Source code in network_wrangler/time.py
def __hash__(self) -> int:
    """Hash value of the Timespan object."""
    return hash(str(self))

__init__(value)

Constructor for the Timespan object.

If the value is a list of two time strings, datetime objects, Time, or seconds from midnight, the start_time and end_time attributes will be set accordingly.

Parameters:

Name Type Description Default
value time

a list of two time strings, datetime objects, Time, or seconds from midnight.

required
Source code in network_wrangler/time.py
def __init__(self, value: list[TimeType]):
    """Constructor for the Timespan object.

    If the value is a list of two time strings, datetime objects, Time, or seconds from
    midnight, the start_time and end_time attributes will be set accordingly.

    Args:
        value (time): a list of two time strings, datetime objects, Time, or seconds from
          midnight.
    """
    if len(value) != 2:  # noqa: PLR2004
        msg = "timespan must be a list of 2 time strings, datetime objs, Time, or sec from midnight."
        raise TimespanFormatError(msg)

    self.start_time, self.end_time = (Time(t) for t in value)
    self._raw_timespan_in = value

__str__()

String representation of the Timespan object.

Source code in network_wrangler/time.py
def __str__(self) -> str:
    """String representation of the Timespan object."""
    return str(self.timespan_str)

overlaps(other)

Check if two timespans overlap.

If the start time is greater than the end time, the timespan is assumed to cross over midnight.

Parameters:

Name Type Description Default
other Timespan

The other timespan to compare.

required

Returns:

Name Type Description
bool bool

True if the two timespans overlap, False otherwise.

Source code in network_wrangler/time.py
def overlaps(self, other: Timespan) -> bool:
    """Check if two timespans overlap.

    If the start time is greater than the end time, the timespan is assumed to cross over
    midnight.

    Args:
        other (Timespan): The other timespan to compare.

    Returns:
        bool: True if the two timespans overlap, False otherwise.
    """
    real_end_time = self.end_time.datetime
    if self.end_time.datetime > self.start_time.datetime:
        real_end_time = self.end_time.datetime + datetime.timedelta(days=1)

    real_other_end_time = other.end_time.datetime
    if other.end_time.datetime > other.start_time.datetime:
        real_other_end_time = other.end_time.datetime + datetime.timedelta(days=1)
    return (
        self.start_time.datetime <= real_other_end_time
        and real_end_time >= other.start_time.datetime
    )

Module for visualizing roadway and transit networks using Mapbox tiles.

This module provides a function net_to_mapbox that creates and serves Mapbox tiles on a local web server based on roadway and transit networks.

Example usage

net_to_mapbox(roadway, transit)

MissingMapboxTokenError

Bases: Exception

Raised when MAPBOX_ACCESS_TOKEN is not found in environment variables.

Source code in network_wrangler/viz.py
class MissingMapboxTokenError(Exception):
    """Raised when MAPBOX_ACCESS_TOKEN is not found in environment variables."""

net_to_mapbox(roadway=None, transit=None, roadway_geojson_out=Path('roadway_shapes.geojson'), transit_geojson_out=Path('transit_shapes.geojson'), mbtiles_out=Path('network.mbtiles'), overwrite=True, port='9000')

Creates and serves mapbox tiles on local web server based on roadway and transit networks.

Parameters:

Name Type Description Default
roadway Optional[Union[RoadwayNetwork, GeoDataFrame, str, Path]]

a RoadwayNetwork instance, geodataframe with roadway linetrings, or path to a geojson file. Defaults to empty GeoDataFrame.

None
transit Optional[Union[TransitNetwork, GeoDataFrame]]

a TransitNetwork instance or a geodataframe with roadway linetrings, or path to a geojson file. Defaults to empty GeoDataFrame.

None
roadway_geojson_out Path

file path for roadway geojson which gets created if roadway is not a path to a geojson file. Defaults to roadway_shapes.geojson.

Path('roadway_shapes.geojson')
transit_geojson_out Path

file path for transit geojson which gets created if transit is not a path to a geojson file. Defaults to transit_shapes.geojson.

Path('transit_shapes.geojson')
mbtiles_out Path

path to output mapbox tiles. Defaults to network.mbtiles

Path('network.mbtiles')
overwrite bool

boolean indicating if can overwrite mbtiles_out and roadway_geojson_out and transit_geojson_out. Defaults to True.

True
port str

port to serve resulting tiles on. Defaults to 9000.

'9000'
Source code in network_wrangler/viz.py
def net_to_mapbox(
    roadway: Optional[Union[RoadwayNetwork, gpd.GeoDataFrame, str, Path]] = None,
    transit: Optional[Union[TransitNetwork, gpd.GeoDataFrame]] = None,
    roadway_geojson_out: Path = Path("roadway_shapes.geojson"),
    transit_geojson_out: Path = Path("transit_shapes.geojson"),
    mbtiles_out: Path = Path("network.mbtiles"),
    overwrite: bool = True,
    port: str = "9000",
):
    """Creates and serves mapbox tiles on local web server based on roadway and transit networks.

    Args:
        roadway: a RoadwayNetwork instance, geodataframe with roadway linetrings, or path to a
            geojson file. Defaults to empty GeoDataFrame.
        transit: a TransitNetwork instance or a geodataframe with roadway linetrings, or path to a
            geojson file. Defaults to empty GeoDataFrame.
        roadway_geojson_out: file path for roadway geojson which gets created if roadway is not
            a path to a geojson file. Defaults to roadway_shapes.geojson.
        transit_geojson_out: file path for transit geojson which gets created if transit is not
            a path to a geojson file. Defaults to transit_shapes.geojson.
        mbtiles_out: path to output mapbox tiles. Defaults to network.mbtiles
        overwrite: boolean indicating if can overwrite mbtiles_out and roadway_geojson_out and
            transit_geojson_out. Defaults to True.
        port: port to serve resulting tiles on. Defaults to 9000.
    """
    import subprocess

    if roadway is None:
        roadway = gpd.GeoDataFrame()
    if transit is None:
        transit = gpd.GeoDataFrame()
    # test for mapbox token
    try:
        os.getenv("MAPBOX_ACCESS_TOKEN")
    except Exception as err:
        WranglerLogger.error(
            "NEED TO SET MAPBOX ACCESS TOKEN IN ENVIRONMENT VARIABLES/n \
                In command line: >>export MAPBOX_ACCESS_TOKEN='pk.0000.1111' # \
                replace value with your mapbox public access token"
        )
        raise MissingMapboxTokenError() from err

    if isinstance(transit, TransitNetwork):
        transit = transit.shape_links_gdf
        transit.to_file(transit_geojson_out, driver="GeoJSON")
    elif Path(transit).exists():
        transit_geojson_out = transit
    else:
        msg = f"Don't understand transit input: {transit}"
        raise ValueError(msg)

    if isinstance(roadway, RoadwayNetwork):
        roadway = roadway.link_shapes_df
        roadway.to_file(roadway_geojson_out, driver="GeoJSON")
    elif Path(roadway).exists():
        roadway_geojson_out = Path(roadway)
    else:
        msg = "Don't understand roadway input: {roadway}"
        raise ValueError(msg)

    tippe_options_list: list[str] = ["-zg", "-o", str(mbtiles_out)]
    if overwrite:
        tippe_options_list.append("--force")
    # tippe_options_list.append("--drop-densest-as-needed")
    tippe_options_list.append(str(roadway_geojson_out))
    tippe_options_list.append(str(transit_geojson_out))

    try:
        WranglerLogger.info(
            f"Running tippecanoe with following options: {' '.join(tippe_options_list)}"
        )
        subprocess.run(["tippecanoe", *tippe_options_list], check=False)
    except Exception as err:
        WranglerLogger.error(
            "If tippecanoe isn't installed, try `brew install tippecanoe` or \
                visit https://github.com/mapbox/tippecanoe"
        )
        raise ImportError() from err

    try:
        WranglerLogger.info(
            "Running mbview with following options: {}".format(" ".join(tippe_options_list))
        )
        subprocess.run(["mbview", "--port", port, f", /{mbtiles_out}"], check=False)
    except Exception as err:
        WranglerLogger.error(
            "If mbview isn't installed, try `npm install -g @mapbox/mbview` or \
                visit https://github.com/mapbox/mbview"
        )
        raise ImportError(msg) from err

All network wrangler errors.

DataframeSelectionError

Bases: Exception

Raised when there is an issue with a selection from a dataframe.

Source code in network_wrangler/errors.py
class DataframeSelectionError(Exception):
    """Raised when there is an issue with a selection from a dataframe."""

FeedReadError

Bases: Exception

Raised when there is an error reading a transit feed.

Source code in network_wrangler/errors.py
class FeedReadError(Exception):
    """Raised when there is an error reading a transit feed."""

FeedValidationError

Bases: Exception

Raised when there is an issue with the validation of the GTFS data.

Source code in network_wrangler/errors.py
class FeedValidationError(Exception):
    """Raised when there is an issue with the validation of the GTFS data."""

InvalidScopedLinkValue

Bases: Exception

Raised when there is an issue with a scoped link value.

Source code in network_wrangler/errors.py
class InvalidScopedLinkValue(Exception):
    """Raised when there is an issue with a scoped link value."""

LinkAddError

Bases: Exception

Raised when there is an issue with adding links.

Source code in network_wrangler/errors.py
class LinkAddError(Exception):
    """Raised when there is an issue with adding links."""

LinkChangeError

Bases: Exception

Raised when there is an error in changing a link property.

Source code in network_wrangler/errors.py
class LinkChangeError(Exception):
    """Raised when there is an error in changing a link property."""

LinkCreationError

Bases: Exception

Raised when there is an issue with creating links.

Source code in network_wrangler/errors.py
class LinkCreationError(Exception):
    """Raised when there is an issue with creating links."""

LinkDeletionError

Bases: Exception

Raised when there is an issue with deleting links.

Source code in network_wrangler/errors.py
class LinkDeletionError(Exception):
    """Raised when there is an issue with deleting links."""

LinkNotFoundError

Bases: Exception

Raised when a link is not found in the links table.

Source code in network_wrangler/errors.py
class LinkNotFoundError(Exception):
    """Raised when a link is not found in the links table."""

ManagedLaneAccessEgressError

Bases: Exception

Raised when there is an issue with access/egress points to managed lanes.

Source code in network_wrangler/errors.py
class ManagedLaneAccessEgressError(Exception):
    """Raised when there is an issue with access/egress points to managed lanes."""

MissingNodesError

Bases: Exception

Raised when referenced nodes are missing from the network.

Source code in network_wrangler/errors.py
class MissingNodesError(Exception):
    """Raised when referenced nodes are missing from the network."""

NewRoadwayError

Bases: Exception

Raised when there is an issue with applying a new roadway.

Source code in network_wrangler/errors.py
class NewRoadwayError(Exception):
    """Raised when there is an issue with applying a new roadway."""

NodeAddError

Bases: Exception

Raised when there is an issue with adding nodes.

Source code in network_wrangler/errors.py
class NodeAddError(Exception):
    """Raised when there is an issue with adding nodes."""

NodeChangeError

Bases: Exception

Raised when there is an issue with applying a node change.

Source code in network_wrangler/errors.py
class NodeChangeError(Exception):
    """Raised when there is an issue with applying a node change."""

NodeDeletionError

Bases: Exception

Raised when there is an issue with deleting nodes.

Source code in network_wrangler/errors.py
class NodeDeletionError(Exception):
    """Raised when there is an issue with deleting nodes."""

NodeNotFoundError

Bases: Exception

Raised when a node is not found in the nodes table.

Source code in network_wrangler/errors.py
class NodeNotFoundError(Exception):
    """Raised when a node is not found in the nodes table."""

NodesInLinksMissingError

Bases: Exception

Raised when there is an issue with validating links and nodes.

Source code in network_wrangler/errors.py
class NodesInLinksMissingError(Exception):
    """Raised when there is an issue with validating links and nodes."""

NotLinksError

Bases: Exception

Raised when a dataframe is not a RoadLinksTable.

Source code in network_wrangler/errors.py
class NotLinksError(Exception):
    """Raised when a dataframe is not a RoadLinksTable."""

NotNodesError

Bases: Exception

Raised when a dataframe is not a RoadNodesTable.

Source code in network_wrangler/errors.py
class NotNodesError(Exception):
    """Raised when a dataframe is not a RoadNodesTable."""

ProjectCardError

Bases: Exception

Raised when a project card is not valid.

Source code in network_wrangler/errors.py
class ProjectCardError(Exception):
    """Raised when a project card is not valid."""

RoadwayDeletionError

Bases: Exception

Raised when there is an issue with applying a roadway deletion.

Source code in network_wrangler/errors.py
class RoadwayDeletionError(Exception):
    """Raised when there is an issue with applying a roadway deletion."""

RoadwayPropertyChangeError

Bases: Exception

Raised when there is an issue with applying a roadway property change.

Source code in network_wrangler/errors.py
class RoadwayPropertyChangeError(Exception):
    """Raised when there is an issue with applying a roadway property change."""

ScenarioConflictError

Bases: Exception

Raised when a conflict is detected.

Source code in network_wrangler/errors.py
class ScenarioConflictError(Exception):
    """Raised when a conflict is detected."""

ScenarioCorequisiteError

Bases: Exception

Raised when a co-requisite is not satisfied.

Source code in network_wrangler/errors.py
class ScenarioCorequisiteError(Exception):
    """Raised when a co-requisite is not satisfied."""

ScenarioPrerequisiteError

Bases: Exception

Raised when a pre-requisite is not satisfied.

Source code in network_wrangler/errors.py
class ScenarioPrerequisiteError(Exception):
    """Raised when a pre-requisite is not satisfied."""

ScopeConflictError

Bases: Exception

Raised when there is a scope conflict in a list of ScopedPropertySetItems.

Source code in network_wrangler/errors.py
class ScopeConflictError(Exception):
    """Raised when there is a scope conflict in a list of ScopedPropertySetItems."""

ScopeLinkValueError

Bases: Exception

Raised when there is an issue with ScopedLinkValueList.

Source code in network_wrangler/errors.py
class ScopeLinkValueError(Exception):
    """Raised when there is an issue with ScopedLinkValueList."""

SegmentFormatError

Bases: Exception

Error in segment format.

Source code in network_wrangler/errors.py
class SegmentFormatError(Exception):
    """Error in segment format."""

SegmentSelectionError

Bases: Exception

Error in segment selection.

Source code in network_wrangler/errors.py
class SegmentSelectionError(Exception):
    """Error in segment selection."""

SelectionError

Bases: Exception

Raised when there is an issue with a selection.

Source code in network_wrangler/errors.py
class SelectionError(Exception):
    """Raised when there is an issue with a selection."""

ShapeAddError

Bases: Exception

Raised when there is an issue with adding shapes.

Source code in network_wrangler/errors.py
class ShapeAddError(Exception):
    """Raised when there is an issue with adding shapes."""

ShapeDeletionError

Bases: Exception

Raised when there is an issue with deleting shapes from a network.

Source code in network_wrangler/errors.py
class ShapeDeletionError(Exception):
    """Raised when there is an issue with deleting shapes from a network."""

SubnetCreationError

Bases: Exception

Raised when a subnet can’t be created.

Source code in network_wrangler/errors.py
class SubnetCreationError(Exception):
    """Raised when a subnet can't be created."""

SubnetExpansionError

Bases: Exception

Raised when a subnet can’t be expanded to include a node or set of nodes.

Source code in network_wrangler/errors.py
class SubnetExpansionError(Exception):
    """Raised when a subnet can't be expanded to include a node or set of nodes."""

TimeFormatError

Bases: Exception

Time format error exception.

Source code in network_wrangler/errors.py
class TimeFormatError(Exception):
    """Time format error exception."""

TimespanFormatError

Bases: Exception

Timespan format error exception.

Source code in network_wrangler/errors.py
class TimespanFormatError(Exception):
    """Timespan format error exception."""

TransitPropertyChangeError

Bases: Exception

Error raised when applying transit property changes.

Source code in network_wrangler/errors.py
class TransitPropertyChangeError(Exception):
    """Error raised when applying transit property changes."""

TransitRoadwayConsistencyError

Bases: Exception

Error raised when transit network is inconsistent with roadway network.

Source code in network_wrangler/errors.py
class TransitRoadwayConsistencyError(Exception):
    """Error raised when transit network is inconsistent with roadway network."""

TransitRouteAddError

Bases: Exception

Error raised when applying add transit route.

Source code in network_wrangler/errors.py
class TransitRouteAddError(Exception):
    """Error raised when applying add transit route."""

TransitRoutingChangeError

Bases: Exception

Raised when there is an error in the transit routing change.

Source code in network_wrangler/errors.py
class TransitRoutingChangeError(Exception):
    """Raised when there is an error in the transit routing change."""

TransitSelectionEmptyError

Bases: Exception

Error for when no transit trips are selected.

Source code in network_wrangler/errors.py
class TransitSelectionEmptyError(Exception):
    """Error for when no transit trips are selected."""

TransitSelectionError

Bases: Exception

Base error for transit selection errors.

Source code in network_wrangler/errors.py
class TransitSelectionError(Exception):
    """Base error for transit selection errors."""

TransitSelectionNetworkConsistencyError

Bases: TransitSelectionError

Error for when transit selection dictionary is not consistent with transit network.

Source code in network_wrangler/errors.py
class TransitSelectionNetworkConsistencyError(TransitSelectionError):
    """Error for when transit selection dictionary is not consistent with transit network."""

TransitValidationError

Bases: Exception

Error raised when transit network doesn’t have expected values.

Source code in network_wrangler/errors.py
class TransitValidationError(Exception):
    """Error raised when transit network doesn't have expected values."""