diff --git a/Dockerfile b/Dockerfile index 04c5aa1..c98ffca 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,12 +1,12 @@ FROM golang:alpine -COPY src /go/src/github.com/tiagoapimenta/nginx-ldap-auth +COPY . /go/src/github.com/tiagoapimenta/nginx-ldap-auth RUN cd /go/src/github.com/tiagoapimenta/nginx-ldap-auth && \ apk add --no-cache git && \ go get -u gopkg.in/yaml.v2 && \ go get -u gopkg.in/ldap.v2 && \ - go build -ldflags='-s -w' -v -o /go/bin/nginx-ldap-auth . + go build -ldflags='-s -w' -v -o /go/bin/nginx-ldap-auth ./main FROM alpine diff --git a/README.md b/README.md index 3df436c..47c903e 100644 --- a/README.md +++ b/README.md @@ -8,14 +8,12 @@ Use this in order to provide a ingress authentication over LDAP for Kubernetes, Configure your ingress with annotation `nginx.ingress.kubernetes.io/auth-url: http://nginx-ldap-auth.default.svc.cluster.local:5555` as described on [nginx documentation](https://kubernetes.github.io/ingress-nginx/examples/auth/external-auth/). -## Config +## Configuration The actual version choose a random server, in future version it is intended to have a pool of them, that is why it is a list, not a single one, but you can fill only one if you wish. The prefix tell the program which protocol to use, if `ldaps://` it will try LDAP over SSL, if `ldap://` it will try plain LDAP with STARTTLS, case no prefix is given it will try to guess based on port, 636 for SSL and 389 for plain. -The actual version will fail if neither SSL or STARTTLS is possible, but next version will allow plain LDAP. - If the `user.requiredGroups` list is omited or empty all LDAP users will be allowed regardless the group, if not empty all groups will be required, the next version will have more flexible configuration. If you are not sure what `filter`, `bindDN` or `baseDN` to use, here is a tip: diff --git a/build b/build index 806bafb..20eaed5 100755 --- a/build +++ b/build @@ -3,7 +3,7 @@ set -e base='docker.io/tpimenta/nginx-ldap-auth' -version='v1.0.0' +version='v1.0.1' image="$base:$version" atexit() { diff --git a/src/timeout.go b/data/storage.go similarity index 63% rename from src/timeout.go rename to data/storage.go index 432efba..5748569 100644 --- a/src/timeout.go +++ b/data/storage.go @@ -1,4 +1,4 @@ -package main +package data import ( "sort" @@ -16,30 +16,27 @@ type userpass struct { wrong []passtimer } -var ( - passwords = map[string]*userpass{} - mutex = sync.RWMutex{} -) - -func containsWrongPassword(data *userpass, password string) (int, bool) { - size := len(data.wrong) - if size == 0 { - return 0, false - } - - pos := sort.Search(size, func(i int) bool { - return data.wrong[i].password >= password - }) - - return pos, pos < size && - data.wrong[pos].password == password +type Storage struct { + passwords map[string]*userpass + lock sync.RWMutex + success time.Duration + wrong time.Duration } -func getCache(username, password string) (bool, bool) { - defer mutex.RUnlock() - mutex.RLock() +func NewStorage(success, wrong time.Duration) *Storage { + return &Storage{ + passwords: map[string]*userpass{}, + lock: sync.RWMutex{}, + success: success, + wrong: wrong, + } +} - data, found := passwords[username] +func (p *Storage) Get(username, password string) (bool, bool) { + p.lock.RLock() + defer p.lock.RUnlock() + + data, found := p.passwords[username] if !found { return false, false } @@ -53,25 +50,25 @@ func getCache(username, password string) (bool, bool) { return false, found } -func putCache(username, password string, ok bool) { - defer mutex.Unlock() - mutex.Lock() +func (p *Storage) Put(username, password string, ok bool) { + p.lock.Lock() + defer p.lock.Unlock() - data, found := passwords[username] + data, found := p.passwords[username] if !found { data = &userpass{} - passwords[username] = data + p.passwords[username] = data } - timeout := config.Timeout.Wrong + timeout := p.wrong if ok { - timeout = config.Timeout.Success + timeout = p.success } pass := passtimer{ password: password, timer: time.AfterFunc(timeout, func() { - removeCache(username, password, ok) + p.remove(username, password, ok) }), } @@ -92,11 +89,11 @@ func putCache(username, password string, ok bool) { } } -func removeCache(username, password string, ok bool) { - defer mutex.Unlock() - mutex.Lock() +func (p *Storage) remove(username, password string, ok bool) { + p.lock.Lock() + defer p.lock.Unlock() - data, found := passwords[username] + data, found := p.passwords[username] if !found { return } @@ -115,6 +112,20 @@ func removeCache(username, password string, ok bool) { } if data.correct == nil && len(data.wrong) == 0 { - delete(passwords, username) + delete(p.passwords, username) } } + +func containsWrongPassword(data *userpass, password string) (int, bool) { + size := len(data.wrong) + if size == 0 { + return 0, false + } + + pos := sort.Search(size, func(i int) bool { + return data.wrong[i].password >= password + }) + + return pos, pos < size && + data.wrong[pos].password == password +} diff --git a/data/storage_test.go b/data/storage_test.go new file mode 100644 index 0000000..8bf3100 --- /dev/null +++ b/data/storage_test.go @@ -0,0 +1,85 @@ +package data + +import ( + "bytes" + "fmt" + "testing" + "time" +) + +const ( + username1 = "Alice" + username2 = "James" + password1 = "master" + password2 = "shadow" + password3 = "qwerty" + success = time.Second / 2 + wrong = time.Second / 5 +) + +func printPassMap(t *testing.T, storage *Storage, prefix string) { + buffer := bytes.Buffer{} + first := true + for k, v := range storage.passwords { + if first { + first = false + } else { + buffer.WriteByte(',') + } + correct := "" + if v.correct != nil { + correct = v.correct.password + } + fmt.Fprintf(&buffer, "%s:{correct:%s,wrong:%+v}", k, correct, v.wrong) + } + t.Logf("%s passwords: %s\n", prefix, buffer.String()) +} + +func testCache(t *testing.T, storage *Storage, id int, username, password string, eok, efound bool) { + ok, found := storage.Get(username, password) + if ok != eok || found != efound { + t.Errorf("Test %d expected (%v %v) given (%v %v)\n", id, eok, efound, ok, found) + } +} + +func TestPasswordTimeout(t *testing.T) { + storage := NewStorage(success, wrong) + + testCache(t, storage, 0, username1, password1, false, false) + testCache(t, storage, 1, username1, password2, false, false) + testCache(t, storage, 2, username1, password3, false, false) + testCache(t, storage, 3, username2, password1, false, false) + testCache(t, storage, 4, username2, password2, false, false) + testCache(t, storage, 5, username2, password3, false, false) + + storage.Put(username1, password1, true) + storage.Put(username1, password3, false) + printPassMap(t, storage, "add") + + testCache(t, storage, 6, username1, password1, true, true) + testCache(t, storage, 7, username1, password2, false, false) + testCache(t, storage, 8, username1, password3, false, true) + testCache(t, storage, 9, username2, password1, false, false) + testCache(t, storage, 10, username2, password2, false, false) + testCache(t, storage, 11, username2, password3, false, false) + + time.Sleep(wrong + wrong/2) + printPassMap(t, storage, "timed") + + testCache(t, storage, 12, username1, password1, true, true) + testCache(t, storage, 13, username1, password2, false, false) + testCache(t, storage, 14, username1, password3, false, false) + testCache(t, storage, 15, username2, password1, false, false) + testCache(t, storage, 16, username2, password2, false, false) + testCache(t, storage, 17, username2, password3, false, false) + + time.Sleep(success - wrong) + printPassMap(t, storage, "expired") + + testCache(t, storage, 18, username1, password1, false, false) + testCache(t, storage, 19, username1, password2, false, false) + testCache(t, storage, 20, username1, password3, false, false) + testCache(t, storage, 21, username2, password1, false, false) + testCache(t, storage, 22, username2, password2, false, false) + testCache(t, storage, 23, username2, password3, false, false) +} diff --git a/group/service.go b/group/service.go new file mode 100644 index 0000000..4ae2d6d --- /dev/null +++ b/group/service.go @@ -0,0 +1,39 @@ +package group + +import ( + "strings" + + "github.com/tiagoapimenta/nginx-ldap-auth/ldap" +) + +type Service struct { + pool *ldap.Pool + base string + filter string + attr string +} + +func NewService(pool *ldap.Pool, base, filter, attr string) *Service { + return &Service{ + pool: pool, + base: base, + filter: filter, + attr: attr, + } +} + +func (p *Service) Find(id string) ([]string, error) { + ok, _, groups, err := p.pool.Search( + p.base, + strings.Replace(p.filter, "{0}", id, -1), + p.attr, + ) + + if !ok && err != nil { + return nil, err + } else if err != nil { + return []string{}, nil + } + + return groups, nil +} diff --git a/k8s.yaml b/k8s.yaml index c8921bf..3807d1c 100644 --- a/k8s.yaml +++ b/k8s.yaml @@ -26,7 +26,7 @@ spec: app: nginx-ldap-auth spec: containers: - - image: docker.io/tpimenta/nginx-ldap-auth:v1.0.0 + - image: docker.io/tpimenta/nginx-ldap-auth:v1.0.1 name: nginx-ldap-auth command: - "nginx-ldap-auth" diff --git a/ldap/connect.go b/ldap/connect.go new file mode 100644 index 0000000..f9bfa62 --- /dev/null +++ b/ldap/connect.go @@ -0,0 +1,47 @@ +package ldap + +import ( + "crypto/tls" + "errors" + "fmt" + "log" + + ldap "gopkg.in/ldap.v2" +) + +func (p *Pool) Connect() error { + if p.url == "" { + return errors.New("No LDAP server available!") + } + + if p.port == 0 { + return fmt.Errorf("Unable to determine schema or port for \"%s\"", p.url) + } + + if p.conn != nil { + return nil + } + + address := fmt.Sprintf("%s:%d", p.url, p.port) + if p.ssl { + conn, err := ldap.DialTLS("tcp", address, &tls.Config{InsecureSkipVerify: true}) + if err != nil { + return err + } + p.conn = conn + } else { + conn, err := ldap.Dial("tcp", address) + if err != nil { + return err + } + err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) + if err != nil { + log.Printf("It was not possble to start TLS, falling back to plain: %v.\n", err) + } + p.conn = conn + } + + p.admin = false + + return p.auth() +} diff --git a/ldap/login.go b/ldap/login.go new file mode 100644 index 0000000..61d25be --- /dev/null +++ b/ldap/login.go @@ -0,0 +1,36 @@ +package ldap + +func (p *Pool) Validate(username, password string) (bool, error) { + p.lock.Lock() + defer p.lock.Unlock() + + err := p.auth() + if err != nil { + return false, err + } + + p.admin = false + err = p.conn.Bind(username, password) + if err != nil { + return true, err + } + + err = p.auth() + if err != nil { + return false, err + } + + return true, nil +} + +func (p *Pool) auth() error { + if p.admin || p.username == "" && p.password == "" { + return nil + } + + err := p.conn.Bind(p.username, p.password) + if err == nil { + p.admin = true + } + return err +} diff --git a/ldap/pool.go b/ldap/pool.go new file mode 100644 index 0000000..5f4e202 --- /dev/null +++ b/ldap/pool.go @@ -0,0 +1,78 @@ +package ldap + +import ( + "log" + "math/rand" + "regexp" + "strconv" + "strings" + "sync" + "time" + + ldap "gopkg.in/ldap.v2" +) + +type Pool struct { + url string + port int + ssl bool + username string + password string + conn *ldap.Conn + admin bool + lock sync.Mutex +} + +func NewPool(servers []string, username, password string) *Pool { + url := "" + port := 0 + schema := "auto" + + size := len(servers) + if size != 0 { + r := rand.New(rand.NewSource(time.Now().Unix())) + server := servers[r.Intn(size)] + + url = server + if strings.HasPrefix(url, "ldaps:") { + url = strings.TrimPrefix(strings.TrimPrefix(url, "ldaps:"), "//") + schema = "ldaps" + port = 636 + } else if strings.HasPrefix(url, "ldap:") { + url = strings.TrimPrefix(strings.TrimPrefix(url, "ldap:"), "//") + schema = "ldap" + port = 389 + } + + portExp := regexp.MustCompile(`:[0-9]+$`) + if portExp.MatchString(url) { + str := portExp.FindString(url) + + number, err := strconv.Atoi(str[1:]) + if err == nil { + port = number + url = strings.TrimSuffix(url, str) + } else { + log.Printf("Error on parse port of \"%s\": %v\n", server, err) + } + } + + if schema == "auto" { + if port == 636 { + schema = "ldaps" + } else if port == 389 { + schema = "ldap" + } else { + port = 0 + } + } + } + + return &Pool{ + url: url, + port: port, + ssl: schema == "ldaps", + username: username, + password: password, + } +} diff --git a/ldap/search.go b/ldap/search.go new file mode 100644 index 0000000..6d998db --- /dev/null +++ b/ldap/search.go @@ -0,0 +1,57 @@ +package ldap + +import ( + "fmt" + "sort" + + ldap "gopkg.in/ldap.v2" +) + +func (p *Pool) Search(base, filter string, attr string) (bool, string, []string, error) { + p.lock.Lock() + defer p.lock.Unlock() + + err := p.auth() + if err != nil { + return false, "", nil, err + } + + var list []string = nil + if attr != "" { + list = []string{attr} + } + + res, err := p.conn.Search(ldap.NewSearchRequest( + base, + ldap.ScopeWholeSubtree, + ldap.NeverDerefAliases, + 0, + 0, + false, + filter, + list, + nil, + )) + if err != nil { + return false, "", nil, err + } + + if len(res.Entries) == 0 { + return true, "", nil, fmt.Errorf("No results for %s filter %s", base, filter) + } + + if attr == "" && len(res.Entries) > 1 { + return true, "", nil, fmt.Errorf("Too many results for %s filter %s", base, filter) + } + + var result []string = nil + if attr != "" { + result = []string{} + for _, entry := range res.Entries { + result = append(result, entry.GetAttributeValue(attr)) + } + sort.Strings(result) + } + + return true, res.Entries[0].DN, result, nil +} diff --git a/src/config.go b/main/config.go similarity index 100% rename from src/config.go rename to main/config.go diff --git a/main/main.go b/main/main.go new file mode 100644 index 0000000..4ef39dc --- /dev/null +++ b/main/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "fmt" + "log" + + "github.com/tiagoapimenta/nginx-ldap-auth/data" + "github.com/tiagoapimenta/nginx-ldap-auth/group" + "github.com/tiagoapimenta/nginx-ldap-auth/ldap" + "github.com/tiagoapimenta/nginx-ldap-auth/rule" + "github.com/tiagoapimenta/nginx-ldap-auth/user" +) + +func main() { + file, config, err := parseConfig() + if err != nil { + log.Fatalln(err.Error()) + } + + fmt.Printf("Loaded config \"%s\".\n", file) + + pool := ldap.NewPool(config.Servers, config.Auth.BindDN, config.Auth.BindPW) + + err = pool.Connect() + if err != nil { + log.Fatalf("Error on connect to LDAP: %v\n", err) + } + + storage := data.NewStorage(config.Timeout.Success, config.Timeout.Wrong) + + userService := user.NewService(pool, config.User.BaseDN, config.User.Filter) + + groupService := group.NewService(pool, config.Group.BaseDN, config.Group.Filter, config.Group.GroupAttr) + + ruleService := rule.NewService(storage, userService, groupService, config.User.RequiredGroups) + + fmt.Printf("Serving...\n") + err = startServer(ruleService, config.Web, config.Path, config.Message) + if err != nil { + log.Fatalf("Error on start server: %v\n", err) + } +} diff --git a/main/parser.go b/main/parser.go new file mode 100644 index 0000000..4b1bf3d --- /dev/null +++ b/main/parser.go @@ -0,0 +1,45 @@ +package main + +import ( + "flag" + "fmt" + "io/ioutil" + "time" + + yaml "gopkg.in/yaml.v2" +) + +func parseConfig() (string, *Config, error) { + file := flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file") + + flag.Parse() + + data, err := ioutil.ReadFile(*file) + if err != nil { + return "", nil, fmt.Errorf("Error on read file \"%s\": %v", *file, err) + } + + config := Config{ + Web: "0.0.0.0:5555", + Path: "/", + Message: "LDAP Login", + User: UserConfig{ + Filter: "(cn={0})", + }, + Group: GroupConfig{ + Filter: "(member={0})", + GroupAttr: "cn", + }, + Timeout: TimeoutConfig{ + Success: 24 * time.Hour, + Wrong: 5 * time.Minute, + }, + } + + err = yaml.Unmarshal(data, &config) + if err != nil { + return "", nil, fmt.Errorf("Error on parse config: %v", err) + } + + return *file, &config, nil +} diff --git a/main/server.go b/main/server.go new file mode 100644 index 0000000..ce91552 --- /dev/null +++ b/main/server.go @@ -0,0 +1,42 @@ +package main + +import ( + "encoding/base64" + "fmt" + "log" + "net/http" + "strings" + + "github.com/tiagoapimenta/nginx-ldap-auth/rule" +) + +func startServer(service *rule.Service, server, path, message string) error { + realm := fmt.Sprintf("Basic realm=\"%s\"", message) + + http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + header := r.Header.Get("Authorization") + + if header != "" { + auth := strings.SplitN(header, " ", 2) + + if len(auth) == 2 && auth[0] == "Basic" { + decoded, err := base64.StdEncoding.DecodeString(auth[1]) + if err == nil { + secret := strings.SplitN(string(decoded), ":", 2) + + if len(secret) == 2 && service.Validate(secret[0], secret[1]) { + w.WriteHeader(http.StatusOK) + return + } + } else { + log.Printf("Error decode basic auth: %v\n", err) + } + } + } + + w.Header().Set("WWW-Authenticate", realm) + w.WriteHeader(http.StatusUnauthorized) + }) + + return http.ListenAndServe(server, nil) +} diff --git a/rule/service.go b/rule/service.go new file mode 100644 index 0000000..296c36e --- /dev/null +++ b/rule/service.go @@ -0,0 +1,74 @@ +package rule + +import ( + "log" + "sort" + + "github.com/tiagoapimenta/nginx-ldap-auth/data" + "github.com/tiagoapimenta/nginx-ldap-auth/group" + "github.com/tiagoapimenta/nginx-ldap-auth/user" +) + +type Service struct { + storage *data.Storage + user *user.Service + group *group.Service + required []string +} + +func NewService(storage *data.Storage, userService *user.Service, groupService *group.Service, required []string) *Service { + return &Service{ + storage: storage, + user: userService, + group: groupService, + required: required, + } +} + +func (p *Service) Validate(username, password string) bool { + ok, found := p.storage.Get(username, password) + if found { + return ok + } + + ok, err := p.validate(username, password) + if err != nil { + log.Printf("Could not validade user %s: %v\n", username, err) + return false + } + + p.storage.Put(username, password, ok) + return ok +} + +func (p *Service) validate(username, password string) (bool, error) { + ok, id, err := p.user.Find(username) + if !ok && err != nil { + return false, err + } else if err != nil { + return false, nil + } + + ok, err = p.user.Login(id, password) + if !ok && err != nil { + return false, err + } + + if ok || p.required == nil || len(p.required) == 0 { + return err != nil, nil + } + + groups, err := p.group.Find(id) + if err != nil { + return false, err + } + + for _, group := range p.required { + pos := sort.SearchStrings(groups, group) + if pos >= len(groups) || groups[pos] != group { + return false, nil + } + } + + return true, nil +} diff --git a/src/ldap.go b/src/ldap.go deleted file mode 100644 index 7f3a388..0000000 --- a/src/ldap.go +++ /dev/null @@ -1,181 +0,0 @@ -package main - -import ( - "crypto/tls" - "errors" - "fmt" - "math/rand" - "regexp" - "sort" - "strconv" - "strings" - "sync" - "time" - - ldap "gopkg.in/ldap.v2" -) - -var ( - conn *ldap.Conn - admin bool - lock sync.Mutex -) - -func setupLDAP() error { - size := len(config.Servers) - if size == 0 { - return errors.New("No LDAP server available!") - } - - r := rand.New(rand.NewSource(time.Now().Unix())) - portExp := regexp.MustCompile(`:[0-9]+$`) - - server := config.Servers[r.Intn(size)] - url := server - port := 0 - schema := "auto" - - if strings.HasPrefix(url, "ldaps:") { - url = strings.TrimPrefix(strings.TrimPrefix(url, "ldaps:"), "//") - schema = "ldaps" - port = 636 - } else if strings.HasPrefix(url, "ldap:") { - url = strings.TrimPrefix(strings.TrimPrefix(url, "ldap:"), "//") - schema = "ldap" - port = 389 - } - - var err error - if portExp.MatchString(url) { - str := portExp.FindString(url) - port, err = strconv.Atoi(str[1:]) - if err != nil { - return fmt.Errorf("Unable to parse url \"%s\", %v", server, err) - } - url = strings.TrimSuffix(url, str) - } - - if schema == "auto" { - if port == 636 { - schema = "ldaps" - } else if port == 389 { - schema = "ldap" - } - } - - if schema == "auto" || port == 0 { - return fmt.Errorf("Unable to determine schema or port for \"%s\"", server) - } - - address := fmt.Sprintf("%s:%d", url, port) - fmt.Printf("Connecting LDAP %s...\n", address) - - if schema == "ldaps" { - conn, err = ldap.DialTLS("tcp", address, &tls.Config{InsecureSkipVerify: true}) - if err != nil { - return err - } - } else { - conn, err = ldap.Dial("tcp", address) - if err != nil { - return err - } - err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true}) - if err != nil { - return err - } - } - - admin = false - return auth() -} - -func auth() error { - if admin || config.Auth.BindDN == "" && config.Auth.BindPW == "" { - return nil - } - err := conn.Bind(config.Auth.BindDN, config.Auth.BindPW) - if err == nil { - admin = true - } - return nil -} - -func ldapLogin(username, password string) (bool, error) { - lock.Lock() - defer lock.Unlock() - - err := auth() - if err != nil { - return false, err - } - - req := ldap.NewSearchRequest( - config.User.BaseDN, - ldap.ScopeWholeSubtree, - ldap.NeverDerefAliases, - 0, - 0, - false, - strings.Replace(config.User.Filter, "{0}", username, -1), - nil, - nil, - ) - - res, err := conn.Search(req) - if err != nil { - return false, err - } - - if len(res.Entries) != 1 { - return false, nil - } - - admin = false - err = conn.Bind(res.Entries[0].DN, password) - if err != nil { - return false, nil - } - - err = auth() - if err != nil { - return false, err - } - - if len(config.User.RequiredGroups) == 0 { - return true, nil - } - - req = ldap.NewSearchRequest( - config.Group.BaseDN, - ldap.ScopeWholeSubtree, - ldap.NeverDerefAliases, - 0, - 0, - false, - strings.Replace(config.Group.Filter, "{0}", res.Entries[0].DN, -1), - []string{config.Group.GroupAttr}, - nil, - ) - - res, err = conn.Search(req) - if err != nil { - return false, err - } - - groups := []string{} - for _, entry := range res.Entries { - groups = append(groups, entry.GetAttributeValue(config.Group.GroupAttr)) - } - - sort.Strings(groups) - - for _, group := range config.User.RequiredGroups { - pos := sort.SearchStrings(groups, group) - if pos >= len(groups) || groups[pos] != group { - return false, nil - } - } - - return true, nil -} diff --git a/src/main.go b/src/main.go deleted file mode 100644 index fa5d794..0000000 --- a/src/main.go +++ /dev/null @@ -1,106 +0,0 @@ -package main - -import ( - "encoding/base64" - "flag" - "fmt" - "io/ioutil" - "log" - "net/http" - "strings" - "time" - - yaml "gopkg.in/yaml.v2" -) - -var ( - configFile = flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file") - config = Config{ - Web: "0.0.0.0:5555", - Path: "/", - Message: "LDAP Login", - User: UserConfig{ - Filter: "(cn={0})", - }, - Group: GroupConfig{ - Filter: "(member={0})", - GroupAttr: "cn", - }, - Timeout: TimeoutConfig{ - Success: 24 * time.Hour, - Wrong: 5 * time.Minute, - }, - } -) - -func main() { - flag.Parse() - - data, err := ioutil.ReadFile(*configFile) - if err != nil { - log.Fatalf("Error on read file \"%s\": %v\n", *configFile, err) - } - - err = yaml.Unmarshal(data, &config) - if err != nil { - log.Fatalf("Error on parse config: %v\n", err) - } - - fmt.Printf("Loaded config \"%s\".\n", *configFile) - - err = setupLDAP() - if err != nil { - log.Fatalf("Error on connect to LDAP: %v\n", err) - } - - http.HandleFunc(config.Path, handler) - - fmt.Printf("Serving...\n") - err = http.ListenAndServe(config.Web, nil) - if err != nil { - log.Fatalf("Error on start server: %v\n", err) - } -} - -func handler(w http.ResponseWriter, r *http.Request) { - header := r.Header.Get("Authorization") - - if header != "" { - auth := strings.SplitN(header, " ", 2) - - if len(auth) == 2 && auth[0] == "Basic" { - decoded, err := base64.StdEncoding.DecodeString(auth[1]) - if err == nil { - secret := strings.SplitN(string(decoded), ":", 2) - - if len(secret) == 2 && validate(secret[0], secret[1]) { - // TODO: match by header, e.g: X-Original-URL X-Original-Method X-Sent-From X-Auth-Request-Redirect - - w.WriteHeader(http.StatusOK) - return - } - } else { - log.Printf("Error decode basic auth: %v\n", err) - } - } - } - - w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", config.Message)) - w.WriteHeader(http.StatusUnauthorized) -} - -func validate(username, password string) bool { - ok, found := getCache(username, password) - if found { - return ok - } - - ok, err := ldapLogin(username, password) - if err != nil { - log.Printf("Could not validade user %s: %v\n", username, err) - return false - } - - putCache(username, password, ok) - return ok -} diff --git a/src/timeout_test.go b/src/timeout_test.go deleted file mode 100644 index daa07cf..0000000 --- a/src/timeout_test.go +++ /dev/null @@ -1,88 +0,0 @@ -package main - -import ( - "bytes" - "fmt" - "os" - "testing" - "time" -) - -const ( - username1 = "Alice" - username2 = "James" - password1 = "master" - password2 = "shadow" - password3 = "qwerty" -) - -func TestMain(m *testing.M) { - config.Timeout.Success = time.Second / 2 - config.Timeout.Wrong = time.Second / 5 - os.Exit(m.Run()) -} - -func printPassMap(t *testing.T, prefix string) { - buffer := bytes.Buffer{} - first := true - for k, v := range passwords { - if first { - first = false - } else { - buffer.WriteByte(',') - } - correct := "" - if v.correct != nil { - correct = v.correct.password - } - fmt.Fprintf(&buffer, "%s:{correct:%s,wrong:%+v}", k, correct, v.wrong) - } - t.Logf("%s passwords: %s\n", prefix, buffer.String()) -} - -func testCache(t *testing.T, id int, username, password string, eok, efound bool) { - ok, found := getCache(username, password) - if ok != eok || found != efound { - t.Errorf("Test %d expected (%v %v) given (%v %v)\n", id, eok, efound, ok, found) - } -} - -func TestPasswordTimeout(t *testing.T) { - testCache(t, 0, username1, password1, false, false) - testCache(t, 1, username1, password2, false, false) - testCache(t, 2, username1, password3, false, false) - testCache(t, 3, username2, password1, false, false) - testCache(t, 4, username2, password2, false, false) - testCache(t, 5, username2, password3, false, false) - - putCache(username1, password1, true) - putCache(username1, password3, false) - printPassMap(t, "add") - - testCache(t, 6, username1, password1, true, true) - testCache(t, 7, username1, password2, false, false) - testCache(t, 8, username1, password3, false, true) - testCache(t, 9, username2, password1, false, false) - testCache(t, 10, username2, password2, false, false) - testCache(t, 11, username2, password3, false, false) - - time.Sleep(config.Timeout.Wrong + config.Timeout.Wrong/2) - printPassMap(t, "timed") - - testCache(t, 12, username1, password1, true, true) - testCache(t, 13, username1, password2, false, false) - testCache(t, 14, username1, password3, false, false) - testCache(t, 15, username2, password1, false, false) - testCache(t, 16, username2, password2, false, false) - testCache(t, 17, username2, password3, false, false) - - time.Sleep(config.Timeout.Success - config.Timeout.Wrong) - printPassMap(t, "expired") - - testCache(t, 18, username1, password1, false, false) - testCache(t, 19, username1, password2, false, false) - testCache(t, 20, username1, password3, false, false) - testCache(t, 21, username2, password1, false, false) - testCache(t, 22, username2, password2, false, false) - testCache(t, 23, username2, password3, false, false) -} diff --git a/user/service.go b/user/service.go new file mode 100644 index 0000000..dbf1c6a --- /dev/null +++ b/user/service.go @@ -0,0 +1,35 @@ +package user + +import ( + "strings" + + "github.com/tiagoapimenta/nginx-ldap-auth/ldap" +) + +type Service struct { + pool *ldap.Pool + base string + filter string +} + +func NewService(pool *ldap.Pool, base, filter string) *Service { + return &Service{ + pool: pool, + base: base, + filter: filter, + } +} + +func (p *Service) Find(username string) (bool, string, error) { + ok, id, _, err := p.pool.Search( + p.base, + strings.Replace(p.filter, "{0}", username, -1), + "", + ) + + return ok, id, err +} + +func (p *Service) Login(id, password string) (bool, error) { + return p.pool.Validate(id, password) +}