Refactor
This commit is contained in:
parent
83b51c0bdf
commit
bdce35bc55
|
@ -1,12 +1,12 @@
|
|||
FROM golang:alpine
|
||||
|
||||
COPY src /go/src/github.com/tiagoapimenta/nginx-ldap-auth
|
||||
COPY . /go/src/github.com/tiagoapimenta/nginx-ldap-auth
|
||||
|
||||
RUN cd /go/src/github.com/tiagoapimenta/nginx-ldap-auth && \
|
||||
apk add --no-cache git && \
|
||||
go get -u gopkg.in/yaml.v2 && \
|
||||
go get -u gopkg.in/ldap.v2 && \
|
||||
go build -ldflags='-s -w' -v -o /go/bin/nginx-ldap-auth .
|
||||
go build -ldflags='-s -w' -v -o /go/bin/nginx-ldap-auth ./main
|
||||
|
||||
FROM alpine
|
||||
|
||||
|
|
|
@ -8,14 +8,12 @@ Use this in order to provide a ingress authentication over LDAP for Kubernetes,
|
|||
|
||||
Configure your ingress with annotation `nginx.ingress.kubernetes.io/auth-url: http://nginx-ldap-auth.default.svc.cluster.local:5555` as described on [nginx documentation](https://kubernetes.github.io/ingress-nginx/examples/auth/external-auth/).
|
||||
|
||||
## Config
|
||||
## Configuration
|
||||
|
||||
The actual version choose a random server, in future version it is intended to have a pool of them, that is why it is a list, not a single one, but you can fill only one if you wish.
|
||||
|
||||
The prefix tell the program which protocol to use, if `ldaps://` it will try LDAP over SSL, if `ldap://` it will try plain LDAP with STARTTLS, case no prefix is given it will try to guess based on port, 636 for SSL and 389 for plain.
|
||||
|
||||
The actual version will fail if neither SSL or STARTTLS is possible, but next version will allow plain LDAP.
|
||||
|
||||
If the `user.requiredGroups` list is omited or empty all LDAP users will be allowed regardless the group, if not empty all groups will be required, the next version will have more flexible configuration.
|
||||
|
||||
If you are not sure what `filter`, `bindDN` or `baseDN` to use, here is a tip:
|
||||
|
|
2
build
2
build
|
@ -3,7 +3,7 @@
|
|||
set -e
|
||||
|
||||
base='docker.io/tpimenta/nginx-ldap-auth'
|
||||
version='v1.0.0'
|
||||
version='v1.0.1'
|
||||
image="$base:$version"
|
||||
|
||||
atexit() {
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package data
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
@ -16,30 +16,27 @@ type userpass struct {
|
|||
wrong []passtimer
|
||||
}
|
||||
|
||||
var (
|
||||
passwords = map[string]*userpass{}
|
||||
mutex = sync.RWMutex{}
|
||||
)
|
||||
|
||||
func containsWrongPassword(data *userpass, password string) (int, bool) {
|
||||
size := len(data.wrong)
|
||||
if size == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
pos := sort.Search(size, func(i int) bool {
|
||||
return data.wrong[i].password >= password
|
||||
})
|
||||
|
||||
return pos, pos < size &&
|
||||
data.wrong[pos].password == password
|
||||
type Storage struct {
|
||||
passwords map[string]*userpass
|
||||
lock sync.RWMutex
|
||||
success time.Duration
|
||||
wrong time.Duration
|
||||
}
|
||||
|
||||
func getCache(username, password string) (bool, bool) {
|
||||
defer mutex.RUnlock()
|
||||
mutex.RLock()
|
||||
func NewStorage(success, wrong time.Duration) *Storage {
|
||||
return &Storage{
|
||||
passwords: map[string]*userpass{},
|
||||
lock: sync.RWMutex{},
|
||||
success: success,
|
||||
wrong: wrong,
|
||||
}
|
||||
}
|
||||
|
||||
data, found := passwords[username]
|
||||
func (p *Storage) Get(username, password string) (bool, bool) {
|
||||
p.lock.RLock()
|
||||
defer p.lock.RUnlock()
|
||||
|
||||
data, found := p.passwords[username]
|
||||
if !found {
|
||||
return false, false
|
||||
}
|
||||
|
@ -53,25 +50,25 @@ func getCache(username, password string) (bool, bool) {
|
|||
return false, found
|
||||
}
|
||||
|
||||
func putCache(username, password string, ok bool) {
|
||||
defer mutex.Unlock()
|
||||
mutex.Lock()
|
||||
func (p *Storage) Put(username, password string, ok bool) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
data, found := passwords[username]
|
||||
data, found := p.passwords[username]
|
||||
if !found {
|
||||
data = &userpass{}
|
||||
passwords[username] = data
|
||||
p.passwords[username] = data
|
||||
}
|
||||
|
||||
timeout := config.Timeout.Wrong
|
||||
timeout := p.wrong
|
||||
if ok {
|
||||
timeout = config.Timeout.Success
|
||||
timeout = p.success
|
||||
}
|
||||
|
||||
pass := passtimer{
|
||||
password: password,
|
||||
timer: time.AfterFunc(timeout, func() {
|
||||
removeCache(username, password, ok)
|
||||
p.remove(username, password, ok)
|
||||
}),
|
||||
}
|
||||
|
||||
|
@ -92,11 +89,11 @@ func putCache(username, password string, ok bool) {
|
|||
}
|
||||
}
|
||||
|
||||
func removeCache(username, password string, ok bool) {
|
||||
defer mutex.Unlock()
|
||||
mutex.Lock()
|
||||
func (p *Storage) remove(username, password string, ok bool) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
data, found := passwords[username]
|
||||
data, found := p.passwords[username]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
@ -115,6 +112,20 @@ func removeCache(username, password string, ok bool) {
|
|||
}
|
||||
|
||||
if data.correct == nil && len(data.wrong) == 0 {
|
||||
delete(passwords, username)
|
||||
delete(p.passwords, username)
|
||||
}
|
||||
}
|
||||
|
||||
func containsWrongPassword(data *userpass, password string) (int, bool) {
|
||||
size := len(data.wrong)
|
||||
if size == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
pos := sort.Search(size, func(i int) bool {
|
||||
return data.wrong[i].password >= password
|
||||
})
|
||||
|
||||
return pos, pos < size &&
|
||||
data.wrong[pos].password == password
|
||||
}
|
85
data/storage_test.go
Normal file
85
data/storage_test.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package data
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
username1 = "Alice"
|
||||
username2 = "James"
|
||||
password1 = "master"
|
||||
password2 = "shadow"
|
||||
password3 = "qwerty"
|
||||
success = time.Second / 2
|
||||
wrong = time.Second / 5
|
||||
)
|
||||
|
||||
func printPassMap(t *testing.T, storage *Storage, prefix string) {
|
||||
buffer := bytes.Buffer{}
|
||||
first := true
|
||||
for k, v := range storage.passwords {
|
||||
if first {
|
||||
first = false
|
||||
} else {
|
||||
buffer.WriteByte(',')
|
||||
}
|
||||
correct := "<nil>"
|
||||
if v.correct != nil {
|
||||
correct = v.correct.password
|
||||
}
|
||||
fmt.Fprintf(&buffer, "%s:{correct:%s,wrong:%+v}", k, correct, v.wrong)
|
||||
}
|
||||
t.Logf("%s passwords: %s\n", prefix, buffer.String())
|
||||
}
|
||||
|
||||
func testCache(t *testing.T, storage *Storage, id int, username, password string, eok, efound bool) {
|
||||
ok, found := storage.Get(username, password)
|
||||
if ok != eok || found != efound {
|
||||
t.Errorf("Test %d expected (%v %v) given (%v %v)\n", id, eok, efound, ok, found)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordTimeout(t *testing.T) {
|
||||
storage := NewStorage(success, wrong)
|
||||
|
||||
testCache(t, storage, 0, username1, password1, false, false)
|
||||
testCache(t, storage, 1, username1, password2, false, false)
|
||||
testCache(t, storage, 2, username1, password3, false, false)
|
||||
testCache(t, storage, 3, username2, password1, false, false)
|
||||
testCache(t, storage, 4, username2, password2, false, false)
|
||||
testCache(t, storage, 5, username2, password3, false, false)
|
||||
|
||||
storage.Put(username1, password1, true)
|
||||
storage.Put(username1, password3, false)
|
||||
printPassMap(t, storage, "add")
|
||||
|
||||
testCache(t, storage, 6, username1, password1, true, true)
|
||||
testCache(t, storage, 7, username1, password2, false, false)
|
||||
testCache(t, storage, 8, username1, password3, false, true)
|
||||
testCache(t, storage, 9, username2, password1, false, false)
|
||||
testCache(t, storage, 10, username2, password2, false, false)
|
||||
testCache(t, storage, 11, username2, password3, false, false)
|
||||
|
||||
time.Sleep(wrong + wrong/2)
|
||||
printPassMap(t, storage, "timed")
|
||||
|
||||
testCache(t, storage, 12, username1, password1, true, true)
|
||||
testCache(t, storage, 13, username1, password2, false, false)
|
||||
testCache(t, storage, 14, username1, password3, false, false)
|
||||
testCache(t, storage, 15, username2, password1, false, false)
|
||||
testCache(t, storage, 16, username2, password2, false, false)
|
||||
testCache(t, storage, 17, username2, password3, false, false)
|
||||
|
||||
time.Sleep(success - wrong)
|
||||
printPassMap(t, storage, "expired")
|
||||
|
||||
testCache(t, storage, 18, username1, password1, false, false)
|
||||
testCache(t, storage, 19, username1, password2, false, false)
|
||||
testCache(t, storage, 20, username1, password3, false, false)
|
||||
testCache(t, storage, 21, username2, password1, false, false)
|
||||
testCache(t, storage, 22, username2, password2, false, false)
|
||||
testCache(t, storage, 23, username2, password3, false, false)
|
||||
}
|
39
group/service.go
Normal file
39
group/service.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package group
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/ldap"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
pool *ldap.Pool
|
||||
base string
|
||||
filter string
|
||||
attr string
|
||||
}
|
||||
|
||||
func NewService(pool *ldap.Pool, base, filter, attr string) *Service {
|
||||
return &Service{
|
||||
pool: pool,
|
||||
base: base,
|
||||
filter: filter,
|
||||
attr: attr,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Service) Find(id string) ([]string, error) {
|
||||
ok, _, groups, err := p.pool.Search(
|
||||
p.base,
|
||||
strings.Replace(p.filter, "{0}", id, -1),
|
||||
p.attr,
|
||||
)
|
||||
|
||||
if !ok && err != nil {
|
||||
return nil, err
|
||||
} else if err != nil {
|
||||
return []string{}, nil
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
2
k8s.yaml
2
k8s.yaml
|
@ -26,7 +26,7 @@ spec:
|
|||
app: nginx-ldap-auth
|
||||
spec:
|
||||
containers:
|
||||
- image: docker.io/tpimenta/nginx-ldap-auth:v1.0.0
|
||||
- image: docker.io/tpimenta/nginx-ldap-auth:v1.0.1
|
||||
name: nginx-ldap-auth
|
||||
command:
|
||||
- "nginx-ldap-auth"
|
||||
|
|
47
ldap/connect.go
Normal file
47
ldap/connect.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
ldap "gopkg.in/ldap.v2"
|
||||
)
|
||||
|
||||
func (p *Pool) Connect() error {
|
||||
if p.url == "" {
|
||||
return errors.New("No LDAP server available!")
|
||||
}
|
||||
|
||||
if p.port == 0 {
|
||||
return fmt.Errorf("Unable to determine schema or port for \"%s\"", p.url)
|
||||
}
|
||||
|
||||
if p.conn != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", p.url, p.port)
|
||||
if p.ssl {
|
||||
conn, err := ldap.DialTLS("tcp", address, &tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.conn = conn
|
||||
} else {
|
||||
conn, err := ldap.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
log.Printf("It was not possble to start TLS, falling back to plain: %v.\n", err)
|
||||
}
|
||||
p.conn = conn
|
||||
}
|
||||
|
||||
p.admin = false
|
||||
|
||||
return p.auth()
|
||||
}
|
36
ldap/login.go
Normal file
36
ldap/login.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package ldap
|
||||
|
||||
func (p *Pool) Validate(username, password string) (bool, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
err := p.auth()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
p.admin = false
|
||||
err = p.conn.Bind(username, password)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
err = p.auth()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *Pool) auth() error {
|
||||
if p.admin || p.username == "" && p.password == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.conn.Bind(p.username, p.password)
|
||||
if err == nil {
|
||||
p.admin = true
|
||||
}
|
||||
return err
|
||||
}
|
78
ldap/pool.go
Normal file
78
ldap/pool.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"log"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
ldap "gopkg.in/ldap.v2"
|
||||
)
|
||||
|
||||
type Pool struct {
|
||||
url string
|
||||
port int
|
||||
ssl bool
|
||||
username string
|
||||
password string
|
||||
conn *ldap.Conn
|
||||
admin bool
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
func NewPool(servers []string, username, password string) *Pool {
|
||||
url := ""
|
||||
port := 0
|
||||
schema := "auto"
|
||||
|
||||
size := len(servers)
|
||||
if size != 0 {
|
||||
r := rand.New(rand.NewSource(time.Now().Unix()))
|
||||
server := servers[r.Intn(size)]
|
||||
|
||||
url = server
|
||||
if strings.HasPrefix(url, "ldaps:") {
|
||||
url = strings.TrimPrefix(strings.TrimPrefix(url, "ldaps:"), "//")
|
||||
schema = "ldaps"
|
||||
port = 636
|
||||
} else if strings.HasPrefix(url, "ldap:") {
|
||||
url = strings.TrimPrefix(strings.TrimPrefix(url, "ldap:"), "//")
|
||||
schema = "ldap"
|
||||
port = 389
|
||||
}
|
||||
|
||||
portExp := regexp.MustCompile(`:[0-9]+$`)
|
||||
if portExp.MatchString(url) {
|
||||
str := portExp.FindString(url)
|
||||
|
||||
number, err := strconv.Atoi(str[1:])
|
||||
if err == nil {
|
||||
port = number
|
||||
url = strings.TrimSuffix(url, str)
|
||||
} else {
|
||||
log.Printf("Error on parse port of \"%s\": %v\n", server, err)
|
||||
}
|
||||
}
|
||||
|
||||
if schema == "auto" {
|
||||
if port == 636 {
|
||||
schema = "ldaps"
|
||||
} else if port == 389 {
|
||||
schema = "ldap"
|
||||
} else {
|
||||
port = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &Pool{
|
||||
url: url,
|
||||
port: port,
|
||||
ssl: schema == "ldaps",
|
||||
username: username,
|
||||
password: password,
|
||||
}
|
||||
}
|
57
ldap/search.go
Normal file
57
ldap/search.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package ldap
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
ldap "gopkg.in/ldap.v2"
|
||||
)
|
||||
|
||||
func (p *Pool) Search(base, filter string, attr string) (bool, string, []string, error) {
|
||||
p.lock.Lock()
|
||||
defer p.lock.Unlock()
|
||||
|
||||
err := p.auth()
|
||||
if err != nil {
|
||||
return false, "", nil, err
|
||||
}
|
||||
|
||||
var list []string = nil
|
||||
if attr != "" {
|
||||
list = []string{attr}
|
||||
}
|
||||
|
||||
res, err := p.conn.Search(ldap.NewSearchRequest(
|
||||
base,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
filter,
|
||||
list,
|
||||
nil,
|
||||
))
|
||||
if err != nil {
|
||||
return false, "", nil, err
|
||||
}
|
||||
|
||||
if len(res.Entries) == 0 {
|
||||
return true, "", nil, fmt.Errorf("No results for %s filter %s", base, filter)
|
||||
}
|
||||
|
||||
if attr == "" && len(res.Entries) > 1 {
|
||||
return true, "", nil, fmt.Errorf("Too many results for %s filter %s", base, filter)
|
||||
}
|
||||
|
||||
var result []string = nil
|
||||
if attr != "" {
|
||||
result = []string{}
|
||||
for _, entry := range res.Entries {
|
||||
result = append(result, entry.GetAttributeValue(attr))
|
||||
}
|
||||
sort.Strings(result)
|
||||
}
|
||||
|
||||
return true, res.Entries[0].DN, result, nil
|
||||
}
|
42
main/main.go
Normal file
42
main/main.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/data"
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/group"
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/ldap"
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/rule"
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/user"
|
||||
)
|
||||
|
||||
func main() {
|
||||
file, config, err := parseConfig()
|
||||
if err != nil {
|
||||
log.Fatalln(err.Error())
|
||||
}
|
||||
|
||||
fmt.Printf("Loaded config \"%s\".\n", file)
|
||||
|
||||
pool := ldap.NewPool(config.Servers, config.Auth.BindDN, config.Auth.BindPW)
|
||||
|
||||
err = pool.Connect()
|
||||
if err != nil {
|
||||
log.Fatalf("Error on connect to LDAP: %v\n", err)
|
||||
}
|
||||
|
||||
storage := data.NewStorage(config.Timeout.Success, config.Timeout.Wrong)
|
||||
|
||||
userService := user.NewService(pool, config.User.BaseDN, config.User.Filter)
|
||||
|
||||
groupService := group.NewService(pool, config.Group.BaseDN, config.Group.Filter, config.Group.GroupAttr)
|
||||
|
||||
ruleService := rule.NewService(storage, userService, groupService, config.User.RequiredGroups)
|
||||
|
||||
fmt.Printf("Serving...\n")
|
||||
err = startServer(ruleService, config.Web, config.Path, config.Message)
|
||||
if err != nil {
|
||||
log.Fatalf("Error on start server: %v\n", err)
|
||||
}
|
||||
}
|
45
main/parser.go
Normal file
45
main/parser.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"time"
|
||||
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
func parseConfig() (string, *Config, error) {
|
||||
file := flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
data, err := ioutil.ReadFile(*file)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("Error on read file \"%s\": %v", *file, err)
|
||||
}
|
||||
|
||||
config := Config{
|
||||
Web: "0.0.0.0:5555",
|
||||
Path: "/",
|
||||
Message: "LDAP Login",
|
||||
User: UserConfig{
|
||||
Filter: "(cn={0})",
|
||||
},
|
||||
Group: GroupConfig{
|
||||
Filter: "(member={0})",
|
||||
GroupAttr: "cn",
|
||||
},
|
||||
Timeout: TimeoutConfig{
|
||||
Success: 24 * time.Hour,
|
||||
Wrong: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("Error on parse config: %v", err)
|
||||
}
|
||||
|
||||
return *file, &config, nil
|
||||
}
|
42
main/server.go
Normal file
42
main/server.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/rule"
|
||||
)
|
||||
|
||||
func startServer(service *rule.Service, server, path, message string) error {
|
||||
realm := fmt.Sprintf("Basic realm=\"%s\"", message)
|
||||
|
||||
http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
|
||||
if header != "" {
|
||||
auth := strings.SplitN(header, " ", 2)
|
||||
|
||||
if len(auth) == 2 && auth[0] == "Basic" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(auth[1])
|
||||
if err == nil {
|
||||
secret := strings.SplitN(string(decoded), ":", 2)
|
||||
|
||||
if len(secret) == 2 && service.Validate(secret[0], secret[1]) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Printf("Error decode basic auth: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("WWW-Authenticate", realm)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
return http.ListenAndServe(server, nil)
|
||||
}
|
74
rule/service.go
Normal file
74
rule/service.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package rule
|
||||
|
||||
import (
|
||||
"log"
|
||||
"sort"
|
||||
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/data"
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/group"
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/user"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
storage *data.Storage
|
||||
user *user.Service
|
||||
group *group.Service
|
||||
required []string
|
||||
}
|
||||
|
||||
func NewService(storage *data.Storage, userService *user.Service, groupService *group.Service, required []string) *Service {
|
||||
return &Service{
|
||||
storage: storage,
|
||||
user: userService,
|
||||
group: groupService,
|
||||
required: required,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Service) Validate(username, password string) bool {
|
||||
ok, found := p.storage.Get(username, password)
|
||||
if found {
|
||||
return ok
|
||||
}
|
||||
|
||||
ok, err := p.validate(username, password)
|
||||
if err != nil {
|
||||
log.Printf("Could not validade user %s: %v\n", username, err)
|
||||
return false
|
||||
}
|
||||
|
||||
p.storage.Put(username, password, ok)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (p *Service) validate(username, password string) (bool, error) {
|
||||
ok, id, err := p.user.Find(username)
|
||||
if !ok && err != nil {
|
||||
return false, err
|
||||
} else if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
ok, err = p.user.Login(id, password)
|
||||
if !ok && err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if ok || p.required == nil || len(p.required) == 0 {
|
||||
return err != nil, nil
|
||||
}
|
||||
|
||||
groups, err := p.group.Find(id)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
for _, group := range p.required {
|
||||
pos := sort.SearchStrings(groups, group)
|
||||
if pos >= len(groups) || groups[pos] != group {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
181
src/ldap.go
181
src/ldap.go
|
@ -1,181 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
ldap "gopkg.in/ldap.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
conn *ldap.Conn
|
||||
admin bool
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
func setupLDAP() error {
|
||||
size := len(config.Servers)
|
||||
if size == 0 {
|
||||
return errors.New("No LDAP server available!")
|
||||
}
|
||||
|
||||
r := rand.New(rand.NewSource(time.Now().Unix()))
|
||||
portExp := regexp.MustCompile(`:[0-9]+$`)
|
||||
|
||||
server := config.Servers[r.Intn(size)]
|
||||
url := server
|
||||
port := 0
|
||||
schema := "auto"
|
||||
|
||||
if strings.HasPrefix(url, "ldaps:") {
|
||||
url = strings.TrimPrefix(strings.TrimPrefix(url, "ldaps:"), "//")
|
||||
schema = "ldaps"
|
||||
port = 636
|
||||
} else if strings.HasPrefix(url, "ldap:") {
|
||||
url = strings.TrimPrefix(strings.TrimPrefix(url, "ldap:"), "//")
|
||||
schema = "ldap"
|
||||
port = 389
|
||||
}
|
||||
|
||||
var err error
|
||||
if portExp.MatchString(url) {
|
||||
str := portExp.FindString(url)
|
||||
port, err = strconv.Atoi(str[1:])
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to parse url \"%s\", %v", server, err)
|
||||
}
|
||||
url = strings.TrimSuffix(url, str)
|
||||
}
|
||||
|
||||
if schema == "auto" {
|
||||
if port == 636 {
|
||||
schema = "ldaps"
|
||||
} else if port == 389 {
|
||||
schema = "ldap"
|
||||
}
|
||||
}
|
||||
|
||||
if schema == "auto" || port == 0 {
|
||||
return fmt.Errorf("Unable to determine schema or port for \"%s\"", server)
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", url, port)
|
||||
fmt.Printf("Connecting LDAP %s...\n", address)
|
||||
|
||||
if schema == "ldaps" {
|
||||
conn, err = ldap.DialTLS("tcp", address, &tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
conn, err = ldap.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = conn.StartTLS(&tls.Config{InsecureSkipVerify: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
admin = false
|
||||
return auth()
|
||||
}
|
||||
|
||||
func auth() error {
|
||||
if admin || config.Auth.BindDN == "" && config.Auth.BindPW == "" {
|
||||
return nil
|
||||
}
|
||||
err := conn.Bind(config.Auth.BindDN, config.Auth.BindPW)
|
||||
if err == nil {
|
||||
admin = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ldapLogin(username, password string) (bool, error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
err := auth()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
req := ldap.NewSearchRequest(
|
||||
config.User.BaseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
strings.Replace(config.User.Filter, "{0}", username, -1),
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
res, err := conn.Search(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(res.Entries) != 1 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
admin = false
|
||||
err = conn.Bind(res.Entries[0].DN, password)
|
||||
if err != nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = auth()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if len(config.User.RequiredGroups) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
req = ldap.NewSearchRequest(
|
||||
config.Group.BaseDN,
|
||||
ldap.ScopeWholeSubtree,
|
||||
ldap.NeverDerefAliases,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
strings.Replace(config.Group.Filter, "{0}", res.Entries[0].DN, -1),
|
||||
[]string{config.Group.GroupAttr},
|
||||
nil,
|
||||
)
|
||||
|
||||
res, err = conn.Search(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
groups := []string{}
|
||||
for _, entry := range res.Entries {
|
||||
groups = append(groups, entry.GetAttributeValue(config.Group.GroupAttr))
|
||||
}
|
||||
|
||||
sort.Strings(groups)
|
||||
|
||||
for _, group := range config.User.RequiredGroups {
|
||||
pos := sort.SearchStrings(groups, group)
|
||||
if pos >= len(groups) || groups[pos] != group {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
106
src/main.go
106
src/main.go
|
@ -1,106 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
var (
|
||||
configFile = flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file")
|
||||
config = Config{
|
||||
Web: "0.0.0.0:5555",
|
||||
Path: "/",
|
||||
Message: "LDAP Login",
|
||||
User: UserConfig{
|
||||
Filter: "(cn={0})",
|
||||
},
|
||||
Group: GroupConfig{
|
||||
Filter: "(member={0})",
|
||||
GroupAttr: "cn",
|
||||
},
|
||||
Timeout: TimeoutConfig{
|
||||
Success: 24 * time.Hour,
|
||||
Wrong: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
data, err := ioutil.ReadFile(*configFile)
|
||||
if err != nil {
|
||||
log.Fatalf("Error on read file \"%s\": %v\n", *configFile, err)
|
||||
}
|
||||
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
log.Fatalf("Error on parse config: %v\n", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Loaded config \"%s\".\n", *configFile)
|
||||
|
||||
err = setupLDAP()
|
||||
if err != nil {
|
||||
log.Fatalf("Error on connect to LDAP: %v\n", err)
|
||||
}
|
||||
|
||||
http.HandleFunc(config.Path, handler)
|
||||
|
||||
fmt.Printf("Serving...\n")
|
||||
err = http.ListenAndServe(config.Web, nil)
|
||||
if err != nil {
|
||||
log.Fatalf("Error on start server: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func handler(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
|
||||
if header != "" {
|
||||
auth := strings.SplitN(header, " ", 2)
|
||||
|
||||
if len(auth) == 2 && auth[0] == "Basic" {
|
||||
decoded, err := base64.StdEncoding.DecodeString(auth[1])
|
||||
if err == nil {
|
||||
secret := strings.SplitN(string(decoded), ":", 2)
|
||||
|
||||
if len(secret) == 2 && validate(secret[0], secret[1]) {
|
||||
// TODO: match by header, e.g: X-Original-URL X-Original-Method X-Sent-From X-Auth-Request-Redirect
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
log.Printf("Error decode basic auth: %v\n", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=\"%s\"", config.Message))
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func validate(username, password string) bool {
|
||||
ok, found := getCache(username, password)
|
||||
if found {
|
||||
return ok
|
||||
}
|
||||
|
||||
ok, err := ldapLogin(username, password)
|
||||
if err != nil {
|
||||
log.Printf("Could not validade user %s: %v\n", username, err)
|
||||
return false
|
||||
}
|
||||
|
||||
putCache(username, password, ok)
|
||||
return ok
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
username1 = "Alice"
|
||||
username2 = "James"
|
||||
password1 = "master"
|
||||
password2 = "shadow"
|
||||
password3 = "qwerty"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
config.Timeout.Success = time.Second / 2
|
||||
config.Timeout.Wrong = time.Second / 5
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func printPassMap(t *testing.T, prefix string) {
|
||||
buffer := bytes.Buffer{}
|
||||
first := true
|
||||
for k, v := range passwords {
|
||||
if first {
|
||||
first = false
|
||||
} else {
|
||||
buffer.WriteByte(',')
|
||||
}
|
||||
correct := "<nil>"
|
||||
if v.correct != nil {
|
||||
correct = v.correct.password
|
||||
}
|
||||
fmt.Fprintf(&buffer, "%s:{correct:%s,wrong:%+v}", k, correct, v.wrong)
|
||||
}
|
||||
t.Logf("%s passwords: %s\n", prefix, buffer.String())
|
||||
}
|
||||
|
||||
func testCache(t *testing.T, id int, username, password string, eok, efound bool) {
|
||||
ok, found := getCache(username, password)
|
||||
if ok != eok || found != efound {
|
||||
t.Errorf("Test %d expected (%v %v) given (%v %v)\n", id, eok, efound, ok, found)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordTimeout(t *testing.T) {
|
||||
testCache(t, 0, username1, password1, false, false)
|
||||
testCache(t, 1, username1, password2, false, false)
|
||||
testCache(t, 2, username1, password3, false, false)
|
||||
testCache(t, 3, username2, password1, false, false)
|
||||
testCache(t, 4, username2, password2, false, false)
|
||||
testCache(t, 5, username2, password3, false, false)
|
||||
|
||||
putCache(username1, password1, true)
|
||||
putCache(username1, password3, false)
|
||||
printPassMap(t, "add")
|
||||
|
||||
testCache(t, 6, username1, password1, true, true)
|
||||
testCache(t, 7, username1, password2, false, false)
|
||||
testCache(t, 8, username1, password3, false, true)
|
||||
testCache(t, 9, username2, password1, false, false)
|
||||
testCache(t, 10, username2, password2, false, false)
|
||||
testCache(t, 11, username2, password3, false, false)
|
||||
|
||||
time.Sleep(config.Timeout.Wrong + config.Timeout.Wrong/2)
|
||||
printPassMap(t, "timed")
|
||||
|
||||
testCache(t, 12, username1, password1, true, true)
|
||||
testCache(t, 13, username1, password2, false, false)
|
||||
testCache(t, 14, username1, password3, false, false)
|
||||
testCache(t, 15, username2, password1, false, false)
|
||||
testCache(t, 16, username2, password2, false, false)
|
||||
testCache(t, 17, username2, password3, false, false)
|
||||
|
||||
time.Sleep(config.Timeout.Success - config.Timeout.Wrong)
|
||||
printPassMap(t, "expired")
|
||||
|
||||
testCache(t, 18, username1, password1, false, false)
|
||||
testCache(t, 19, username1, password2, false, false)
|
||||
testCache(t, 20, username1, password3, false, false)
|
||||
testCache(t, 21, username2, password1, false, false)
|
||||
testCache(t, 22, username2, password2, false, false)
|
||||
testCache(t, 23, username2, password3, false, false)
|
||||
}
|
35
user/service.go
Normal file
35
user/service.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tiagoapimenta/nginx-ldap-auth/ldap"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
pool *ldap.Pool
|
||||
base string
|
||||
filter string
|
||||
}
|
||||
|
||||
func NewService(pool *ldap.Pool, base, filter string) *Service {
|
||||
return &Service{
|
||||
pool: pool,
|
||||
base: base,
|
||||
filter: filter,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Service) Find(username string) (bool, string, error) {
|
||||
ok, id, _, err := p.pool.Search(
|
||||
p.base,
|
||||
strings.Replace(p.filter, "{0}", username, -1),
|
||||
"",
|
||||
)
|
||||
|
||||
return ok, id, err
|
||||
}
|
||||
|
||||
func (p *Service) Login(id, password string) (bool, error) {
|
||||
return p.pool.Validate(id, password)
|
||||
}
|
Reference in a new issue