From 9bc614e287ac9c86d19fc2493c86647c217408f7 Mon Sep 17 00:00:00 2001 From: Mit Desai Date: Fri, 6 Jun 2025 08:43:22 -0700 Subject: [PATCH] [YUNIKORN-656] Add LDAP resolver for group resolution --- go.mod | 8 +- go.sum | 31 + pkg/common/configs/config.go | 18 +- pkg/common/configs/config_test.go | 50 + pkg/common/constants.go | 28 + pkg/common/security/ldap_validator.go | 383 +++++++ pkg/common/security/ldap_validator_test.go | 984 ++++++++++++++++++ pkg/common/security/usergroup.go | 22 +- .../security/usergroup_ldap_resolver.go | 375 +++++++ .../security/usergroup_ldap_resolver_test.go | 818 +++++++++++++++ .../security/usergroup_no_resolver_test.go | 100 ++ .../security/usergroup_os_resolver_test.go | 44 + pkg/common/security/usergroup_test.go | 899 +++++++++++----- .../security/usergroup_test_resolver.go | 3 + pkg/scheduler/partition.go | 2 +- 15 files changed, 3475 insertions(+), 290 deletions(-) create mode 100644 pkg/common/security/ldap_validator.go create mode 100644 pkg/common/security/ldap_validator_test.go create mode 100644 pkg/common/security/usergroup_ldap_resolver.go create mode 100644 pkg/common/security/usergroup_ldap_resolver_test.go create mode 100644 pkg/common/security/usergroup_no_resolver_test.go create mode 100644 pkg/common/security/usergroup_os_resolver_test.go diff --git a/go.mod b/go.mod index dc6a3321d..3b4801efd 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ toolchain go1.23.7 require ( github.com/apache/yunikorn-scheduler-interface v0.0.0-20250304214837-4513ff3a692d + github.com/go-ldap/ldap/v3 v3.4.11 github.com/google/btree v1.1.3 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 @@ -34,9 +35,10 @@ require ( github.com/prometheus/client_model v0.5.0 github.com/prometheus/common v0.45.0 github.com/sasha-s/go-deadlock v0.3.5 + github.com/stretchr/testify v1.8.1 go.uber.org/zap v1.27.0 golang.org/x/exp v0.0.0-20250228200357-dead58393ab7 - golang.org/x/net v0.36.0 + golang.org/x/net v0.38.0 golang.org/x/time v0.10.0 google.golang.org/grpc v1.71.0 gopkg.in/yaml.v3 v3.0.1 @@ -44,13 +46,17 @@ require ( ) require ( + github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/procfs v0.12.0 // indirect go.uber.org/multierr v1.10.0 // indirect + golang.org/x/crypto v0.36.0 // indirect golang.org/x/sys v0.30.0 // indirect golang.org/x/text v0.22.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f // indirect diff --git a/go.sum b/go.sum index b7dd89ae7..29ea13236 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +1,20 @@ +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8= +github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU= +github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= +github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= github.com/apache/yunikorn-scheduler-interface v0.0.0-20250304214837-4513ff3a692d h1:JDRId3/5HqpDlOV1RrVL8xDrZ2v0s/ucb6vpEGvkuy8= github.com/apache/yunikorn-scheduler-interface v0.0.0-20250304214837-4513ff3a692d/go.mod h1:udBVRAW3pcKRneNL8xTC9t40I5zwLjBldT+bpzw9He4= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo= +github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0= +github.com/go-ldap/ldap/v3 v3.4.11 h1:4k0Yxweg+a3OyBLjdYn5OKglv18JNvfDykSoI8bW0gU= +github.com/go-ldap/ldap/v3 v3.4.11/go.mod h1:bY7t0FLK8OAVpp/vV6sSlpz3EQDGcQwc8pF0ujLgKvM= github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -18,6 +27,20 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/go-uuid v1.0.3 h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8= +github.com/hashicorp/go-uuid v1.0.3/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= +github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= +github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= +github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= +github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= +github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= +github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= +github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= +github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= +github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8= +github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs= +github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= +github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -44,6 +67,11 @@ github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjR github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/sasha-s/go-deadlock v0.3.5 h1:tNCOEEDG6tBqrNDOX35j/7hL5FcFViG6awUGROb2NsU= github.com/sasha-s/go-deadlock v0.3.5/go.mod h1:bugP6EGbdGYObIlx7pUZtWqlvo8k9H6vCBBsiChJQ5U= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= @@ -64,6 +92,8 @@ go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= +golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/exp v0.0.0-20250228200357-dead58393ab7 h1:aWwlzYV971S4BXRS9AmqwDLAD85ouC6X+pocatKY58c= golang.org/x/exp v0.0.0-20250228200357-dead58393ab7/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk= golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= @@ -83,6 +113,7 @@ google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojt gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= diff --git a/pkg/common/configs/config.go b/pkg/common/configs/config.go index 4c09db404..512893617 100644 --- a/pkg/common/configs/config.go +++ b/pkg/common/configs/config.go @@ -45,13 +45,19 @@ type SchedulerConfig struct { // - a list of placement rule definition objects // - a list of users specifying limits on the partition // - the preemption configuration for the partition +// - user group resolver type (os, ldap, "") type PartitionConfig struct { - Name string - Queues []QueueConfig - PlacementRules []PlacementRule `yaml:",omitempty" json:",omitempty"` - Limits []Limit `yaml:",omitempty" json:",omitempty"` - Preemption PartitionPreemptionConfig `yaml:",omitempty" json:",omitempty"` - NodeSortPolicy NodeSortingPolicy `yaml:",omitempty" json:",omitempty"` + Name string + Queues []QueueConfig + PlacementRules []PlacementRule `yaml:",omitempty" json:",omitempty"` + Limits []Limit `yaml:",omitempty" json:",omitempty"` + Preemption PartitionPreemptionConfig `yaml:",omitempty" json:",omitempty"` + NodeSortPolicy NodeSortingPolicy `yaml:",omitempty" json:",omitempty"` + UserGroupResolver UserGroupResolver `yaml:",omitempty" json:",omitempty"` +} + +type UserGroupResolver struct { + Type string `yaml:"type,omitempty" json:"type,omitempty"` } // The partition preemption configuration diff --git a/pkg/common/configs/config_test.go b/pkg/common/configs/config_test.go index 307a44799..8c8bda195 100644 --- a/pkg/common/configs/config_test.go +++ b/pkg/common/configs/config_test.go @@ -2181,3 +2181,53 @@ partitions: _, err = CreateConfig(data) assert.ErrorContains(t, err, "group * max resource map[memory:90000 vcore:100000] of queue leaf is greater than immediate or ancestor parent maximum resource map[memory:10000 vcore:10000000]") } + +// TestUserGroupResolverConfig: tests the user group resolver configuration +func TestUserGroupResolverConfig(t *testing.T) { + data := ` +partitions: + - + name: default + usergroupresolver: + type: ldap + placementrules: + - name: tag + value: namespace + create: true + queues: + - name: root + submitacl: '*' + properties: + application.sort.policy: fifo + sample: value2 +` + // validate the config and check after the update + config, err := CreateConfig(data) + assert.NilError(t, err) + + // check if the user group resolver is set correctly + assert.Equal(t, "ldap", config.Partitions[0].UserGroupResolver.Type) + + // partition with no user group resolver + data = ` +partitions: + - + name: default + placementrules: + - name: tag + value: namespace + create: true + queues: + - name: root + submitacl: '*' + properties: + application.sort.policy: fifo + sample: value2 +` + // validate the config and check after the update + config, err = CreateConfig(data) + assert.NilError(t, err) + + // check if the user group resolver is set to empty + assert.Equal(t, "", config.Partitions[0].UserGroupResolver.Type) +} diff --git a/pkg/common/constants.go b/pkg/common/constants.go index 7eae1c79a..04bc2de38 100644 --- a/pkg/common/constants.go +++ b/pkg/common/constants.go @@ -29,4 +29,32 @@ const ( RecoveryQueue = "@recovery@" RecoveryQueueFull = "root." + RecoveryQueue DefaultPlacementQueue = "root.default" + LdapHost = "Host" + LdapPort = "Port" + LdapBaseDN = "BaseDN" + LdapFilter = "Filter" + LdapGroupAttr = "GroupAttr" + LdapReturnAttr = "ReturnAttr" + LdapBindUser = "BindUser" + LdapBindPassword = "BindPassword" + LdapInsecure = "Insecure" + LdapSSL = "SSL" +) + +const ( + DefaultLdapHost = "localhost" + DefaultLdapPort = 389 + DefaultLdapBaseDN = "dc=example,dc=com" + DefaultLdapFilter = "(&(sAMAccountName=%s))" + DefaultLdapGroupAttr = "memberOf" + DefaultLdapBindUser = "admin" + DefaultLdapBindPassword = "admin" + DefaultLdapInsecure = false + DefaultLdapSSL = false + DefaultLdapUserUID = "1211" +) + +var ( + LdapMountPath = "/run/secrets/ldap" + DefaultLdapReturnAttr = []string{"memberOf"} ) diff --git a/pkg/common/security/ldap_validator.go b/pkg/common/security/ldap_validator.go new file mode 100644 index 000000000..eb95ec336 --- /dev/null +++ b/pkg/common/security/ldap_validator.go @@ -0,0 +1,383 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package security + +import ( + "fmt" + "net" + "regexp" + "strconv" + "strings" + + "go.uber.org/zap" + + "github.com/apache/yunikorn-core/pkg/common" + "github.com/apache/yunikorn-core/pkg/log" +) + +// ValidationLevel defines the severity of validation issues +type ValidationLevel int + +const ( + // ValidationWarning indicates a non-critical issue that allows operation but might cause problems + ValidationWarning ValidationLevel = iota + // ValidationError indicates a critical issue that prevents proper operation + ValidationError +) + +// ValidationIssue represents a single validation problem +type ValidationIssue struct { + Field string + Message string + Level ValidationLevel +} + +// LdapValidator provides validation for LDAP configuration +type LdapValidator struct { + issues []ValidationIssue +} + +// NewLdapValidator creates a new validator instance +func NewLdapValidator() *LdapValidator { + return &LdapValidator{ + issues: make([]ValidationIssue, 0), + } +} + +// ValidateConfig validates the entire LDAP configuration +func (v *LdapValidator) ValidateConfig(config *LdapResolverConfig) bool { + v.validateHost(config.Host) + v.validatePort(config.Port) + v.validateBaseDN(config.BaseDN) + v.validateFilter(config.Filter) + v.validateGroupAttr(config.GroupAttr) + v.validateReturnAttr(config.ReturnAttr) + v.validateBindUser(config.BindUser) + v.validateBindPassword(config.BindPassword) + + // Consistency checks + v.validateConsistency(config) + + // Log all issues + v.logIssues() + + // Return true if no errors (warnings are acceptable) + return !v.hasErrors() +} + +// validateHost validates the LDAP host +func (v *LdapValidator) validateHost(host string) { + if host == "" { + v.addIssue("Host", "Host cannot be empty", ValidationError) + return + } + + // Check if it's an IP address + if net.ParseIP(host) != nil { + return // Valid IP address + } + + // Check if it's a valid hostname + hostnameRegex := regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`) + if !hostnameRegex.MatchString(host) { + v.addIssue("Host", fmt.Sprintf("Invalid hostname format: %s", host), ValidationWarning) + } +} + +// validatePort validates the LDAP port +func (v *LdapValidator) validatePort(port int) { + if port < 1 || port > 65535 { + v.addIssue("Port", fmt.Sprintf("Port must be between 1 and 65535, got: %d", port), ValidationError) + } +} + +// validateBaseDN validates the LDAP base DN +func (v *LdapValidator) validateBaseDN(baseDN string) { + if baseDN == "" { + v.addIssue("BaseDN", "BaseDN cannot be empty", ValidationError) + return + } + + // Check for at least one domain component + if !strings.Contains(strings.ToLower(baseDN), "dc=") { + v.addIssue("BaseDN", "BaseDN should contain at least one domain component (dc=)", ValidationWarning) + } + + // Check for valid DN format + dnRegex := regexp.MustCompile(`^(?:(?:[a-zA-Z0-9]+=[^,]+)(?:,(?:[a-zA-Z0-9]+=[^,]+))*)?$`) + if !dnRegex.MatchString(baseDN) { + v.addIssue("BaseDN", fmt.Sprintf("Invalid DN format: %s", baseDN), ValidationWarning) + } +} + +// validateFilter validates the LDAP filter +func (v *LdapValidator) validateFilter(filter string) { + if filter == "" { + v.addIssue("Filter", "Filter cannot be empty", ValidationError) + return + } + + // Check for username placeholder + if !strings.Contains(filter, "%s") { + v.addIssue("Filter", "Filter must contain '%s' placeholder for username substitution", ValidationError) + } + + // Check for balanced parentheses + if !hasBalancedParentheses(filter) { + v.addIssue("Filter", "Filter has unbalanced parentheses", ValidationError) + } + + // Basic filter format check + filterRegex := regexp.MustCompile(`^\(.*\)$`) + if !filterRegex.MatchString(filter) { + v.addIssue("Filter", "Filter should be enclosed in parentheses", ValidationWarning) + } +} + +// validateGroupAttr validates the LDAP group attribute +func (v *LdapValidator) validateGroupAttr(groupAttr string) { + if groupAttr == "" { + v.addIssue("GroupAttr", "GroupAttr cannot be empty", ValidationError) + return + } + + // Check for valid attribute name format + attrRegex := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\-_]*$`) + if !attrRegex.MatchString(groupAttr) { + v.addIssue("GroupAttr", fmt.Sprintf("Invalid attribute name format: %s", groupAttr), ValidationWarning) + } +} + +// validateReturnAttr validates the LDAP return attributes +func (v *LdapValidator) validateReturnAttr(returnAttr []string) { + if len(returnAttr) == 0 { + v.addIssue("ReturnAttr", "ReturnAttr cannot be empty", ValidationError) + return + } + + attrRegex := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\-_]*$`) + for _, attr := range returnAttr { + if !attrRegex.MatchString(attr) { + v.addIssue("ReturnAttr", fmt.Sprintf("Invalid attribute name format: %s", attr), ValidationWarning) + } + } +} + +// validateBindUser validates the LDAP bind user +func (v *LdapValidator) validateBindUser(bindUser string) { + if bindUser == "" { + v.addIssue("BindUser", "BindUser cannot be empty", ValidationError) + return + } + + // Check if it's a DN format + dnRegex := regexp.MustCompile(`^(?:(?:[a-zA-Z0-9]+=[^,]+)(?:,(?:[a-zA-Z0-9]+=[^,]+))*)?$`) + if dnRegex.MatchString(bindUser) { + return // Valid DN format + } + + // Check if it's a username format + usernameRegex := regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\._\-@]*$`) + if !usernameRegex.MatchString(bindUser) { + v.addIssue("BindUser", fmt.Sprintf("BindUser is neither a valid DN nor a valid username: %s", bindUser), ValidationWarning) + } +} + +// validateBindPassword validates the LDAP bind password +func (v *LdapValidator) validateBindPassword(bindPassword string) { + if bindPassword == "" { + v.addIssue("BindPassword", "BindPassword cannot be empty", ValidationError) + return + } + + // Check for minimum length + if len(bindPassword) < 3 { + v.addIssue("BindPassword", "BindPassword is too short", ValidationWarning) + } + + // We don't check for password complexity here as it depends on the LDAP server policy +} + +// validateConsistency performs cross-field validation +func (v *LdapValidator) validateConsistency(config *LdapResolverConfig) { + // Check SSL and port consistency + if config.SSL && config.Port != 636 { + v.addIssue("Port", fmt.Sprintf("SSL is enabled but port is not the default LDAPS port (636), using: %d", config.Port), ValidationWarning) + } + + // Check SSL and Insecure consistency + if config.SSL && config.Insecure { + v.addIssue("SSL/Insecure", "Both SSL and Insecure are enabled, which may indicate a security misconfiguration", ValidationWarning) + } +} + +// addIssue adds a validation issue to the list +func (v *LdapValidator) addIssue(field, message string, level ValidationLevel) { + v.issues = append(v.issues, ValidationIssue{ + Field: field, + Message: message, + Level: level, + }) +} + +// hasErrors checks if there are any validation errors +func (v *LdapValidator) hasErrors() bool { + for _, issue := range v.issues { + if issue.Level == ValidationError { + return true + } + } + return false +} + +// logIssues logs all validation issues +func (v *LdapValidator) logIssues() { + for _, issue := range v.issues { + if issue.Level == ValidationError { + log.Log(log.Security).Error("LDAP configuration validation error", + zap.String("field", issue.Field), + zap.String("message", issue.Message)) + } else { + log.Log(log.Security).Warn("LDAP configuration validation warning", + zap.String("field", issue.Field), + zap.String("message", issue.Message)) + } + } +} + +// hasBalancedParentheses checks if a string has balanced parentheses +func hasBalancedParentheses(s string) bool { + count := 0 + for _, c := range s { + if c == '(' { + count++ + } else if c == ')' { + count-- + if count < 0 { + return false + } + } + } + return count == 0 +} + +// ValidateSecretValue validates a single secret value based on its key +func ValidateSecretValue(key, value string) (interface{}, error) { + switch key { + case common.LdapHost: + return validateHostValue(value) + case common.LdapPort: + return validatePortValue(value) + case common.LdapBaseDN: + return validateBaseDNValue(value) + case common.LdapFilter: + return validateFilterValue(value) + case common.LdapGroupAttr: + return validateGroupAttrValue(value) + case common.LdapReturnAttr: + return validateReturnAttrValue(value) + case common.LdapBindUser: + return validateBindUserValue(value) + case common.LdapBindPassword: + return validateBindPasswordValue(value) + case common.LdapInsecure, common.LdapSSL: + return validateBoolValue(value) + default: + return nil, fmt.Errorf("unknown LDAP secret key: %s", key) + } +} + +// Individual validation functions for each secret type +func validateHostValue(value string) (string, error) { + if value == "" { + return "", fmt.Errorf("host cannot be empty") + } + return value, nil +} + +func validatePortValue(value string) (int, error) { + port, err := strconv.Atoi(value) + if err != nil { + return 0, fmt.Errorf("invalid port number: %s", err) + } + if port < 1 || port > 65535 { + return 0, fmt.Errorf("port must be between 1 and 65535, got: %d", port) + } + return port, nil +} + +func validateBaseDNValue(value string) (string, error) { + if value == "" { + return "", fmt.Errorf("baseDN cannot be empty") + } + return value, nil +} + +func validateFilterValue(value string) (string, error) { + if value == "" { + return "", fmt.Errorf("filter cannot be empty") + } + if !strings.Contains(value, "%s") { + return "", fmt.Errorf("filter must contain '%%s' placeholder for username substitution") + } + if !hasBalancedParentheses(value) { + return "", fmt.Errorf("filter has unbalanced parentheses") + } + return value, nil +} + +func validateGroupAttrValue(value string) (string, error) { + if value == "" { + return "", fmt.Errorf("groupAttr cannot be empty") + } + return value, nil +} + +func validateReturnAttrValue(value string) ([]string, error) { + if value == "" { + return nil, fmt.Errorf("returnAttr cannot be empty") + } + attrs := strings.Split(value, ",") + if len(attrs) == 0 { + return nil, fmt.Errorf("returnAttr must contain at least one attribute") + } + return attrs, nil +} + +func validateBindUserValue(value string) (string, error) { + if value == "" { + return "", fmt.Errorf("bindUser cannot be empty") + } + return value, nil +} + +func validateBindPasswordValue(value string) (string, error) { + if value == "" { + return "", fmt.Errorf("bindPassword cannot be empty") + } + return value, nil +} + +func validateBoolValue(value string) (bool, error) { + boolValue, err := strconv.ParseBool(value) + if err != nil { + return false, fmt.Errorf("invalid boolean value: %s", err) + } + return boolValue, nil +} diff --git a/pkg/common/security/ldap_validator_test.go b/pkg/common/security/ldap_validator_test.go new file mode 100644 index 000000000..e39c875af --- /dev/null +++ b/pkg/common/security/ldap_validator_test.go @@ -0,0 +1,984 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package security + +import ( + "testing" + + "gotest.tools/v3/assert" + + "github.com/apache/yunikorn-core/pkg/common" +) + +func TestValidateHostValue(t *testing.T) { + tests := []struct { + name string + value string + wantError bool + }{ + {"Valid hostname", "ldap.example.com", false}, + {"Valid IP", "192.168.1.1", false}, + {"Empty", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := validateHostValue(tt.value) + if tt.wantError { + assert.Assert(t, err != nil) + } else { + assert.NilError(t, err) + } + }) + } +} + +func TestValidatePortValue(t *testing.T) { + tests := []struct { + name string + value string + wantError bool + }{ + {"Valid port", "389", false}, + {"Valid port range", "65535", false}, + {"Invalid port - too high", "65536", true}, + {"Invalid port - too low", "0", true}, + {"Invalid port - not a number", "abc", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := validatePortValue(tt.value) + if tt.wantError { + assert.Assert(t, err != nil) + } else { + assert.NilError(t, err) + } + }) + } +} + +func TestValidateFilterValue(t *testing.T) { + tests := []struct { + name string + value string + wantError bool + }{ + {"Valid filter", "(&(objectClass=user)(sAMAccountName=%s))", false}, + {"Missing placeholder", "(&(objectClass=user)(sAMAccountName=user))", true}, + {"Unbalanced parentheses", "(&(objectClass=user)(sAMAccountName=%s)", true}, + {"Empty", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := validateFilterValue(tt.value) + if tt.wantError { + assert.Assert(t, err != nil) + } else { + assert.NilError(t, err) + } + }) + } +} + +func TestValidateReturnAttrValue(t *testing.T) { + tests := []struct { + name string + value string + wantError bool + }{ + {"Valid single attr", "memberOf", false}, + {"Valid multiple attrs", "memberOf,cn,mail", false}, + {"Empty", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + attrs, err := validateReturnAttrValue(tt.value) + if tt.wantError { + assert.Assert(t, err != nil) + } else { + assert.NilError(t, err) + if tt.value == "memberOf" { + assert.Equal(t, 1, len(attrs)) + assert.Equal(t, "memberOf", attrs[0]) + } else if tt.value == "memberOf,cn,mail" { + assert.Equal(t, 3, len(attrs)) + assert.Equal(t, "memberOf", attrs[0]) + assert.Equal(t, "cn", attrs[1]) + assert.Equal(t, "mail", attrs[2]) + } + } + }) + } +} + +func TestValidateBoolValue(t *testing.T) { + tests := []struct { + name string + value string + expected bool + wantError bool + }{ + {"Valid true", "true", true, false}, + {"Valid false", "false", false, false}, + {"Valid 1", "1", true, false}, + {"Valid 0", "0", false, false}, + {"Invalid", "notabool", false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := validateBoolValue(tt.value) + if tt.wantError { + assert.Assert(t, err != nil) + } else { + assert.NilError(t, err) + assert.Equal(t, tt.expected, val) + } + }) + } +} + +func TestValidateConfig(t *testing.T) { + tests := []struct { + name string + config *LdapResolverConfig + expected bool + }{ + { + name: "Valid configuration", + config: &LdapResolverConfig{ + Host: "ldap.example.com", + Port: 389, + BaseDN: "dc=example,dc=com", + Filter: "(&(objectClass=user)(sAMAccountName=%s))", + GroupAttr: "memberOf", + ReturnAttr: []string{"memberOf"}, + BindUser: "cn=admin,dc=example,dc=com", + BindPassword: "password", + Insecure: false, + SSL: false, + }, + expected: true, + }, + { + name: "Invalid configuration - empty fields", + config: &LdapResolverConfig{ + Host: "", + Port: 0, + BaseDN: "", + Filter: "invalid-filter", + GroupAttr: "", + ReturnAttr: []string{}, + BindUser: "", + BindPassword: "", + Insecure: true, + SSL: true, + }, + expected: false, + }, + { + name: "Invalid configuration - missing placeholder in filter", + config: &LdapResolverConfig{ + Host: "ldap.example.com", + Port: 389, + BaseDN: "dc=example,dc=com", + Filter: "(&(objectClass=user)(sAMAccountName=user))", // Missing %s placeholder + GroupAttr: "memberOf", + ReturnAttr: []string{"memberOf"}, + BindUser: "cn=admin,dc=example,dc=com", + BindPassword: "password", + Insecure: false, + SSL: false, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + result := validator.ValidateConfig(tt.config) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestValidateBaseDN tests the validateBaseDN method with various inputs +func TestValidateBaseDN(t *testing.T) { + tests := []struct { + name string + baseDN string + expectWarning bool + expectError bool + }{ + { + name: "Valid BaseDN", + baseDN: "dc=example,dc=com", + expectWarning: false, + expectError: false, + }, + { + name: "Empty BaseDN", + baseDN: "", + expectWarning: false, + expectError: true, + }, + { + name: "Invalid format - missing value", + baseDN: "dc=,dc=com", + expectWarning: true, + expectError: false, + }, + { + name: "Invalid format - missing equals", + baseDN: "dcexample,dc=com", + expectWarning: true, + expectError: false, + }, + { + name: "Invalid format - unbalanced commas", + baseDN: "dc=example,dc=com,", + expectWarning: true, + expectError: false, + }, + { + name: "No domain component", + baseDN: "cn=admin,ou=users", + expectWarning: true, + expectError: false, + }, + { + name: "Invalid format - extra comma", + baseDN: "dc=example,,dc=com", + expectWarning: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateBaseDN(tt.baseDN) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "BaseDN" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +// TestValidateGroupAttr tests the validateGroupAttr method with various inputs +func TestValidateGroupAttr(t *testing.T) { + tests := []struct { + name string + groupAttr string + expectWarning bool + expectError bool + }{ + { + name: "Valid attribute name", + groupAttr: "memberOf", + expectWarning: false, + expectError: false, + }, + { + name: "Empty attribute name", + groupAttr: "", + expectWarning: false, + expectError: true, + }, + { + name: "Invalid format - starts with number", + groupAttr: "1memberOf", + expectWarning: true, + expectError: false, + }, + { + name: "Invalid format - special characters", + groupAttr: "member@Of", + expectWarning: true, + expectError: false, + }, + { + name: "Valid with hyphen", + groupAttr: "member-of", + expectWarning: false, + expectError: false, + }, + { + name: "Valid with underscore", + groupAttr: "member_of", + expectWarning: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateGroupAttr(tt.groupAttr) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "GroupAttr" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +// TestValidateBindUser tests the validateBindUser method with various inputs +func TestValidateBindUser(t *testing.T) { + tests := []struct { + name string + bindUser string + expectWarning bool + expectError bool + }{ + { + name: "Valid DN format", + bindUser: "cn=admin,dc=example,dc=com", + expectWarning: false, + expectError: false, + }, + { + name: "Valid username format", + bindUser: "admin", + expectWarning: false, + expectError: false, + }, + { + name: "Valid username with domain", + bindUser: "admin@example.com", + expectWarning: false, + expectError: false, + }, + { + name: "Empty bind user", + bindUser: "", + expectWarning: false, + expectError: true, + }, + { + name: "Invalid format - special characters", + bindUser: "admin!#$%", + expectWarning: true, + expectError: false, + }, + { + name: "Invalid DN format", + bindUser: "cn=admin,dc=example,=com", + expectWarning: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateBindUser(tt.bindUser) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "BindUser" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +// TestValidateBindPassword tests the validateBindPassword method with various inputs +func TestValidateBindPassword(t *testing.T) { + tests := []struct { + name string + bindPassword string + expectWarning bool + expectError bool + }{ + { + name: "Valid password", + bindPassword: "password123", + expectWarning: false, + expectError: false, + }, + { + name: "Empty password", + bindPassword: "", + expectWarning: false, + expectError: true, + }, + { + name: "Very short password", + bindPassword: "a", + expectWarning: true, + expectError: false, + }, + { + name: "Password with special characters", + bindPassword: "p@ssw0rd!", + expectWarning: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateBindPassword(tt.bindPassword) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "BindPassword" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +// TestValidateReturnAttr tests the validateReturnAttr method with various inputs +func TestValidateReturnAttr(t *testing.T) { + tests := []struct { + name string + returnAttr []string + expectWarning bool + expectError bool + }{ + { + name: "Valid single attribute", + returnAttr: []string{"memberOf"}, + expectWarning: false, + expectError: false, + }, + { + name: "Valid multiple attributes", + returnAttr: []string{"memberOf", "cn", "mail"}, + expectWarning: false, + expectError: false, + }, + { + name: "Empty array", + returnAttr: []string{}, + expectWarning: false, + expectError: true, + }, + { + name: "Invalid attribute name", + returnAttr: []string{"member@Of"}, + expectWarning: true, + expectError: false, + }, + { + name: "Mix of valid and invalid", + returnAttr: []string{"memberOf", "123invalid"}, + expectWarning: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateReturnAttr(tt.returnAttr) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "ReturnAttr" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +// TestValidateHost tests the validateHost method with various inputs +func TestValidateHost(t *testing.T) { + tests := []struct { + name string + host string + expectWarning bool + expectError bool + }{ + { + name: "Valid hostname", + host: "ldap.example.com", + expectWarning: false, + expectError: false, + }, + { + name: "Valid IP address", + host: "192.168.1.1", + expectWarning: false, + expectError: false, + }, + { + name: "Empty host", + host: "", + expectWarning: false, + expectError: true, + }, + { + name: "Invalid hostname - starts with hyphen", + host: "-ldap.example.com", + expectWarning: true, + expectError: false, + }, + { + name: "Invalid hostname - contains invalid characters", + host: "ldap_example.com", + expectWarning: true, + expectError: false, + }, + { + name: "Invalid hostname - double dots", + host: "ldap..example.com", + expectWarning: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateHost(tt.host) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "Host" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +// TestValidateFilter tests the validateFilter method with various inputs +func TestValidateFilter(t *testing.T) { + tests := []struct { + name string + filter string + expectWarning bool + expectError bool + }{ + { + name: "Valid filter", + filter: "(&(objectClass=user)(sAMAccountName=%s))", + expectWarning: false, + expectError: false, + }, + { + name: "Empty filter", + filter: "", + expectWarning: false, + expectError: true, + }, + { + name: "Missing placeholder", + filter: "(&(objectClass=user)(sAMAccountName=user))", + expectWarning: false, + expectError: true, + }, + { + name: "Unbalanced parentheses", + filter: "(&(objectClass=user)(sAMAccountName=%s)", + expectWarning: false, + expectError: true, + }, + { + name: "Not enclosed in parentheses", + filter: "objectClass=user&sAMAccountName=%s", + expectWarning: true, + expectError: false, // This filter does contain the %s placeholder, so it's not an error + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateFilter(tt.filter) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "Filter" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +func TestValidateSecretValue(t *testing.T) { + tests := []struct { + name string + key string + value string + wantError bool + }{ + {"Valid host", common.LdapHost, "ldap.example.com", false}, + {"Valid port", common.LdapPort, "389", false}, + {"Valid baseDN", common.LdapBaseDN, "dc=example,dc=com", false}, + {"Valid filter", common.LdapFilter, "(&(objectClass=user)(sAMAccountName=%s))", false}, + {"Valid groupAttr", common.LdapGroupAttr, "memberOf", false}, + {"Valid returnAttr", common.LdapReturnAttr, "memberOf,cn", false}, + {"Valid bindUser", common.LdapBindUser, "cn=admin,dc=example,dc=com", false}, + {"Valid bindPassword", common.LdapBindPassword, "password", false}, + {"Valid insecure", common.LdapInsecure, "true", false}, + {"Valid SSL", common.LdapSSL, "false", false}, + {"Invalid key", "unknown", "value", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ValidateSecretValue(tt.key, tt.value) + if tt.wantError { + assert.Assert(t, err != nil) + } else { + assert.NilError(t, err) + } + }) + } +} + +func TestHasBalancedParentheses(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + {"Empty string", "", true}, + {"Simple balanced", "()", true}, + {"Nested balanced", "((()))", true}, + {"Complex balanced", "(a(b)c(d(e)f)g)", true}, + {"Unbalanced - too many open", "(()", false}, + {"Unbalanced - too many closed", "())", false}, + {"Unbalanced - wrong order", ")(", false}, + {"No parentheses", "abc", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := hasBalancedParentheses(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestLdapValidatorAddIssue(t *testing.T) { + validator := NewLdapValidator() + + // Add a warning + validator.addIssue("TestField", "Test warning", ValidationWarning) + + // Add an error + validator.addIssue("TestField2", "Test error", ValidationError) + + // Check that issues were added + assert.Equal(t, 2, len(validator.issues)) + assert.Equal(t, "TestField", validator.issues[0].Field) + assert.Equal(t, "Test warning", validator.issues[0].Message) + assert.Equal(t, ValidationWarning, validator.issues[0].Level) + assert.Equal(t, "TestField2", validator.issues[1].Field) + assert.Equal(t, "Test error", validator.issues[1].Message) + assert.Equal(t, ValidationError, validator.issues[1].Level) + + // Check hasErrors + assert.Assert(t, validator.hasErrors()) +} + +func TestLdapValidatorValidateConsistency(t *testing.T) { + tests := []struct { + name string + config *LdapResolverConfig + expectWarning bool + }{ + { + name: "No warnings", + config: &LdapResolverConfig{ + SSL: false, + Insecure: false, + Port: 389, + }, + expectWarning: false, + }, + { + name: "SSL with non-standard port", + config: &LdapResolverConfig{ + SSL: true, + Insecure: false, + Port: 389, + }, + expectWarning: true, + }, + { + name: "SSL with insecure", + config: &LdapResolverConfig{ + SSL: true, + Insecure: true, + Port: 636, + }, + expectWarning: true, + }, + { + name: "SSL with standard port", + config: &LdapResolverConfig{ + SSL: true, + Insecure: false, + Port: 636, + }, + expectWarning: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validateConsistency(tt.config) + + hasWarnings := len(validator.issues) > 0 + assert.Equal(t, tt.expectWarning, hasWarnings) + }) + } +} + +// TestValidatePort tests the validatePort method with various inputs +func TestValidatePort(t *testing.T) { + tests := []struct { + name string + port int + expectWarning bool + expectError bool + }{ + { + name: "Valid port - LDAP", + port: 389, + expectWarning: false, + expectError: false, + }, + { + name: "Valid port - LDAPS", + port: 636, + expectWarning: false, + expectError: false, + }, + { + name: "Valid port - custom", + port: 1389, + expectWarning: false, + expectError: false, + }, + { + name: "Invalid port - too low", + port: 0, + expectWarning: false, + expectError: true, + }, + { + name: "Invalid port - too high", + port: 65536, + expectWarning: false, + expectError: true, + }, + { + name: "Valid port - minimum", + port: 1, + expectWarning: false, + expectError: false, + }, + { + name: "Valid port - maximum", + port: 65535, + expectWarning: false, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator := NewLdapValidator() + validator.validatePort(tt.port) + + hasError := false + hasWarning := false + + for _, issue := range validator.issues { + if issue.Field == "Port" { + if issue.Level == ValidationError { + hasError = true + } else if issue.Level == ValidationWarning { + hasWarning = true + } + } + } + + assert.Equal(t, tt.expectError, hasError, "Expected error: %v, got: %v", tt.expectError, hasError) + assert.Equal(t, tt.expectWarning, hasWarning, "Expected warning: %v, got: %v", tt.expectWarning, hasWarning) + }) + } +} + +// TestLogIssues tests the logIssues method +func TestLogIssues(t *testing.T) { + validator := NewLdapValidator() + + // Add a warning + validator.addIssue("TestField1", "Test warning message", ValidationWarning) + + // Add an error + validator.addIssue("TestField2", "Test error message", ValidationError) + + // Call logIssues - we can't easily capture the log output in a unit test, + // but we can at least verify it doesn't panic + validator.logIssues() + + // Verify the issues are still present after logging + assert.Equal(t, 2, len(validator.issues)) + assert.Equal(t, "TestField1", validator.issues[0].Field) + assert.Equal(t, ValidationWarning, validator.issues[0].Level) + assert.Equal(t, "Test warning message", validator.issues[0].Message) + assert.Equal(t, "TestField2", validator.issues[1].Field) + assert.Equal(t, ValidationError, validator.issues[1].Level) + assert.Equal(t, "Test error message", validator.issues[1].Message) +} + +// TestValidateBindUserValue tests the validateBindUserValue function +func TestValidateBindUserValue(t *testing.T) { + tests := []struct { + name string + value string + wantError bool + }{ + {"Valid DN", "cn=admin,dc=example,dc=com", false}, + {"Valid username", "admin", false}, + {"Empty", "", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := validateBindUserValue(tt.value) + if tt.wantError { + assert.Assert(t, err != nil) + } else { + assert.NilError(t, err) + } + }) + } +} + +// TestValidateBaseDNValueEdgeCases tests edge cases for validateBaseDNValue +func TestValidateBaseDNValueEdgeCases(t *testing.T) { + // Test empty value + _, err := validateBaseDNValue("") + assert.Assert(t, err != nil) + assert.ErrorContains(t, err, "baseDN cannot be empty") +} + +// TestValidateGroupAttrValueEdgeCases tests edge cases for validateGroupAttrValue +func TestValidateGroupAttrValueEdgeCases(t *testing.T) { + // Test empty value + _, err := validateGroupAttrValue("") + assert.Assert(t, err != nil) + assert.ErrorContains(t, err, "groupAttr cannot be empty") +} + +// TestValidateReturnAttrValueEdgeCases tests edge cases for validateReturnAttrValue +func TestValidateReturnAttrValueEdgeCases(t *testing.T) { + // Test empty value + _, err := validateReturnAttrValue("") + assert.Assert(t, err != nil) + assert.ErrorContains(t, err, "returnAttr cannot be empty") +} + +// TestValidateBindPasswordValueEdgeCases tests edge cases for validateBindPasswordValue +func TestValidateBindPasswordValueEdgeCases(t *testing.T) { + // Test empty value + _, err := validateBindPasswordValue("") + assert.Assert(t, err != nil) + assert.ErrorContains(t, err, "bindPassword cannot be empty") +} diff --git a/pkg/common/security/usergroup.go b/pkg/common/security/usergroup.go index d9a1966c7..37155db91 100644 --- a/pkg/common/security/usergroup.go +++ b/pkg/common/security/usergroup.go @@ -66,20 +66,32 @@ type UserGroup struct { resolved int64 } +const ( + Default = "" + Ldap = "ldap" + Test = "test" + Os = "os" +) + // Get the resolver for the user and group info. // Current setup allows three resolvers: // * NO resolver: default, no user or group resolution just return the info (k8s use case) // * OS resolver: uses the OS libraries to resolve user and group memberships // * Test resolver: fake resolution for testing -func GetUserGroupCache(resolver string) *UserGroupCache { +// * Ldap resolver: uses the LDAP protocol to resolve user and group memberships +func GetUserGroupCache(ugr configs.UserGroupResolver) *UserGroupCache { + resolver := ugr.Type once.Do(func() { switch resolver { - case "test": + case Test: log.Log(log.Security).Info("creating test user group resolver") instance = GetUserGroupCacheTest() - case "os": + case Os: log.Log(log.Security).Info("creating OS user group resolver") instance = GetUserGroupCacheOS() + case Ldap: + log.Log(log.Security).Info("creating LDAP user group resolver") + instance = GetUserGroupCacheLdap() default: log.Log(log.Security).Info("creating UserGroupCache without resolver") instance = GetUserGroupNoResolve() @@ -231,6 +243,10 @@ func (c *UserGroupCache) Stop() { if !stopped.Load() { log.Log(log.Security).Info("Stopping UserGroupCache background cleanup") close(c.stop) + // Clear the cache before resetting the instance + c.lock.Lock() + c.ugs = make(map[string]*UserGroup) + c.lock.Unlock() once = &sync.Once{} // re-init so that GetUserGroupCache() can create a new instance again instance = nil stopped.Store(true) diff --git a/pkg/common/security/usergroup_ldap_resolver.go b/pkg/common/security/usergroup_ldap_resolver.go new file mode 100644 index 000000000..818b2c7ad --- /dev/null +++ b/pkg/common/security/usergroup_ldap_resolver.go @@ -0,0 +1,375 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package security + +import ( + "crypto/tls" + "fmt" + "os" + "os/user" + "path/filepath" + "strings" + "time" + + "go.uber.org/zap" + + "github.com/go-ldap/ldap/v3" + + "github.com/apache/yunikorn-core/pkg/common" + "github.com/apache/yunikorn-core/pkg/log" +) + +// This file contains the implementation of the LDAP resolver for user groups + +// LdapAccess defines the interface for LDAP operations +type LdapAccess interface { + // DialURL establishes a connection to the LDAP server + DialURL(url string, options ...ldap.DialOpt) (*ldap.Conn, error) + + // Bind authenticates with the LDAP server + Bind(conn *ldap.Conn, username, password string) error + + // Search performs an LDAP search operation + Search(conn *ldap.Conn, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) + + // Close closes the LDAP connection + Close(conn *ldap.Conn) +} + +// LdapAccessImpl implements the LdapAccess interface with real LDAP operations +type LdapAccessImpl struct{} + +func (l *LdapAccessImpl) DialURL(url string, options ...ldap.DialOpt) (*ldap.Conn, error) { + return ldap.DialURL(url, options...) +} + +func (l *LdapAccessImpl) Bind(conn *ldap.Conn, username, password string) error { + return conn.Bind(username, password) +} + +func (l *LdapAccessImpl) Search(conn *ldap.Conn, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + return conn.Search(searchRequest) +} + +func (l *LdapAccessImpl) Close(conn *ldap.Conn) { + conn.Close() +} + +// ldapAccessFactory is a function type that creates LdapAccess instances +type ldapAccessFactory func(config *LdapResolverConfig) LdapAccess + +// defaultLdapAccessFactory is the default factory function that creates real LdapAccessImpl instances +var defaultLdapAccessFactory ldapAccessFactory = func(config *LdapResolverConfig) LdapAccess { + return &LdapAccessImpl{} +} + +// newLdapAccessImpl creates a new LdapAccess instance using the current factory +// This can be replaced in tests to return mock implementations +var newLdapAccessImpl = defaultLdapAccessFactory + +// resetLdapAccessFactory resets the factory to the default implementation +// This is used in tests to ensure the global state is restored +func resetLdapAccessFactory() { + newLdapAccessImpl = defaultLdapAccessFactory +} + +// LDAPResolverConfig holds the configuration for the LDAP resolver +type LdapResolverConfig struct { + Host string + Port int + BaseDN string + Filter string + GroupAttr string + ReturnAttr []string + BindUser string + BindPassword string + Insecure bool + SSL bool +} + +// Default values for the LDAP resolver +var ldapConf = LdapResolverConfig{ + Host: common.DefaultLdapHost, + Port: common.DefaultLdapPort, + BaseDN: common.DefaultLdapBaseDN, + Filter: common.DefaultLdapFilter, + GroupAttr: common.DefaultLdapGroupAttr, + ReturnAttr: common.DefaultLdapReturnAttr, + BindUser: common.DefaultLdapBindUser, + BindPassword: common.DefaultLdapBindPassword, + Insecure: common.DefaultLdapInsecure, + SSL: common.DefaultLdapSSL, +} + +// read secrets from the secrets directory +// returns true if at least one secret was loaded and the configuration is valid, false otherwise +var readSecrets = func() bool { + secretsDir := common.LdapMountPath + + // Read all files from secrets directory + files, err := os.ReadDir(secretsDir) + if err != nil { + log.Log(log.Security).Error("Unable to access LDAP secrets directory", + zap.String("directory", secretsDir), + zap.Error(err)) + return false + } + + secretCount := 0 + validSecrets := make(map[string]interface{}) + + // Iterate over all secret files in the secrets directory + for _, file := range files { + fileName := file.Name() + + // Skip non-secret entries such as Kubernetes internal metadata (e.g., symlinks like "..data" or directories like "..timestamp") + if strings.HasPrefix(fileName, "..") || file.IsDir() { + log.Log(log.Security).Info("Ignoring non-secret entry (Kubernetes metadata entry or directory)", + zap.String("name", fileName)) + continue + } + + secretKey := fileName + secretValueBytes, err := os.ReadFile(filepath.Join(secretsDir, secretKey)) + if err != nil { + log.Log(log.Security).Warn("Could not read secret file", + zap.String("file", secretKey), + zap.Error(err)) + continue + } + secretValue := strings.TrimSpace(string(secretValueBytes)) + + // Validate the secret value + validatedValue, err := ValidateSecretValue(secretKey, secretValue) + if err != nil { + log.Log(log.Security).Warn("Invalid LDAP secret value", + zap.String("key", secretKey), + zap.Error(err)) + continue + } + + // Store the validated value + validSecrets[secretKey] = validatedValue + secretCount++ + + log.Log(log.Security).Debug("Loaded LDAP secret", + zap.String("key", secretKey)) + } + + // Apply validated values to the configuration + if host, ok := validSecrets[common.LdapHost].(string); ok { + ldapConf.Host = host + } + if port, ok := validSecrets[common.LdapPort].(int); ok { + ldapConf.Port = port + } + if baseDN, ok := validSecrets[common.LdapBaseDN].(string); ok { + ldapConf.BaseDN = baseDN + } + if filter, ok := validSecrets[common.LdapFilter].(string); ok { + ldapConf.Filter = filter + } + if groupAttr, ok := validSecrets[common.LdapGroupAttr].(string); ok { + ldapConf.GroupAttr = groupAttr + } + if returnAttr, ok := validSecrets[common.LdapReturnAttr].([]string); ok { + ldapConf.ReturnAttr = returnAttr + } + if bindUser, ok := validSecrets[common.LdapBindUser].(string); ok { + ldapConf.BindUser = bindUser + } + if bindPassword, ok := validSecrets[common.LdapBindPassword].(string); ok { + ldapConf.BindPassword = bindPassword + } + if insecure, ok := validSecrets[common.LdapInsecure].(bool); ok { + ldapConf.Insecure = insecure + } + if ssl, ok := validSecrets[common.LdapSSL].(bool); ok { + ldapConf.SSL = ssl + } + + // Validate the entire configuration + validator := NewLdapValidator() + isValid := validator.ValidateConfig(&ldapConf) + + // Check if all required fields were provided in the secrets + requiredFields := []string{ + common.LdapHost, + common.LdapPort, + common.LdapBaseDN, + common.LdapFilter, + common.LdapGroupAttr, + common.LdapReturnAttr, + common.LdapBindUser, + common.LdapBindPassword, + } + + missingFields := []string{} + for _, field := range requiredFields { + if _, ok := validSecrets[field]; !ok { + missingFields = append(missingFields, field) + } + } + + if len(missingFields) > 0 { + log.Log(log.Security).Error("Missing required LDAP configuration fields", + zap.Strings("missingFields", missingFields)) + isValid = false + } + + log.Log(log.Security).Info("Finished loading LDAP secrets", + zap.Int("numberOfSecretsLoaded", secretCount), + zap.Bool("configurationValid", isValid), + zap.Int("missingRequiredFields", len(missingFields))) + + return secretCount > 0 && isValid && len(missingFields) == 0 +} + +func GetUserGroupCacheLdap() *UserGroupCache { + secretsLoaded := readSecrets() + + if !secretsLoaded { + // Log a FATAL level message - this is very prominent and will typically cause the application to exit + log.Log(log.Security).Fatal("LDAP configuration not found or invalid. No secrets were loaded from the secrets directory.", + zap.String("secretsPath", common.LdapMountPath), + zap.String("resolution", "Ensure LDAP secrets are properly mounted and accessible")) + + // If the Fatal log doesn't cause an exit (depends on logger configuration), + // we could also panic here to ensure the application stops + panic("LDAP configuration not found or invalid") + } + + return &UserGroupCache{ + ugs: map[string]*UserGroup{}, + interval: cleanerInterval * time.Second, + lookup: LdapLookupUser, + lookupGroupID: LdapLookupGroupID, + groupIds: LDAPLookupGroupIds, + stop: make(chan struct{}), + } +} + +// Default linux behaviour: a user is member of the primary group with the same name +func LdapLookupUser(userName string) (*user.User, error) { + log.Log(log.Security).Debug("Performing LDAP user lookup", + zap.String("username", userName), + zap.String("defaultUID", common.DefaultLdapUserUID)) + return &user.User{ + Uid: common.DefaultLdapUserUID, + Gid: userName, + Username: userName, + }, nil +} + +func LdapLookupGroupID(gid string) (*user.Group, error) { + log.Log(log.Security).Debug("Looking up LDAP group ID", + zap.String("groupID", gid)) + group := user.Group{Gid: gid} + group.Name = gid + return &group, nil +} + +func LDAPLookupGroupIds(osUser *user.User) ([]string, error) { + ldapAccess := newLdapAccessImpl(&ldapConf) + sr, err := LdapSearch(ldapAccess, osUser.Username) + if err != nil { + log.Log(log.Security).Error("Failed to connect to LDAP for group lookup", + zap.String("user", osUser.Username), + zap.Error(err)) + return nil, err + } + + var groups []string + for _, entry := range sr.Entries { + attr := entry.GetAttributeValues("memberOf") + log.Log(log.Security).Debug("LDAP 'memberOf' attributes for user", + zap.String("user", osUser.Username), + zap.Strings("attributes", attr)) + for i := range attr { + s := strings.Split(attr[i], ",") + newgroup := strings.Split(s[0], "CN=") + groups = append(groups, newgroup[1]) + } + } + return groups, nil +} + +// LdapSearch performs an LDAP search for the specified username +// This replaces the old LDAPConn_Bind function with a more testable approach +func LdapSearch(ldapAccess LdapAccess, userName string) (*ldap.SearchResult, error) { + var LDAP_URI string + if ldapConf.SSL { + LDAP_URI = "ldaps" + } else { + LDAP_URI = "ldap" + } + + ldapaddr := fmt.Sprintf("%s://%s:%d", LDAP_URI, ldapConf.Host, ldapConf.Port) + log.Log(log.Security).Debug("Attempting LDAP connection", + zap.String("address", ldapaddr), + zap.Bool("ssl", ldapConf.SSL), + zap.Bool("insecureSkipVerify", ldapConf.Insecure)) + + l, err := ldapAccess.DialURL(ldapaddr, + ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify: ldapConf.Insecure})) // #nosec G402 + if err != nil { + log.Log(log.Security).Error("Error connecting to LDAP server", + zap.String("address", ldapaddr), + zap.Error(err)) + return nil, err + } + defer ldapAccess.Close(l) + + log.Log(log.Security).Debug("LDAP connection successful, attempting bind", + zap.String("bindUser", ldapConf.BindUser)) + err = ldapAccess.Bind(l, ldapConf.BindUser, ldapConf.BindPassword) + if err != nil { + log.Log(log.Security).Error("Failed to bind with LDAP server", + zap.String("bindDN", ldapConf.BindUser), + zap.Error(err)) + return nil, err + } + + filter := fmt.Sprintf(ldapConf.Filter, userName) + log.Log(log.Security).Debug("Executing LDAP search", + zap.String("baseDN", ldapConf.BaseDN), + zap.String("filter", filter), + zap.Strings("attributesToReturn", ldapConf.ReturnAttr)) + + searchRequest := ldap.NewSearchRequest( + ldapConf.BaseDN, + ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false, + filter, + ldapConf.ReturnAttr, + nil, + ) + sr, err := ldapAccess.Search(l, searchRequest) + if err != nil { + log.Log(log.Security).Error("Failed to execute LDAP search query", + zap.String("filter", filter), + zap.String("baseDN", ldapConf.BaseDN), + zap.Error(err)) + return nil, err + } + + log.Log(log.Security).Debug("LDAP search completed successfully", + zap.String("username", userName), + zap.Int("entriesFound", len(sr.Entries))) + return sr, nil +} diff --git a/pkg/common/security/usergroup_ldap_resolver_test.go b/pkg/common/security/usergroup_ldap_resolver_test.go new file mode 100644 index 000000000..0e3880165 --- /dev/null +++ b/pkg/common/security/usergroup_ldap_resolver_test.go @@ -0,0 +1,818 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package security + +import ( + "errors" + "fmt" + "os" + "os/user" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "gotest.tools/v3/assert" + + "github.com/go-ldap/ldap/v3" + + "github.com/apache/yunikorn-core/pkg/common" +) + +// GetUserGroupCacheLdapMock returns a UserGroupCache with mocked LDAP functions for testing +func GetUserGroupCacheLdapMock() *UserGroupCache { + return &UserGroupCache{ + ugs: map[string]*UserGroup{}, + interval: time.Second, + lookup: mockLdapLookupUser, + lookupGroupID: mockLdapLookupGroupID, + groupIds: mockLDAPLookupGroupIds, + stop: make(chan struct{}), + } +} + +// mockLdapLookupUser is a mock implementation of LdapLookupUser for testing +func mockLdapLookupUser(userName string) (*user.User, error) { + // Similar to the test resolver, but with LDAP-specific behavior + if userName == Testuser1 || userName == Testuser { + return &user.User{ + Uid: "1000", + Gid: "1000", + Username: userName, + }, nil + } + if userName == Testuser2 { + return &user.User{ + Uid: "100", + Gid: "100", + Username: "testuser2", + }, nil + } + if userName == Testuser3 { + return &user.User{ + Uid: "1001", + Gid: "1001", + Username: "testuser3", + }, nil + } + if userName == Testuser4 { + return &user.User{ + Uid: "901", + Gid: "901", + Username: "testuser4", + }, nil + } + if userName == Testuser5 { + return &user.User{ + Uid: "1001", + Gid: "1001", + Username: "testuser5", + }, nil + } + if userName == "invalid-gid-user" { + return &user.User{ + Uid: "1001", + Gid: "1_001", + Username: "invalid-gid-user", + }, nil + } + // All other users fail + return nil, fmt.Errorf("lookup failed for user: %s", userName) +} + +// mockLdapLookupGroupID is a mock implementation of LdapLookupGroupID for testing +func mockLdapLookupGroupID(gid string) (*user.Group, error) { + // For testing, we'll use a simple pattern + if gid == "100" { + return nil, fmt.Errorf("lookup failed for group: %s", gid) + } + // Special case for invalid-gid-user + if gid == "1_001" { + return nil, fmt.Errorf("lookup failed for group: %s", gid) + } + group := user.Group{Gid: gid} + group.Name = "group" + gid + return &group, nil +} + +// mockLDAPLookupGroupIds is a mock implementation of LDAPLookupGroupIds for testing +func mockLDAPLookupGroupIds(osUser *user.User) ([]string, error) { + if osUser.Username == Testuser1 || osUser.Username == Testuser { + return []string{"1001"}, nil + } + if osUser.Username == Testuser2 { + return []string{"1001", "1002"}, nil + } + // Group list might return primary group ID also + if osUser.Username == Testuser3 { + return []string{"1002", "1001", "1003", "1004"}, nil + } + if osUser.Username == Testuser4 { + return []string{"901", "902"}, nil + } + return nil, fmt.Errorf("lookup failed for user: %s", osUser.Username) +} + +// Mock LDAP search result for testing +func mockLdapSearchResult(username string) (*ldap.SearchResult, error) { + if username == Testuser1 || username == Testuser { + return &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + Attributes: []*ldap.EntryAttribute{ + { + Name: "memberOf", + Values: []string{"CN=group1001,OU=groups,DC=example,DC=com"}, + }, + }, + }, + }, + }, nil + } + if username == Testuser2 { + return &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + Attributes: []*ldap.EntryAttribute{ + { + Name: "memberOf", + Values: []string{"CN=group1001,OU=groups,DC=example,DC=com", "CN=group1002,OU=groups,DC=example,DC=com"}, + }, + }, + }, + }, + }, nil + } + if username == Testuser3 { + return &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + Attributes: []*ldap.EntryAttribute{ + { + Name: "memberOf", + Values: []string{"CN=group1002,OU=groups,DC=example,DC=com", "CN=group1001,OU=groups,DC=example,DC=com", "CN=group1003,OU=groups,DC=example,DC=com", "CN=group1004,OU=groups,DC=example,DC=com"}, + }, + }, + }, + }, + }, nil + } + if username == Testuser4 { + return &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + Attributes: []*ldap.EntryAttribute{ + { + Name: "memberOf", + Values: []string{"CN=group901,OU=groups,DC=example,DC=com", "CN=group902,OU=groups,DC=example,DC=com"}, + }, + }, + }, + }, + }, nil + } + return nil, fmt.Errorf("ldap lookup failed for user: %s", username) +} + +// LdapAccessMock implements the LdapAccess interface for testing +type LdapAccessMock struct { + DialURLFunc func(url string, options ...ldap.DialOpt) (*ldap.Conn, error) + BindFunc func(conn *ldap.Conn, username, password string) error + SearchFunc func(conn *ldap.Conn, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) + CloseFunc func(conn *ldap.Conn) + SearchResult *ldap.SearchResult + Error error +} + +func (m *LdapAccessMock) DialURL(url string, options ...ldap.DialOpt) (*ldap.Conn, error) { + if m.DialURLFunc != nil { + return m.DialURLFunc(url, options...) + } + return &ldap.Conn{}, nil +} + +func (m *LdapAccessMock) Bind(conn *ldap.Conn, username, password string) error { + if m.BindFunc != nil { + return m.BindFunc(conn, username, password) + } + return nil +} + +func (m *LdapAccessMock) Search(conn *ldap.Conn, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + if m.SearchFunc != nil { + return m.SearchFunc(conn, searchRequest) + } + return m.SearchResult, m.Error +} + +func (m *LdapAccessMock) Close(conn *ldap.Conn) { + if m.CloseFunc != nil { + m.CloseFunc(conn) + } +} + +// Helper function to create a mock LDAP access with predefined search results +func newMockLdapAccess(searchResult *ldap.SearchResult, err error) *LdapAccessMock { + return &LdapAccessMock{ + SearchResult: searchResult, + Error: err, + } +} + +// TestLdapSearch tests the new LdapSearch function with a mock LdapAccess +func TestLdapSearch(t *testing.T) { + // Create a mock search result + mockResult := &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + Attributes: []*ldap.EntryAttribute{ + { + Name: "memberOf", + Values: []string{"CN=group1,OU=groups,DC=example,DC=com", "CN=group2,OU=groups,DC=example,DC=com"}, + }, + }, + }, + }, + } + + // Create a mock LDAP access with the mock result + mockAccess := newMockLdapAccess(mockResult, nil) + + // Call LdapSearch with the mock access + result, err := LdapSearch(mockAccess, "testuser") + + // Verify results + assert.NilError(t, err) + assert.Assert(t, result != nil) + assert.Equal(t, 1, len(result.Entries)) + assert.Equal(t, 1, len(result.Entries[0].Attributes)) + assert.Equal(t, "memberOf", result.Entries[0].Attributes[0].Name) + assert.Equal(t, 2, len(result.Entries[0].Attributes[0].Values)) + assert.Equal(t, "CN=group1,OU=groups,DC=example,DC=com", result.Entries[0].Attributes[0].Values[0]) + assert.Equal(t, "CN=group2,OU=groups,DC=example,DC=com", result.Entries[0].Attributes[0].Values[1]) +} + +// TestLdapSearchError tests the error handling in LdapSearch +func TestLdapSearchError(t *testing.T) { + // Test cases for different error scenarios + testCases := []struct { + name string + dialError error + bindError error + searchError error + }{ + { + name: "Dial Error", + dialError: errors.New("dial error"), + bindError: nil, + searchError: nil, + }, + { + name: "Bind Error", + dialError: nil, + bindError: errors.New("bind error"), + searchError: nil, + }, + { + name: "Search Error", + dialError: nil, + bindError: nil, + searchError: errors.New("search error"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create a mock LDAP access with the appropriate error + mockAccess := &LdapAccessMock{ + DialURLFunc: func(url string, options ...ldap.DialOpt) (*ldap.Conn, error) { + return &ldap.Conn{}, tc.dialError + }, + BindFunc: func(conn *ldap.Conn, username, password string) error { + return tc.bindError + }, + SearchFunc: func(conn *ldap.Conn, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + return nil, tc.searchError + }, + } + + // Call LdapSearch with the mock access + result, err := LdapSearch(mockAccess, "testuser") + + // Verify error + assert.Assert(t, err != nil) + assert.Assert(t, result == nil) + + // Check for specific error + switch { + case tc.dialError != nil: + assert.Equal(t, tc.dialError.Error(), err.Error()) + case tc.bindError != nil: + assert.Equal(t, tc.bindError.Error(), err.Error()) + case tc.searchError != nil: + assert.Equal(t, tc.searchError.Error(), err.Error()) + } + }) + } +} + +func TestLdapLookups(t *testing.T) { + tests := []struct { + name string + testType string + id string + validate func(t *testing.T, result interface{}, err error) + }{ + { + name: "Lookup user", + testType: "user", + id: "testuser", + validate: func(t *testing.T, result interface{}, err error) { + assert.NilError(t, err) + u, ok := result.(*user.User) + assert.Assert(t, ok, "invalid result type") + assert.Equal(t, "testuser", u.Username) + assert.Equal(t, "testuser", u.Gid) + assert.Equal(t, "1211", u.Uid) + }, + }, + { + name: "Lookup group", + testType: "group", + id: "testgroup", + validate: func(t *testing.T, result interface{}, err error) { + assert.NilError(t, err) + g, ok := result.(*user.Group) + assert.Assert(t, ok, "invalid result type") + assert.Equal(t, "testgroup", g.Gid) + assert.Equal(t, "testgroup", g.Name) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.testType == "user" { + u, err := LdapLookupUser(tt.id) + tt.validate(t, u, err) + } else if tt.testType == "group" { + g, err := LdapLookupGroupID(tt.id) + tt.validate(t, g, err) + } + }) + } +} + +func TestLDAPLookupGroupIds(t *testing.T) { + // Save the original newLdapAccessImpl function and ensure it's restored + defer resetLdapAccessFactory() + + // Create a mock search result + mockResult := &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + Attributes: []*ldap.EntryAttribute{ + { + Name: "memberOf", + Values: []string{"CN=group1,OU=groups,DC=example,DC=com", "CN=group2,OU=groups,DC=example,DC=com"}, + }, + }, + }, + }, + } + + // Mock the newLdapAccessImpl function to return our mock + mockFactory := func(config *LdapResolverConfig) LdapAccess { + return newMockLdapAccess(mockResult, nil) + } + + // Replace the factory function + newLdapAccessImpl = mockFactory + + u := &user.User{Username: "testuser"} + groups, err := LDAPLookupGroupIds(u) + assert.NilError(t, err) + assert.Assert(t, strings.Contains(strings.Join(groups, ","), "group1")) + assert.Assert(t, strings.Contains(strings.Join(groups, ","), "group2")) +} + +func TestLDAPLookupGroupIdsError(t *testing.T) { + // Ensure we restore the original factory at the end of the test + defer resetLdapAccessFactory() + + // Mock the newLdapAccessImpl function to return our mock with an error + mockFactory := func(config *LdapResolverConfig) LdapAccess { + return newMockLdapAccess(nil, errors.New("ldap error")) + } + + // Replace the factory function + newLdapAccessImpl = mockFactory + + u := &user.User{Username: "testuser"} + groups, err := LDAPLookupGroupIds(u) + assert.Error(t, err, "ldap error") + assert.Assert(t, groups == nil) +} + +// Helper to reset ldapConf to defaults before each test +func resetLdapConfDefaults() { + ldapConf = LdapResolverConfig{ + Host: common.DefaultLdapHost, + Port: common.DefaultLdapPort, + BaseDN: common.DefaultLdapBaseDN, + Filter: common.DefaultLdapFilter, + GroupAttr: common.DefaultLdapGroupAttr, + ReturnAttr: common.DefaultLdapReturnAttr, + BindUser: common.DefaultLdapBindUser, + BindPassword: common.DefaultLdapBindPassword, + Insecure: common.DefaultLdapInsecure, + SSL: common.DefaultLdapSSL, + } +} + +//nolint:funlen // Table-driven test for coverage, helpers used to reduce length +func TestReadSecrets(t *testing.T) { + tests := getReadSecretsTestCases() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetLdapConfDefaults() + _, cleanup := tt.setupFunc(t) + defer cleanup() + result := readSecrets() + assert.Equal(t, tt.expectedResult, result) + tt.validateFunc(t) + }) + } +} + +//nolint:funlen // Table-driven test helper for coverage, intentionally long +func getReadSecretsTestCases() []struct { + name string + setupFunc func(t *testing.T) (string, func()) + expectedResult bool + validateFunc func(t *testing.T) +} { + return []struct { + name string + setupFunc func(t *testing.T) (string, func()) + expectedResult bool + validateFunc func(t *testing.T) + }{ + { + name: "Skips K8s metadata and directories", + setupFunc: func(t *testing.T) (string, func()) { + tmpDir := t.TempDir() + err := os.Mkdir(filepath.Join(tmpDir, "..data"), 0755) + assert.NilError(t, err) + err = os.Mkdir(filepath.Join(tmpDir, "dir1"), 0755) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "key1"), []byte("value1"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, "..timestamp"), []byte("meta"), 0600) + assert.NilError(t, err) + origLdapMountPath := common.LdapMountPath + common.LdapMountPath = tmpDir + return tmpDir, func() { common.LdapMountPath = origLdapMountPath } + }, + expectedResult: false, + validateFunc: func(t *testing.T) { + assert.Equal(t, common.DefaultLdapHost, ldapConf.Host) + assert.Equal(t, common.DefaultLdapPort, ldapConf.Port) + assert.Equal(t, common.DefaultLdapBaseDN, ldapConf.BaseDN) + assert.Equal(t, common.DefaultLdapFilter, ldapConf.Filter) + assert.Equal(t, common.DefaultLdapGroupAttr, ldapConf.GroupAttr) + assert.Equal(t, strings.Join(common.DefaultLdapReturnAttr, ","), strings.Join(ldapConf.ReturnAttr, ",")) + assert.Equal(t, common.DefaultLdapBindUser, ldapConf.BindUser) + assert.Equal(t, common.DefaultLdapBindPassword, ldapConf.BindPassword) + assert.Equal(t, common.DefaultLdapInsecure, ldapConf.Insecure) + assert.Equal(t, common.DefaultLdapSSL, ldapConf.SSL) + }, + }, + { + name: "Handles missing secrets directory", + setupFunc: func(t *testing.T) (string, func()) { + origLdapMountPath := common.LdapMountPath + common.LdapMountPath = "/nonexistent" + return "/nonexistent", func() { common.LdapMountPath = origLdapMountPath } + }, + expectedResult: false, + validateFunc: func(t *testing.T) { + assert.Equal(t, common.DefaultLdapHost, ldapConf.Host) + assert.Equal(t, common.DefaultLdapPort, ldapConf.Port) + assert.Equal(t, common.DefaultLdapBaseDN, ldapConf.BaseDN) + assert.Equal(t, common.DefaultLdapFilter, ldapConf.Filter) + assert.Equal(t, common.DefaultLdapGroupAttr, ldapConf.GroupAttr) + assert.Equal(t, strings.Join(common.DefaultLdapReturnAttr, ","), strings.Join(ldapConf.ReturnAttr, ",")) + assert.Equal(t, common.DefaultLdapBindUser, ldapConf.BindUser) + assert.Equal(t, common.DefaultLdapBindPassword, ldapConf.BindPassword) + assert.Equal(t, common.DefaultLdapInsecure, ldapConf.Insecure) + assert.Equal(t, common.DefaultLdapSSL, ldapConf.SSL) + }, + }, + { + name: "Handles unknown key", + setupFunc: func(t *testing.T) (string, func()) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, "unknownKey"), []byte("somevalue"), 0600) + assert.NilError(t, err) + origLdapMountPath := common.LdapMountPath + common.LdapMountPath = tmpDir + return tmpDir, func() { common.LdapMountPath = origLdapMountPath } + }, + expectedResult: false, + validateFunc: func(t *testing.T) { + assert.Equal(t, common.DefaultLdapHost, ldapConf.Host) + assert.Equal(t, common.DefaultLdapPort, ldapConf.Port) + assert.Equal(t, common.DefaultLdapBaseDN, ldapConf.BaseDN) + assert.Equal(t, common.DefaultLdapFilter, ldapConf.Filter) + assert.Equal(t, common.DefaultLdapGroupAttr, ldapConf.GroupAttr) + assert.Equal(t, strings.Join(common.DefaultLdapReturnAttr, ","), strings.Join(ldapConf.ReturnAttr, ",")) + assert.Equal(t, common.DefaultLdapBindUser, ldapConf.BindUser) + assert.Equal(t, common.DefaultLdapBindPassword, ldapConf.BindPassword) + assert.Equal(t, common.DefaultLdapInsecure, ldapConf.Insecure) + assert.Equal(t, common.DefaultLdapSSL, ldapConf.SSL) + }, + }, + { + name: "Handles invalid port and bool values", + setupFunc: func(t *testing.T) (string, func()) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, common.LdapPort), []byte("notanint"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapInsecure), []byte("notabool"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapSSL), []byte("notabool"), 0600) + assert.NilError(t, err) + origLdapMountPath := common.LdapMountPath + common.LdapMountPath = tmpDir + return tmpDir, func() { common.LdapMountPath = origLdapMountPath } + }, + expectedResult: false, + validateFunc: func(t *testing.T) { + // Assert that ldapConf.Port is set to DefaultLdapPort when invalid int value is provided + assert.Equal(t, common.DefaultLdapPort, ldapConf.Port) + + // Assert that rest of ldap conf is set to default values + assert.Equal(t, common.DefaultLdapHost, ldapConf.Host) + assert.Equal(t, common.DefaultLdapBaseDN, ldapConf.BaseDN) + assert.Equal(t, common.DefaultLdapFilter, ldapConf.Filter) + assert.Equal(t, common.DefaultLdapGroupAttr, ldapConf.GroupAttr) + assert.Equal(t, strings.Join(common.DefaultLdapReturnAttr, ","), strings.Join(ldapConf.ReturnAttr, ",")) + assert.Equal(t, common.DefaultLdapBindUser, ldapConf.BindUser) + assert.Equal(t, common.DefaultLdapBindPassword, ldapConf.BindPassword) + assert.Equal(t, common.DefaultLdapInsecure, ldapConf.Insecure) + assert.Equal(t, common.DefaultLdapSSL, ldapConf.SSL) + }, + }, + { + name: "Sets custom values", + setupFunc: func(t *testing.T) (string, func()) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, common.LdapHost), []byte("myhost"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapPort), []byte("1234"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapBaseDN), []byte("dc=test,dc=com"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapFilter), []byte("(&(uid=%s))"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapGroupAttr), []byte("groups"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapReturnAttr), []byte("memberOf,groups"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapBindUser), []byte("binduser"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapBindPassword), []byte("bindpass"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapInsecure), []byte("true"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapSSL), []byte("true"), 0600) + assert.NilError(t, err) + origLdapMountPath := common.LdapMountPath + common.LdapMountPath = tmpDir + return tmpDir, func() { common.LdapMountPath = origLdapMountPath } + }, + expectedResult: true, + validateFunc: func(t *testing.T) { + assert.Equal(t, "myhost", ldapConf.Host) + + // Use strconv to verify the port value to ensure the import is used + portStr := "1234" + expectedPort, err := strconv.Atoi(portStr) + assert.NilError(t, err, "failed to convert port string to int") + assert.Equal(t, expectedPort, ldapConf.Port) + + assert.Equal(t, "dc=test,dc=com", ldapConf.BaseDN) + assert.Equal(t, "(&(uid=%s))", ldapConf.Filter) + assert.Equal(t, "groups", ldapConf.GroupAttr) + assert.Equal(t, "memberOf,groups", strings.Join(ldapConf.ReturnAttr, ",")) + assert.Equal(t, "binduser", ldapConf.BindUser) + assert.Equal(t, "bindpass", ldapConf.BindPassword) + + // Use strconv to verify boolean values + insecureStr := "true" + expectedInsecure, err := strconv.ParseBool(insecureStr) + assert.NilError(t, err, "failed to convert insecure string to bool") + assert.Equal(t, expectedInsecure, ldapConf.Insecure) + + sslStr := "true" + expectedSSL, err := strconv.ParseBool(sslStr) + assert.NilError(t, err, "failed to convert ssl string to bool") + assert.Equal(t, expectedSSL, ldapConf.SSL) + }, + }, + { + name: "Missing required fields", + setupFunc: func(t *testing.T) (string, func()) { + tmpDir := t.TempDir() + err := os.WriteFile(filepath.Join(tmpDir, common.LdapHost), []byte("ldap.example.com"), 0600) + assert.NilError(t, err) + err = os.WriteFile(filepath.Join(tmpDir, common.LdapPort), []byte("389"), 0600) + assert.NilError(t, err) + // Missing BaseDN, Filter, GroupAttr, ReturnAttr, BindUser, BindPassword + origLdapMountPath := common.LdapMountPath + common.LdapMountPath = tmpDir + return tmpDir, func() { common.LdapMountPath = origLdapMountPath } + }, + expectedResult: false, + validateFunc: func(t *testing.T) { + // No specific validation needed - we're testing the return value + }, + }, + { + name: "All required fields present", + setupFunc: func(t *testing.T) (string, func()) { + tmpDir := t.TempDir() + requiredFields := map[string]string{ + common.LdapHost: "ldap.example.com", + common.LdapPort: "389", + common.LdapBaseDN: "dc=example,dc=com", + common.LdapFilter: "(&(objectClass=user)(sAMAccountName=%s))", + common.LdapGroupAttr: "memberOf", + common.LdapReturnAttr: "memberOf", + common.LdapBindUser: "cn=admin,dc=example,dc=com", + common.LdapBindPassword: "password", + } + + for key, value := range requiredFields { + err := os.WriteFile(filepath.Join(tmpDir, key), []byte(value), 0600) + if err != nil { + t.Fatalf("failed to write file %s: %v", key, err) + } + } + + origLdapMountPath := common.LdapMountPath + common.LdapMountPath = tmpDir + return tmpDir, func() { common.LdapMountPath = origLdapMountPath } + }, + expectedResult: true, + validateFunc: func(t *testing.T) { + // No specific validation needed - we're testing the return value + }, + }, + } +} + +func TestUserGroupCacheLdap(t *testing.T) { + tests := []struct { + name string + validateFunc func(t *testing.T, cache *UserGroupCache) + }{ + { + name: "Cache initialization", + validateFunc: func(t *testing.T, cache *UserGroupCache) { + assert.Assert(t, cache != nil) + assert.Assert(t, cache.ugs != nil) + assert.Assert(t, cache.lookup != nil) + assert.Assert(t, cache.lookupGroupID != nil) + assert.Assert(t, cache.groupIds != nil) + }, + }, + { + name: "Cache interval", + validateFunc: func(t *testing.T, cache *UserGroupCache) { + interval := cache.interval + expectedInterval := cleanerInterval * time.Second // 60 seconds + assert.Equal(t, expectedInterval, interval, "LDAP resolver interval should be 60 seconds") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save original function and restore after test + origReadSecrets := readSecrets + + // Mock readSecrets to return true (successful configuration) + readSecrets = func() bool { + return true + } + defer func() { readSecrets = origReadSecrets }() + + // Get the LDAP user group cache + cache := GetUserGroupCacheLdap() + + // Run the validation function + tt.validateFunc(t, cache) + }) + } +} + +func TestMockLdapSearchResult(t *testing.T) { + // Test valid users + testCases := []struct { + username string + expectedCount int + expectError bool + }{ + {"testuser1", 1, false}, + {"testuser", 1, false}, + {"testuser2", 2, false}, + {"testuser3", 4, false}, + {"testuser4", 2, false}, + {"unknown", 0, true}, + } + + for _, tc := range testCases { + t.Run(tc.username, func(t *testing.T) { + result, err := mockLdapSearchResult(tc.username) + + if tc.expectError { + assert.Assert(t, err != nil, "Expected error for user %s but got none", tc.username) + assert.Assert(t, result == nil, "Expected nil result for user %s but got %v", tc.username, result) + assert.ErrorContains(t, err, "ldap lookup failed for user: "+tc.username) + } else { + assert.NilError(t, err, "Unexpected error for user %s: %v", tc.username, err) + assert.Assert(t, result != nil, "Expected non-nil result for user %s", tc.username) + assert.Assert(t, len(result.Entries) > 0, "Expected entries for user %s", tc.username) + assert.Assert(t, len(result.Entries[0].Attributes) > 0, "Expected attributes for user %s", tc.username) + + memberOfAttr := result.Entries[0].Attributes[0] + assert.Equal(t, "memberOf", memberOfAttr.Name, "Expected 'memberOf' attribute for user %s", tc.username) + assert.Equal(t, tc.expectedCount, len(memberOfAttr.Values), + "Expected %d group values for user %s but got %d", + tc.expectedCount, tc.username, len(memberOfAttr.Values)) + } + }) + } +} + +func TestLdapAccessImpl(t *testing.T) { + // Create a mock LDAP access implementation + mockAccess := &LdapAccessMock{ + DialURLFunc: func(url string, options ...ldap.DialOpt) (*ldap.Conn, error) { + return &ldap.Conn{}, nil + }, + BindFunc: func(conn *ldap.Conn, username, password string) error { + return nil + }, + SearchFunc: func(conn *ldap.Conn, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + return &ldap.SearchResult{}, nil + }, + CloseFunc: func(conn *ldap.Conn) {}, + } + + assert.Assert(t, mockAccess != nil) + conn, err := mockAccess.DialURL("testurl") + assert.NilError(t, err) + assert.Assert(t, conn != nil) + assert.NilError(t, mockAccess.Bind(&ldap.Conn{}, "user", "pass")) + result, err := mockAccess.Search(&ldap.Conn{}, &ldap.SearchRequest{}) + assert.NilError(t, err) + assert.Assert(t, result != nil) + mockAccess.Close(&ldap.Conn{}) +} + +// TestLdapAccessImplMethods tests the LdapAccessImpl methods +func TestLdapAccessImplMethods(t *testing.T) { + // Create a real implementation + impl := &LdapAccessImpl{} + + // We can't actually connect to an LDAP server in unit tests, + // but we can verify the methods don't panic when called with nil + + // Test DialURL - should return error with invalid URL + conn, err := impl.DialURL("invalid://url") + assert.Assert(t, err != nil) + assert.Assert(t, conn == nil) + + // Other methods would panic if called with nil, so we can't test them directly + // In a real scenario, we'd use a mock LDAP server or dependency injection +} diff --git a/pkg/common/security/usergroup_no_resolver_test.go b/pkg/common/security/usergroup_no_resolver_test.go new file mode 100644 index 000000000..1e78eb07e --- /dev/null +++ b/pkg/common/security/usergroup_no_resolver_test.go @@ -0,0 +1,100 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package security + +import ( + "os/user" + "testing" + + "gotest.tools/v3/assert" +) + +func TestNoLookupUser(t *testing.T) { + // Test with various usernames + testCases := []struct { + name string + username string + }{ + {"Empty username", ""}, + {"Standard username", "testuser"}, + {"Username with special chars", "test-user_123"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + u, err := noLookupUser(tc.username) + + // Should never return error + assert.NilError(t, err) + + // Verify user properties + assert.Equal(t, tc.username, u.Username) + assert.Equal(t, "-1", u.Uid) + assert.Equal(t, tc.username, u.Gid) + }) + } +} + +func TestNoLookupGroupID(t *testing.T) { + // Test with various group IDs + testCases := []struct { + name string + gid string + }{ + {"Empty GID", ""}, + {"Numeric GID", "1000"}, + {"String GID", "users"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + g, err := noLookupGroupID(tc.gid) + + // Should never return error + assert.NilError(t, err) + + // Verify group properties + assert.Equal(t, tc.gid, g.Gid) + assert.Equal(t, tc.gid, g.Name) + }) + } +} + +func TestNoLookupGroupIds(t *testing.T) { + // Test with various users + testCases := []struct { + name string + user *user.User + }{ + {"Standard user", &user.User{Username: "testuser", Uid: "1000", Gid: "1000"}}, + {"Empty user", &user.User{}}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + groups, err := noLookupGroupIds(tc.user) + + // Should never return error + assert.NilError(t, err) + + // Should always return empty slice + assert.Equal(t, 0, len(groups)) + }) + } +} diff --git a/pkg/common/security/usergroup_os_resolver_test.go b/pkg/common/security/usergroup_os_resolver_test.go new file mode 100644 index 000000000..acd7d5ab9 --- /dev/null +++ b/pkg/common/security/usergroup_os_resolver_test.go @@ -0,0 +1,44 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package security + +import ( + "os/user" + "testing" +) + +func TestWrappedGroupIds(t *testing.T) { + // Create a mock user + // Note: This test will behave differently depending on the system + // We'll just verify it doesn't panic and returns the expected type + u := &user.User{ + Username: "testuser", + Uid: "1000", + Gid: "1000", + } + + // Call the function - we can't predict the exact result + groups, err := wrappedGroupIds(u) + + // Log the result for informational purposes + t.Logf("Groups: %v, Error: %v", groups, err) + + // We can only verify the function doesn't panic + // The actual result depends on the OS and user configuration +} diff --git a/pkg/common/security/usergroup_test.go b/pkg/common/security/usergroup_test.go index 6c0d9f457..b0ee83cf0 100644 --- a/pkg/common/security/usergroup_test.go +++ b/pkg/common/security/usergroup_test.go @@ -27,10 +27,67 @@ import ( "gotest.tools/v3/assert" + "github.com/go-ldap/ldap/v3" + "github.com/apache/yunikorn-core/pkg/common" + "github.com/apache/yunikorn-core/pkg/common/configs" "github.com/apache/yunikorn-scheduler-interface/lib/go/si" ) +// Helper function to set up the mock LDAP implementation for testing +func setupMockLdap() { + // Save the original newLdapAccessImpl function + originalLdapAccessImpl := newLdapAccessImpl + + // Replace with mock implementation + newLdapAccessImpl = func(config *LdapResolverConfig) LdapAccess { + // Use the mockLdapSearchResult function from usergroup_ldap_resolver_mock.go + return &LdapAccessMock{ + SearchFunc: func(conn *ldap.Conn, searchRequest *ldap.SearchRequest) (*ldap.SearchResult, error) { + // Extract username from the search filter + username := "" + if searchRequest != nil && searchRequest.Filter != "" { + // Simple extraction - this assumes the filter format is consistent + parts := strings.Split(searchRequest.Filter, "=") + if len(parts) > 1 { + username = strings.TrimRight(parts[len(parts)-1], ")") + } + } + return mockLdapSearchResult(username) + }, + } + } + + // Mock readSecrets to return true (successful configuration) + originalReadSecrets := readSecrets + readSecrets = func() bool { + return true + } + + // Store the original functions to be restored in teardown + originalFunctions["newLdapAccessImpl"] = originalLdapAccessImpl + originalFunctions["readSecrets"] = originalReadSecrets +} + +// Helper function to tear down the mock LDAP implementation after testing +func teardownMockLdap() { + // Restore the original functions + if originalImpl, ok := originalFunctions["newLdapAccessImpl"]; ok { + if factory, ok := originalImpl.(ldapAccessFactory); ok { + newLdapAccessImpl = factory + } + } + + if originalRead, ok := originalFunctions["readSecrets"]; ok { + if readFunc, ok := originalRead.(func() bool); ok { + readSecrets = readFunc + } + } +} + +// Map to store original functions for restoration +var originalFunctions = make(map[string]interface{}) + func (c *UserGroupCache) getUGsize() int { c.lock.RLock() defer c.lock.RUnlock() @@ -50,308 +107,592 @@ func (c *UserGroupCache) getUGmap() map[string]*UserGroup { return c.ugs } -func TestGetUserGroupCache(t *testing.T) { - // get the cache with the test resolver set - testCache := GetUserGroupCache("test") - assert.Assert(t, testCache != nil, "Cache create failed") - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - testCache.Stop() - assert.Assert(t, instance == nil, "instance should be nil") - assert.Assert(t, stopped.Load()) - - // get the cache with the os resolver set - testCache = GetUserGroupCache("os") - assert.Assert(t, testCache != nil, "Cache create failed") - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - testCache.Stop() - assert.Assert(t, instance == nil, "instance should be nil") - assert.Assert(t, stopped.Load()) - - // get the cache with the default resolver set - testCache = GetUserGroupCache("unknown") - assert.Assert(t, testCache != nil, "Cache create failed") - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - testCache.Stop() - assert.Assert(t, instance == nil, "instance should be nil") - assert.Assert(t, stopped.Load()) - - // test for re stop again - testCache.Stop() - assert.Assert(t, instance == nil, "instance should be nil") - assert.Assert(t, stopped.Load()) +// UserGroupResolver Config for the test +var testResolver = configs.UserGroupResolver{ + Type: "test", } -func TestGetUserGroup(t *testing.T) { - testCache := GetUserGroupCache("test") - testCache.resetCache() - // test cache should be empty now - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - ugi := &si.UserGroupInformation{ - User: "testuser1", - Groups: nil, - } - ug, err := testCache.GetUserGroup(ugi.User) - assert.NilError(t, err, "Lookup should not have failed: testuser1") - if ug.failed { - t.Errorf("lookup failed which should not have: %t", ug.failed) - } - if len(testCache.ugs) != 1 { - t.Errorf("Cache not updated should have 1 entry %d", len(testCache.ugs)) - } - // check returned info: primary and secondary groups etc - if ug.User != ugi.User || len(ug.Groups) != 2 || ug.resolved == 0 || ug.failed { - t.Errorf("User 'testuser1' not resolved correctly: %v", ug) - } - testCache.lock.Lock() - cachedUG := testCache.ugs[ugi.User] - if ug.resolved != cachedUG.resolved { - t.Errorf("User 'testuser1' not cached correctly resolution time differs: %d got %d", ug.resolved, cachedUG.resolved) - } - // click over the clock: if we do not get the cached version the new time will differ from the cache update - cachedUG.resolved -= 5 - testCache.lock.Unlock() - - ug, err = testCache.GetUserGroup(ugi.User) - if err != nil || ug.resolved != cachedUG.resolved { - t.Errorf("User 'testuser1' not returned from Cache, resolution time differs: %d got %d (err = %v)", ug.resolved, cachedUG.resolved, err) - } +// UserGroupResolver Config for the os resolver +var osResolver = configs.UserGroupResolver{ + Type: "os", } -func TestBrokenUserGroup(t *testing.T) { - testCache := GetUserGroupCache("test") - testCache.resetCache() - // test cache should be empty now - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - ug, err := testCache.GetUserGroup("testuser2") - if err != nil { - t.Error("Lookup should not have failed: testuser2") - } +// UserGroupResolver Config for the unknown resolver +var unknownResolver = configs.UserGroupResolver{ + Type: "unknown", +} - assert.Equal(t, 1, testCache.getUGsize(), "Cache not updated should have 1 entry %d", testCache.getUGmap()) - // check returned info: 3 groups etc - if ug.User != "testuser2" || len(ug.Groups) != 3 || ug.resolved == 0 || ug.failed { - t.Errorf("User 'testuser2' not resolved correctly: %v", ug) - } - // first group should have failed resolution: just the ID expected - if ug.Groups[0] != "100" { - t.Errorf("User 'testuser2' primary group resolved while it should not: %v", ug) - } +// UserGroupResolver Config for the LDAP resolver +var ldapResolver = configs.UserGroupResolver{ + Type: "ldap", +} - ug, err = testCache.GetUserGroup("testuser3") - if err != nil { - t.Error("Lookup should not have failed: testuser3") +func TestGetUserGroupCache(t *testing.T) { + testCases := []struct { + name string + resolver configs.UserGroupResolver + setup func() + teardown func() + }{ + { + name: "TestResolver", + resolver: testResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "OsResolver", + resolver: osResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "UnknownResolver", + resolver: unknownResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "LdapResolver", + resolver: ldapResolver, + setup: setupMockLdap, + teardown: teardownMockLdap, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the test environment + tc.setup() + defer tc.teardown() + + // Get the cache with the resolver set + testCache := GetUserGroupCache(tc.resolver) + assert.Assert(t, testCache != nil, "Cache create failed") + assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + + testCache.Stop() + assert.Assert(t, instance == nil, "instance should be nil") + assert.Assert(t, stopped.Load()) + + // Test for re-stop + testCache.Stop() + assert.Assert(t, instance == nil, "instance should be nil") + assert.Assert(t, stopped.Load()) + }) } - - assert.Equal(t, 2, testCache.getUGsize(), "Cache not updated should have 2 entries %d", len(testCache.ugs)) - assert.Equal(t, 4, testCache.getUGGroupSize("testuser3"), "User 'testuser3' not resolved correctly: duplicate primary group not filtered %v", ug) - - ug, err = testCache.GetUserGroup("unknown") - assert.ErrorContains(t, err, "lookup failed for user: unknown") - - ug, err = testCache.GetUserGroup("testuser4") - assert.NilError(t, err) - - ug, err = testCache.GetUserGroup("testuser5") - assert.ErrorContains(t, err, "lookup failed for user: testuser5") - - ug, err = testCache.GetUserGroup("invalid-gid-user") - assert.ErrorContains(t, err, "lookup failed for user: invalid-gid-user") - exceptedGroup := []string{"1_001"} - assert.Assert(t, reflect.DeepEqual(ug.Groups, exceptedGroup), fmt.Errorf("group should be: %v, but got: %v", exceptedGroup, ug.Groups)) } -func TestGetUserGroupFail(t *testing.T) { - testCache := GetUserGroupCache("test") - testCache.resetCache() - // test cache should be empty now - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - // resolve an empty user - ug, err := testCache.GetUserGroup("") - if err == nil { - t.Error("Lookup should have failed: empty user") - } - // ug is empty everything should be nil.. - if ug.User != "" || len(ug.Groups) != 0 || ug.resolved != 0 || ug.failed { - t.Errorf("UserGroup is not empty: %v", ug) - } +// Tests for the LDAP resolver using the mock implementation - // resolve a non existing user - ugi := &si.UserGroupInformation{ - User: "unknown", - Groups: nil, - } - ug, err = testCache.GetUserGroup(ugi.User) - if err == nil { - t.Error("Lookup should have failed: unknown user") - } - // ug is partially filled and failed flag is set - if ug.User != ugi.User || len(ug.Groups) != 0 || !ug.failed { - t.Errorf("UserGroup is not empty: %v", ug) +func TestGetUserGroup(t *testing.T) { + testCases := []struct { + name string + resolver configs.UserGroupResolver + setup func() + teardown func() + }{ + { + name: "TestResolver", + resolver: testResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "OsResolver", + resolver: osResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "UnknownResolver", + resolver: unknownResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "LdapResolver", + resolver: ldapResolver, + setup: setupMockLdap, + teardown: teardownMockLdap, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the test environment + tc.setup() + defer tc.teardown() + + testCache := GetUserGroupCache(tc.resolver) + testCache.resetCache() + // test cache should be empty now + assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + ugi := &si.UserGroupInformation{ + User: "testuser1", + Groups: nil, + } + ug, err := testCache.GetUserGroup(ugi.User) + assert.NilError(t, err, "Lookup should not have failed: testuser1") + if ug.failed { + t.Errorf("lookup failed which should not have: %t", ug.failed) + } + if len(testCache.ugs) != 1 { + t.Errorf("Cache not updated should have 1 entry %d", len(testCache.ugs)) + } + // check returned info: primary and secondary groups etc + if ug.User != ugi.User || len(ug.Groups) != 2 || ug.resolved == 0 || ug.failed { + t.Errorf("User 'testuser1' not resolved correctly: %v", ug) + } + testCache.lock.Lock() + cachedUG := testCache.ugs[ugi.User] + if ug.resolved != cachedUG.resolved { + t.Errorf("User 'testuser1' not cached correctly resolution time differs: %d got %d", ug.resolved, cachedUG.resolved) + } + // click over the clock: if we do not get the cached version the new time will differ from the cache update + cachedUG.resolved -= 5 + testCache.lock.Unlock() + + ug, err = testCache.GetUserGroup(ugi.User) + if err != nil || ug.resolved != cachedUG.resolved { + t.Errorf("User 'testuser1' not returned from Cache, resolution time differs: %d got %d (err = %v)", ug.resolved, cachedUG.resolved, err) + } + }) } +} - ug, err = testCache.GetUserGroup(ugi.User) - if err == nil { - t.Error("Lookup should have failed: unknown user") - } - // ug is partially filled and failed flag is set: error message should show that the cache was returned - if err != nil && !strings.Contains(err.Error(), "cached data returned") { - t.Errorf("UserGroup not returned from Cache: %v, error: %v", ug, err) +func TestBrokenUserGroup(t *testing.T) { + testCases := []struct { + name string + resolver configs.UserGroupResolver + setup func() + teardown func() + }{ + { + name: "TestResolver", + resolver: testResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "OsResolver", + resolver: osResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "UnknownResolver", + resolver: unknownResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "LdapResolver", + resolver: ldapResolver, + setup: setupMockLdap, + teardown: teardownMockLdap, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the test environment + tc.setup() + defer tc.teardown() + + testCache := GetUserGroupCache(tc.resolver) + testCache.resetCache() + // test cache should be empty now + assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + + ug, err := testCache.GetUserGroup("testuser2") + if err != nil { + t.Error("Lookup should not have failed: testuser2") + } + + assert.Equal(t, 1, testCache.getUGsize(), "Cache not updated should have 1 entry %d", testCache.getUGmap()) + // check returned info: 3 groups etc + if ug.User != "testuser2" || len(ug.Groups) != 3 || ug.resolved == 0 || ug.failed { + t.Errorf("User 'testuser2' not resolved correctly: %v", ug) + } + // first group should have failed resolution: just the ID expected + if ug.Groups[0] != "100" { + t.Errorf("User 'testuser2' primary group resolved while it should not: %v", ug) + } + + ug, err = testCache.GetUserGroup("testuser3") + if err != nil { + t.Error("Lookup should not have failed: testuser3") + } + + assert.Equal(t, 2, testCache.getUGsize(), "Cache not updated should have 2 entries %d", len(testCache.ugs)) + assert.Equal(t, 4, testCache.getUGGroupSize("testuser3"), "User 'testuser3' not resolved correctly: duplicate primary group not filtered %v", ug) + + ug, err = testCache.GetUserGroup("unknown") + assert.ErrorContains(t, err, "lookup failed for user: unknown") + + ug, err = testCache.GetUserGroup("testuser4") + assert.NilError(t, err) + + ug, err = testCache.GetUserGroup("testuser5") + assert.ErrorContains(t, err, "lookup failed for user: testuser5") + + ug, err = testCache.GetUserGroup("invalid-gid-user") + assert.ErrorContains(t, err, "lookup failed for user: invalid-gid-user") + exceptedGroup := []string{"1_001"} + assert.Assert(t, reflect.DeepEqual(ug.Groups, exceptedGroup), fmt.Errorf("group should be: %v, but got: %v", exceptedGroup, ug.Groups)) + }) } } -func TestCacheCleanUp(t *testing.T) { - testCache := GetUserGroupCache("test") - testCache.resetCache() - // test cache should be empty now - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - // resolve an existing user - _, err := testCache.GetUserGroup("testuser1") - if err != nil { - t.Error("Lookup should not have failed: testuser1 user") - } - _, err = testCache.GetUserGroup("testuser2") - if err != nil { - t.Error("Lookup should not have failed: testuser2 user") +func TestGetUserGroupFail(t *testing.T) { + testCases := []struct { + name string + resolver configs.UserGroupResolver + setup func() + teardown func() + }{ + { + name: "TestResolver", + resolver: testResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "OsResolver", + resolver: osResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "UnknownResolver", + resolver: unknownResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "LdapResolver", + resolver: ldapResolver, + setup: setupMockLdap, + teardown: teardownMockLdap, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the test environment + tc.setup() + defer tc.teardown() + + testCache := GetUserGroupCache(tc.resolver) + testCache.resetCache() + // test cache should be empty now + assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + + // resolve an empty user + ug, err := testCache.GetUserGroup("") + if err == nil { + t.Error("Lookup should have failed: empty user") + } + // ug is empty everything should be nil.. + if ug.User != "" || len(ug.Groups) != 0 || ug.resolved != 0 || ug.failed { + t.Errorf("UserGroup is not empty: %v", ug) + } + + // resolve a non existing user + ugi := &si.UserGroupInformation{ + User: "unknown", + Groups: nil, + } + ug, err = testCache.GetUserGroup(ugi.User) + if err == nil { + t.Error("Lookup should have failed: unknown user") + } + // ug is partially filled and failed flag is set + if ug.User != ugi.User || len(ug.Groups) != 0 || !ug.failed { + t.Errorf("UserGroup is not empty: %v", ug) + } + + ug, err = testCache.GetUserGroup(ugi.User) + if err == nil { + t.Error("Lookup should have failed: unknown user") + } + // ug is partially filled and failed flag is set: error message should show that the cache was returned + if err != nil && !strings.Contains(err.Error(), "cached data returned") { + t.Errorf("UserGroup not returned from Cache: %v, error: %v", ug, err) + } + }) } +} - testCache.lock.Lock() - ug := testCache.ugs["testuser1"] - if ug.failed { - t.Error("User 'testuser1' not resolved as a success") - } - // expire the successful lookup - ug.resolved -= 2 * poscache - testCache.lock.Unlock() - - // resolve a non existing user - _, err = testCache.GetUserGroup("unknown") - if err == nil { - t.Error("Lookup should have failed: unknown user") - } - testCache.lock.Lock() - ug = testCache.ugs["unknown"] - if !ug.failed { - t.Error("User 'unknown' not resolved as a failure") +func TestCacheCleanUp(t *testing.T) { + testCases := []struct { + name string + resolver configs.UserGroupResolver + setup func() + teardown func() + }{ + { + name: "TestResolver", + resolver: testResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "OsResolver", + resolver: osResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "UnknownResolver", + resolver: unknownResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "LdapResolver", + resolver: ldapResolver, + setup: setupMockLdap, + teardown: teardownMockLdap, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the test environment + tc.setup() + defer tc.teardown() + + testCache := GetUserGroupCache(tc.resolver) + testCache.resetCache() + // test cache should be empty now + assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + + // resolve an existing user + _, err := testCache.GetUserGroup("testuser1") + if err != nil { + t.Error("Lookup should not have failed: testuser1 user") + } + _, err = testCache.GetUserGroup("testuser2") + if err != nil { + t.Error("Lookup should not have failed: testuser2 user") + } + + testCache.lock.Lock() + ug := testCache.ugs["testuser1"] + if ug.failed { + t.Error("User 'testuser1' not resolved as a success") + } + // expire the successful lookup + ug.resolved -= 2 * poscache + testCache.lock.Unlock() + + // resolve a non existing user + _, err = testCache.GetUserGroup("unknown") + if err == nil { + t.Error("Lookup should have failed: unknown user") + } + testCache.lock.Lock() + ug = testCache.ugs["unknown"] + if !ug.failed { + t.Error("User 'unknown' not resolved as a failure") + } + // expire the failed lookup + ug.resolved -= 2 * negcache + testCache.lock.Unlock() + + testCache.cleanUpCache() + assert.Equal(t, 1, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + }) } - // expire the failed lookup - ug.resolved -= 2 * negcache - testCache.lock.Unlock() - - testCache.cleanUpCache() - assert.Equal(t, 1, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) } func TestIntervalCacheCleanUp(t *testing.T) { - testCache := GetUserGroupCache("test") - testCache.resetCache() - // test cache should be empty now - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - // resolve an existing user - user1ug, err := testCache.GetUserGroup("testuser1") - assert.NilError(t, err, "Lookup should not have failed: testuser1 user") - assert.Assert(t, !user1ug.failed, "User 'testuser1' not resolved as a success") - - _, err = testCache.GetUserGroup("testuser2") - assert.NilError(t, err, "Lookup should not have failed: testuser1 user") - - // expire the successful lookup - testCache.lock.Lock() - ug := testCache.ugs["testuser1"] - ug.resolved -= 2 * poscache - - testCache.lock.Unlock() - // resolve a non existing user - _, err = testCache.GetUserGroup("unknown") - assert.Assert(t, err != nil, "Lookup should have failed: unknown user") - testCache.lock.Lock() - ug = testCache.ugs["unknown"] - assert.Assert(t, ug.failed, "User 'unknown' not resolved as a failure") - - // expire the failed lookup - ug.resolved -= 2 * negcache - testCache.lock.Unlock() - - // sleep to wait for interval, it will trigger cleanUpCache - time.Sleep(testCache.interval + time.Second) - assert.Equal(t, 1, testCache.getUGsize(), "Cache not cleaned up : %v", testCache.getUGmap()) + testCases := []struct { + name string + resolver configs.UserGroupResolver + setup func() + teardown func() + }{ + { + name: "TestResolver", + resolver: testResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "OsResolver", + resolver: osResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "UnknownResolver", + resolver: unknownResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "LdapResolver", + resolver: ldapResolver, + setup: setupMockLdap, + teardown: teardownMockLdap, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the test environment + tc.setup() + defer tc.teardown() + + testCache := GetUserGroupCache(tc.resolver) + testCache.resetCache() + // test cache should be empty now + assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + + // resolve an existing user + user1ug, err := testCache.GetUserGroup("testuser1") + assert.NilError(t, err, "Lookup should not have failed: testuser1 user") + assert.Assert(t, !user1ug.failed, "User 'testuser1' not resolved as a success") + + _, err = testCache.GetUserGroup("testuser2") + assert.NilError(t, err, "Lookup should not have failed: testuser1 user") + + // expire the successful lookup + testCache.lock.Lock() + ug := testCache.ugs["testuser1"] + ug.resolved -= 2 * poscache + + testCache.lock.Unlock() + // resolve a non existing user + _, err = testCache.GetUserGroup("unknown") + assert.Assert(t, err != nil, "Lookup should have failed: unknown user") + testCache.lock.Lock() + ug = testCache.ugs["unknown"] + assert.Assert(t, ug.failed, "User 'unknown' not resolved as a failure") + + // expire the failed lookup + ug.resolved -= 2 * negcache + testCache.lock.Unlock() + + // sleep to wait for interval, it will trigger cleanUpCache + time.Sleep(testCache.interval + time.Second) + assert.Equal(t, 1, testCache.getUGsize(), "Cache not cleaned up : %v", testCache.getUGmap()) + }) + } } func TestConvertUGI(t *testing.T) { - testCache := GetUserGroupCache("test") - testCache.resetCache() - // test cache should be empty now - assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) - - ugi := &si.UserGroupInformation{ - User: "", - Groups: nil, - } - ug, err := testCache.ConvertUGI(ugi, false) - if err == nil { - t.Errorf("empty user convert should have failed and did not: %v", ug) - } - // try known user without groups - ugi.User = "testuser1" - ug, err = testCache.ConvertUGI(ugi, false) - if err != nil { - t.Errorf("known user, no groups, convert should not have failed: %v", err) - } - if ug.User != "testuser1" || len(ug.Groups) != 2 || ug.resolved == 0 || ug.failed { - t.Errorf("User 'testuser1' not resolved correctly: %v", ug) - } - // try unknown user without groups - ugi.User = "unknown" - ug, err = testCache.ConvertUGI(ugi, false) - if err == nil { - t.Errorf("unknown user, no groups, convert should have failed: %v", ug) + testCases := []struct { + name string + resolver configs.UserGroupResolver + setup func() + teardown func() + }{ + { + name: "TestResolver", + resolver: testResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "OsResolver", + resolver: osResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "UnknownResolver", + resolver: unknownResolver, + setup: func() {}, + teardown: func() {}, + }, + { + name: "LdapResolver", + resolver: ldapResolver, + setup: setupMockLdap, + teardown: teardownMockLdap, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup the test environment + tc.setup() + defer tc.teardown() + + testCache := GetUserGroupCache(tc.resolver) + testCache.resetCache() + // test cache should be empty now + assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v", testCache.getUGmap()) + + ugi := &si.UserGroupInformation{ + User: "", + Groups: nil, + } + ug, err := testCache.ConvertUGI(ugi, false) + if err == nil { + t.Errorf("empty user convert should have failed and did not: %v", ug) + } + // try known user without groups + ugi.User = "testuser1" + ug, err = testCache.ConvertUGI(ugi, false) + if err != nil { + t.Errorf("known user, no groups, convert should not have failed: %v", err) + } + if ug.User != "testuser1" || len(ug.Groups) != 2 || ug.resolved == 0 || ug.failed { + t.Errorf("User 'testuser1' not resolved correctly: %v", ug) + } + // try unknown user without groups + ugi.User = "unknown" + ug, err = testCache.ConvertUGI(ugi, false) + if err == nil { + t.Errorf("unknown user, no groups, convert should have failed: %v", ug) + } + // try empty user when forced + ugi.User = "" + ug, err = testCache.ConvertUGI(ugi, true) + if err != nil { + t.Errorf("empty user but forced, convert should not have failed: %v", err) + } + // try unknown user with groups + ugi.User = "unknown2" + group := "passedin" + ugi.Groups = []string{group} + ug, err = testCache.ConvertUGI(ugi, false) + if err != nil { + t.Errorf("unknown user with groups, convert should not have failed: %v", err) + } + if ug.User != "unknown2" || len(ug.Groups) != 1 || ug.resolved == 0 || ug.failed { + t.Fatalf("User 'unknown2' not resolved correctly: %v", ug) + } + if ug.Groups[0] != group { + t.Errorf("groups not initialised correctly on convert: expected '%s' got '%s'", group, ug.Groups[0]) + } + // try valid username with groups + ugi.User = "validuserABCD1234@://#" + ugi.Groups = []string{group} + ug, err = testCache.ConvertUGI(ugi, false) + if err != nil { + t.Errorf("valid username with groups, convert should not have failed: %v", err) + } + // try invalid username with groups + ugi.User = "invaliduser><+" + ugi.Groups = []string{group} + ug, err = testCache.ConvertUGI(ugi, false) + if err == nil { + t.Errorf("invalid username, convert should have failed: %v", err) + } + + // try unknown user with empty group when forced + ugi.User = "unknown" + ugi.Groups = []string{} + ug, err = testCache.ConvertUGI(ugi, true) + exceptedGroup := []string{common.AnonymousGroup} + assert.Assert(t, reflect.DeepEqual(ug.Groups, exceptedGroup), "group should be: %v, but got: %v", exceptedGroup, ug.Groups) + assert.NilError(t, err, "unknown user, no groups, convert should not have failed") + }) } - // try empty user when forced - ugi.User = "" - ug, err = testCache.ConvertUGI(ugi, true) - if err != nil { - t.Errorf("empty user but forced, convert should not have failed: %v", err) - } - // try unknown user with groups - ugi.User = "unknown2" - group := "passedin" - ugi.Groups = []string{group} - ug, err = testCache.ConvertUGI(ugi, false) - if err != nil { - t.Errorf("unknown user with groups, convert should not have failed: %v", err) - } - if ug.User != "unknown2" || len(ug.Groups) != 1 || ug.resolved == 0 || ug.failed { - t.Fatalf("User 'unknown2' not resolved correctly: %v", ug) - } - if ug.Groups[0] != group { - t.Errorf("groups not initialised correctly on convert: expected '%s' got '%s'", group, ug.Groups[0]) - } - // try valid username with groups - ugi.User = "validuserABCD1234@://#" - ugi.Groups = []string{group} - ug, err = testCache.ConvertUGI(ugi, false) - if err != nil { - t.Errorf("valid username with groups, convert should not have failed: %v", err) - } - // try invalid username with groups - ugi.User = "invaliduser><+" - ugi.Groups = []string{group} - ug, err = testCache.ConvertUGI(ugi, false) - if err == nil { - t.Errorf("invalid username, convert should have failed: %v", err) - } - - // try unknown user with empty group when forced - ugi.User = "unknown" - ugi.Groups = []string{} - ug, err = testCache.ConvertUGI(ugi, true) - exceptedGroup := []string{common.AnonymousGroup} - assert.Assert(t, reflect.DeepEqual(ug.Groups, exceptedGroup), "group should be: %v, but got: %v", exceptedGroup, ug.Groups) - assert.NilError(t, err, "unknown user, no groups, convert should not have failed") } diff --git a/pkg/common/security/usergroup_test_resolver.go b/pkg/common/security/usergroup_test_resolver.go index c35dd6bb9..47058c260 100644 --- a/pkg/common/security/usergroup_test_resolver.go +++ b/pkg/common/security/usergroup_test_resolver.go @@ -26,9 +26,12 @@ import ( ) const ( + Testuser = "testuser" Testuser1 = "testuser1" Testuser2 = "testuser2" Testuser3 = "testuser3" + Testuser4 = "testuser4" + Testuser5 = "testuser5" ) // Get the cache with a test resolver diff --git a/pkg/scheduler/partition.go b/pkg/scheduler/partition.go index 49a2fa2a9..61a6c2f17 100644 --- a/pkg/scheduler/partition.go +++ b/pkg/scheduler/partition.go @@ -134,7 +134,7 @@ func (pc *PartitionContext) initialPartitionFromConfig(conf configs.PartitionCon // Placing an application will not have a lock on the partition context. pc.placementManager = placement.NewPlacementManager(conf.PlacementRules, pc.GetQueue, silence) // get the user group cache for the partition - pc.userGroupCache = security.GetUserGroupCache("") + pc.userGroupCache = security.GetUserGroupCache(conf.UserGroupResolver) pc.updateNodeSortingPolicy(conf, silence) pc.updatePreemption(conf)