Skip to content

Add metadata overrides for sensitive connection string values (URL and DSN support) #3825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions bindings/postgres/metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,38 @@ metadata:
- "simple_protocol"
example: "cache_describe"
default: ""
- name: host
required: false
description: The host of the PostgreSQL database
example: "localhost"
type: string
- name: hostaddr
required: false
description: The host address of the PostgreSQL database
example: "127.0.0.1"
type: string
- name: port
required: false
description: The port of the PostgreSQL database
example: "5432"
type: string
- name: database
required: false
description: The database of the PostgreSQL database
example: "postgres"
type: string
- name: user
required: false
description: The user of the PostgreSQL database
example: "postgres"
type: string
- name: password
required: false
description: The password of the PostgreSQL database
example: "password"
type: string
- name: sslRootCert
required: false
description: The path to the SSL root certificate file
example: "/path/to/ssl/root/cert.pem"
type: string
10 changes: 5 additions & 5 deletions bindings/postgres/metadata_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestMetadata(t *testing.T) {
t.Run("has connection string", func(t *testing.T) {
m := psqlMetadata{}
props := map[string]string{
"connectionString": "foo",
"connectionString": "foo=bar",
}

err := m.InitWithMetadata(props)
Expand All @@ -44,7 +44,7 @@ func TestMetadata(t *testing.T) {
t.Run("default timeout", func(t *testing.T) {
m := psqlMetadata{}
props := map[string]string{
"connectionString": "foo",
"connectionString": "foo=bar",
}

err := m.InitWithMetadata(props)
Expand All @@ -55,7 +55,7 @@ func TestMetadata(t *testing.T) {
t.Run("invalid timeout", func(t *testing.T) {
m := psqlMetadata{}
props := map[string]string{
"connectionString": "foo",
"connectionString": "foo=bar",
"timeout": "NaN",
}

Expand All @@ -66,7 +66,7 @@ func TestMetadata(t *testing.T) {
t.Run("positive timeout", func(t *testing.T) {
m := psqlMetadata{}
props := map[string]string{
"connectionString": "foo",
"connectionString": "foo=bar",
"timeout": "42",
}

Expand All @@ -78,7 +78,7 @@ func TestMetadata(t *testing.T) {
t.Run("zero timeout", func(t *testing.T) {
m := psqlMetadata{}
props := map[string]string{
"connectionString": "foo",
"connectionString": "foo=bar",
"timeout": "0",
}

Expand Down
140 changes: 136 additions & 4 deletions common/authentication/postgresql/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
Expand All @@ -32,6 +34,13 @@ import (
// PostgresAuthMetadata contains authentication metadata for PostgreSQL components.
type PostgresAuthMetadata struct {
ConnectionString string `mapstructure:"connectionString" mapstructurealiases:"url"`
Host string `mapstructure:"host"`
HostAddr string `mapstructure:"hostaddr"`
Port string `mapstructure:"port"`
Database string `mapstructure:"database"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
SslRootCert string `mapstructure:"sslRootCert"`
ConnectionMaxIdleTime time.Duration `mapstructure:"connectionMaxIdleTime"`
MaxConns int `mapstructure:"maxConns"`
UseAzureAD bool `mapstructure:"useAzureAD"`
Expand All @@ -45,6 +54,13 @@ type PostgresAuthMetadata struct {
// Reset the object.
func (m *PostgresAuthMetadata) Reset() {
m.ConnectionString = ""
m.Host = ""
m.HostAddr = ""
m.Port = ""
m.Database = ""
m.User = ""
m.Password = ""
m.SslRootCert = ""
m.ConnectionMaxIdleTime = 0
m.MaxConns = 0
m.UseAzureAD = false
Expand All @@ -62,8 +78,9 @@ type InitWithMetadataOpts struct {
// This is different from the "useAzureAD" property from the user, which is provided by the user and instructs the component to authenticate using Azure AD.
func (m *PostgresAuthMetadata) InitWithMetadata(meta map[string]string, opts InitWithMetadataOpts) (err error) {
// Validate input
if m.ConnectionString == "" {
return errors.New("missing connection string")
_, err = m.buildConnectionString()
if err != nil {
return err
}
switch {
case opts.AzureADEnabled && m.UseAzureAD:
Expand All @@ -87,6 +104,118 @@ func (m *PostgresAuthMetadata) InitWithMetadata(meta map[string]string, opts Ini
return nil
}

// buildConnectionString builds the connection string from the metadata.
// It supports both DSN-style and URL-style connection strings.
Copy link

@alicejgibbons alicejgibbons May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah ignore me - deleted

// Metadata fields override existing values in the connection string.
func (m *PostgresAuthMetadata) buildConnectionString() (string, error) {
metadata := m.getConnectionStringMetadata()
if strings.HasPrefix(m.ConnectionString, "postgres://") || strings.HasPrefix(m.ConnectionString, "postgresql://") {
return m.buildURLConnectionString(metadata)
}
return m.buildDSNConnectionString(metadata)
}

func (m *PostgresAuthMetadata) buildDSNConnectionString(metadata map[string]string) (string, error) {
connectionString := ""
parts := strings.Split(m.ConnectionString, " ")
for _, part := range parts {
kv := strings.SplitN(part, "=", 2)
if len(kv) == 2 {
key := kv[0]
if value, ok := metadata[key]; ok {
connectionString += fmt.Sprintf("%s=%s ", key, value)
delete(metadata, key)
} else {
connectionString += fmt.Sprintf("%s=%s ", key, kv[1])
}
}
}
for k, v := range metadata {
connectionString += fmt.Sprintf("%s=%s ", k, v)
}

if connectionString == "" {
return "", errors.New("failed to build connection string")
}

return strings.TrimSpace(connectionString), nil
}

func (m *PostgresAuthMetadata) getConnectionStringMetadata() map[string]string {
metadata := make(map[string]string)
if m.User != "" {
metadata["user"] = m.User
}
if m.Host != "" {
metadata["host"] = m.Host
}
if m.HostAddr != "" {
metadata["hostaddr"] = m.HostAddr
}
if m.Port != "" {
metadata["port"] = m.Port
}
if m.Database != "" {
metadata["database"] = m.Database
}
if m.Password != "" {
metadata["password"] = m.Password
}
if m.SslRootCert != "" {
metadata["sslrootcert"] = m.SslRootCert
}
return metadata
}

func (m *PostgresAuthMetadata) buildURLConnectionString(metadata map[string]string) (string, error) {
u, err := url.Parse(m.ConnectionString)
if err != nil {
return "", fmt.Errorf("invalid URL connection string: %w", err)
}

var username string
var password string
if u.User != nil {
username = u.User.Username()
pw, set := u.User.Password()
if set {
password = pw
}
}

if val, ok := metadata["user"]; ok {
username = val
}
if val, ok := metadata["password"]; ok {
password = val
}
if username != "" {
u.User = url.UserPassword(username, password)
}

if val, ok := metadata["host"]; ok {
u.Host = val
}
if val, ok := metadata["hostaddr"]; ok {
u.Host = val
}
if m.Port != "" {
u.Host = fmt.Sprintf("%s:%s", u.Host, m.Port)
}

if val, ok := metadata["database"]; ok {
u.Path = "/" + strings.TrimPrefix(val, "/")
}

q := u.Query()
if val, ok := metadata["sslrootcert"]; ok {
q.Set("sslrootcert", val)
}
u.RawQuery = q.Encode()

return u.String(), nil
}

func (m *PostgresAuthMetadata) BuildAwsIamOptions(logger logger.Logger, properties map[string]string) (*aws.Options, error) {
awsRegion, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "AWSRegion")
region, _ := metadata.GetMetadataProperty(m.awsEnv.Metadata, "region")
Expand Down Expand Up @@ -132,8 +261,11 @@ func (m *PostgresAuthMetadata) BuildAwsIamOptions(logger logger.Logger, properti

// GetPgxPoolConfig returns the pgxpool.Config object that contains the credentials for connecting to PostgreSQL.
func (m *PostgresAuthMetadata) GetPgxPoolConfig() (*pgxpool.Config, error) {
// Get the config from the connection string
config, err := pgxpool.ParseConfig(m.ConnectionString)
connectionString, err := m.buildConnectionString()
if err != nil {
return nil, err
}
config, err := pgxpool.ParseConfig(connectionString)
if err != nil {
return nil, fmt.Errorf("failed to parse connection string: %w", err)
}
Expand Down
Loading