Skip to content

Commit

Permalink
Trying to reproduce the error with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
louisPoulain committed Sep 10, 2024
1 parent 1dbb5fd commit bb6e408
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 0 deletions.
282 changes: 282 additions & 0 deletions tests/stations_cloud.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
{
"1_10125": {
"type_id": 1,
"point_id": 10125,
"name": "1_10125",
"nat_abbr": "TGARE",
"fullname": "Arenenberg",
"longitude": 9.059916250260539,
"latitude": 47.67102750634816,
"height_masl": 470.0
},
"1_4046": {
"type_id": 1,
"point_id": 4046,
"name": "1_4046",
"nat_abbr": "GRCHI",
"fullname": "Chur / A13",
"longitude": 9.507070914334435,
"latitude": 46.85479789431043,
"height_masl": 562.0
},
"1_1770": {
"type_id": 1,
"point_id": 1770,
"name": "1_1770",
"nat_abbr": "ZELUSC",
"fullname": "Schlund",
"longitude": 8.29790440714527,
"latitude": 47.01832749844614,
"height_masl": 458.0
},
"1_10358": {
"type_id": 1,
"point_id": 10358,
"name": "1_10358",
"nat_abbr": "VDVAU",
"fullname": "Vaulion",
"longitude": 6.422616930273129,
"latitude": 46.69630027513211,
"height_masl": 885.0
},
"1_5022": {
"type_id": 1,
"point_id": 5022,
"name": "1_5022",
"nat_abbr": "RGFRG",
"fullname": "Fribourg / H\u00f4pital cantonal",
"longitude": 7.137642389029125,
"latitude": 46.80148913242183,
"height_masl": 700.0
},
"1_10204": {
"type_id": 1,
"point_id": 10204,
"name": "1_10204",
"nat_abbr": "WNSTDP",
"fullname": "Corseaux / La Pichette",
"longitude": 6.811496444780631,
"latitude": 46.46916763731881,
"height_masl": 372.0
},
"1_4270": {
"type_id": 1,
"point_id": 4270,
"name": "1_4270",
"nat_abbr": "BEKSE",
"fullname": "Kleine Scheidegg",
"longitude": 7.959165725281996,
"latitude": 46.584194394655235,
"height_masl": 2060.0
},
"1_3979": {
"type_id": 1,
"point_id": 3979,
"name": "1_3979",
"nat_abbr": "ZHWIN",
"fullname": "Winterthur",
"longitude": 8.735701370850412,
"latitude": 47.495324087775884,
"height_masl": 455.0
},
"1_2": {
"type_id": 1,
"point_id": 2,
"name": "1_2",
"nat_abbr": "RAG",
"fullname": "Bad Ragaz",
"longitude": 9.502592086511703,
"latitude": 47.01663199943618,
"height_masl": 496.0
},
"1_10110": {
"type_id": 1,
"point_id": 10110,
"name": "1_10110",
"nat_abbr": "TGFRA",
"fullname": "Frauenfeld",
"longitude": 8.894147679667007,
"latitude": 47.57327217311134,
"height_masl": 391.0
},
"1_10116": {
"type_id": 1,
"point_id": 10116,
"name": "1_10116",
"nat_abbr": "TGDUS",
"fullname": "Dussnang",
"longitude": 8.962560021925889,
"latitude": 47.42824020287953,
"height_masl": 590.0
},
"1_10118": {
"type_id": 1,
"point_id": 10118,
"name": "1_10118",
"nat_abbr": "TGAMR",
"fullname": "Amriswil",
"longitude": 9.311959628217751,
"latitude": 47.53668384537617,
"height_masl": 485.0
},
"2_606800": {
"type_id": 2,
"point_id": 606800,
"name": "2_606800",
"nat_abbr": null,
"fullname": "Melchsee-Frutt",
"longitude": 8.270720520530185,
"latitude": 46.77486076058689,
"height_masl": 1926.0
},
"2_842200": {
"type_id": 2,
"point_id": 842200,
"name": "2_842200",
"nat_abbr": null,
"fullname": "Pfungen",
"longitude": 8.64216334764632,
"latitude": 47.51604405919474,
"height_masl": 410.0
},
"2_510700": {
"type_id": 2,
"point_id": 510700,
"name": "2_510700",
"nat_abbr": null,
"fullname": "Schinznach Dorf",
"longitude": 8.144143406127775,
"latitude": 47.44753661629843,
"height_masl": 375.0
},
"2_893300": {
"type_id": 2,
"point_id": 893300,
"name": "2_893300",
"nat_abbr": null,
"fullname": "Maschwanden",
"longitude": 8.427520358316777,
"latitude": 47.23453650497392,
"height_masl": 405.0
},
"2_301100": {
"type_id": 2,
"point_id": 301100,
"name": "2_301100",
"nat_abbr": null,
"fullname": "Bern",
"longitude": 7.445856892894557,
"latitude": 46.947709413598716,
"height_masl": 542.0
},
"2_121200": {
"type_id": 2,
"point_id": 121200,
"name": "2_121200",
"nat_abbr": null,
"fullname": "Grand-Lancy",
"longitude": 6.122141810734864,
"latitude": 46.17609440756448,
"height_masl": 397.0
},
"2_885402": {
"type_id": 2,
"point_id": 885402,
"name": "2_885402",
"nat_abbr": null,
"fullname": "Galgenen",
"longitude": 8.87151969910517,
"latitude": 47.180057288047664,
"height_masl": 434.0
},
"2_111400": {
"type_id": 2,
"point_id": 111400,
"name": "2_111400",
"nat_abbr": null,
"fullname": "Colombier VD",
"longitude": 6.470478710341658,
"latitude": 46.55646045564312,
"height_masl": 532.0
},
"3_1120": {
"type_id": 3,
"point_id": 1120,
"name": "3_1120",
"nat_abbr": "VPASER",
"fullname": "Col de Seron",
"longitude": 7.167322863534754,
"latitude": 46.38031670770484,
"height_masl": 2152.0
},
"3_505": {
"type_id": 3,
"point_id": 505,
"name": "3_505",
"nat_abbr": "VSHMUT",
"fullname": "Mutthornh\u00fctte SAC",
"longitude": 7.830114138368265,
"latitude": 46.486240077766176,
"height_masl": 2901.0
},
"3_543": {
"type_id": 3,
"point_id": 543,
"name": "3_543",
"nat_abbr": "VSHSLZ",
"fullname": "Cabane de Saleinaz CAS",
"longitude": 7.070331138327478,
"latitude": 45.97646077502695,
"height_masl": 2691.0
},
"3_837": {
"type_id": 3,
"point_id": 837,
"name": "3_837",
"nat_abbr": "VGIIBH",
"fullname": "Inners Barrhorn",
"longitude": 7.737304988369621,
"latitude": 46.14970652648962,
"height_masl": 3583.0
},
"3_1099": {
"type_id": 3,
"point_id": 1099,
"name": "3_1099",
"nat_abbr": "VSHCTR",
"fullname": "Capanna Tremorgio",
"longitude": 8.724929820188311,
"latitude": 46.48166771035242,
"height_masl": 1849.0
},
"3_523": {
"type_id": 3,
"point_id": 523,
"name": "3_523",
"nat_abbr": "VSHSUS",
"fullname": "Sustlih\u00fctte SAC",
"longitude": 8.470729860151124,
"latitude": 46.75168641736309,
"height_masl": 2257.0
},
"3_597": {
"type_id": 3,
"point_id": 597,
"name": "3_597",
"nat_abbr": "VPAGOB",
"fullname": "Gottschalkenberg",
"longitude": 8.647519869364434,
"latitude": 47.152754578098296,
"height_masl": 1162.0
},
"3_1156": {
"type_id": 3,
"point_id": 1156,
"name": "3_1156",
"nat_abbr": "VGWISP",
"fullname": "Wildspitz (ZG)",
"longitude": 8.577613642637207,
"latitude": 47.084488107476,
"height_masl": 1580.0
}
}
9 changes: 9 additions & 0 deletions tests/stations_cloud_splits.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
station_split:
train: 0.7
val: 0.3
test:
- "1_2"
- "1_10110"
- "1_10116"
- "1_10118"
- "1_10125"
10 changes: 10 additions & 0 deletions tests/test_model_selection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Mapping, Optional, Any
from dataclasses import dataclass
import json
import yaml

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -74,6 +76,12 @@ def __post_init__(self):
elif self.station == "mixed":
self.station_split = {"train": 0.7, "val": 0.3, "test": self.stations[-5:]}
self.station_split_method = "random"
elif self.station == "cloud_mixed":
with open("./stations_cloud.json", "r") as f:
self.stations = json.load(f)
with open("./stations_cloud.yaml", "r") as f:
self.station_split = yaml.safe_load(f)
self.station_split_method = "random"

def time_split_lists(self):
frac = {"train": 0.6, "val": 0.2, "test": 0.2}
Expand Down Expand Up @@ -103,6 +111,8 @@ class TestDataSplitter:
ValidDataSplitterOptions(time="lists", station="fractions"),
ValidDataSplitterOptions(time="lists", station="mixed"),
ValidDataSplitterOptions(time="mixed", station="fractions"),
ValidDataSplitterOptions(time="mixed", station="mixed"),
ValidDataSplitterOptions(time="mixed", station="cloud_mixed"),
]

@pytest.mark.parametrize(
Expand Down

0 comments on commit bb6e408

Please sign in to comment.