diff --git a/emissions/aep/srgspec_osm.go b/emissions/aep/srgspec_osm.go index 4deb98952..0870b0e59 100644 --- a/emissions/aep/srgspec_osm.go +++ b/emissions/aep/srgspec_osm.go @@ -23,6 +23,7 @@ import ( "encoding/json" "fmt" "io" + "sync" backoff "github.com/cenkalti/backoff/v4" "github.com/ctessum/geom" @@ -57,6 +58,8 @@ type SrgSpecOSM struct { // in MergeNames. MergeMultipliers []float64 `json:"merge_multipliers"` + connectPostGISOnce sync.Once + postGISURL string conn *pgxpool.Pool } @@ -74,39 +77,43 @@ type SrgSpecOSM struct { // and the PostGIS database should have the "hstore" extension installed before // loading the data. func ReadSrgSpecOSM(ctx context.Context, r io.Reader, postGISURL string) (*SrgSpecs, error) { - if postGISURL == "" { - return nil, fmt.Errorf("PostGIS URL is required") - } - // Connect to database. - var conn *pgxpool.Pool - var err error - err = backoff.Retry(func() error { - conn, err = pgxpool.Connect(ctx, postGISURL) - if err != nil { - return err - } - return nil - }, backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 10)) - if err != nil { - return nil, fmt.Errorf("Unable to connect to PostGIS database %s after 10 tries: %w\n", postGISURL, err) - } - // Read the surrogate specification. d := json.NewDecoder(r) var o []*SrgSpecOSM - if err = d.Decode(&o); err != nil { + if err := d.Decode(&o); err != nil { return nil, err } // Add the db connection to each surrogate. srgs := NewSrgSpecs() for _, s := range o { - s.conn = conn + s.postGISURL = postGISURL srgs.Add(s) } return srgs, nil } +func (s *SrgSpecOSM) connectPostGIS() { + if s.postGISURL == "" { + panic(fmt.Errorf("PostGIS URL is required")) + } + + // Connect to database. + var conn *pgxpool.Pool + var err error + err = backoff.Retry(func() error { + conn, err = pgxpool.Connect(context.Background(), s.postGISURL) + if err != nil { + return err + } + return nil + }, backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 10)) + if err != nil { + panic(fmt.Errorf("unable to connect to PostGIS database %s after 10 tries: %w", s.postGISURL, err)) + } + s.conn = conn +} + func (srg *SrgSpecOSM) backupSurrogateNames() []string { return srg.BackupSurrogateNames } func (srg *SrgSpecOSM) region() Country { return srg.Region } func (srg *SrgSpecOSM) code() string { return srg.Code } @@ -164,6 +171,8 @@ func (srg *SrgSpecOSM) getSrgData(gridData *GridDef, inputLoc *Location, tol flo } tagKeys = tagKeys[:len(tagKeys)-1] // Remove trailing comma. + srg.connectPostGISOnce.Do(srg.connectPostGIS) + rows, err := srg.conn.Query(ctx, ` SELECT hstore_to_array(tags) tags, ST_AsBinary(way)