Check timeout with test

This commit is contained in:
Tiago Augusto Pimenta 2018-09-15 17:43:44 -03:00
parent 134c5508c1
commit a1228eddc3
3 changed files with 111 additions and 20 deletions

View file

@ -9,8 +9,10 @@ import (
yaml "gopkg.in/yaml.v2"
)
var configFile = flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file")
var config Config
var (
configFile = flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file")
config Config
)
func main() {
flag.Parse()

View file

@ -17,7 +17,7 @@ type userpass struct {
}
var (
passwords map[string]userpass
passwords = map[string]*userpass{}
mutex = sync.RWMutex{}
)
@ -28,7 +28,7 @@ func containsWrongPassword(data *userpass, password string) (int, bool) {
}
pos := sort.Search(size, func(i int) bool {
return data.wrong[i].password < password
return data.wrong[i].password >= password
})
return pos, pos < size &&
@ -48,7 +48,7 @@ func getCache(username, password string) (bool, bool) {
return true, true
}
_, found = containsWrongPassword(&data, password)
_, found = containsWrongPassword(data, password)
return false, found
}
@ -59,28 +59,29 @@ func putCache(username, password string, ok bool) {
data, found := passwords[username]
if !found {
data = userpass{}
data = &userpass{}
passwords[username] = data
}
timeout := config.Timeout.Wrong
if ok {
timeout = config.Timeout.Success
}
pass := passtimer{
password: password,
timer: time.AfterFunc(timeout, func() {
removeCache(username, password, ok)
}),
}
if ok {
if data.correct != nil {
data.correct.timer.Stop()
}
data.correct = &passtimer{
password: password,
timer: time.AfterFunc(config.Timeout.Success, func() {
removeCache(username, "", true)
}),
}
data.correct = &pass
} else {
pass := passtimer{
password: password,
timer: time.AfterFunc(config.Timeout.Wrong, func() {
removeCache(username, password, false)
}),
}
pos, found := containsWrongPassword(&data, password)
pos, found := containsWrongPassword(data, password)
if found {
data.wrong[pos].timer.Stop()
} else {
@ -106,7 +107,7 @@ func removeCache(username, password string, ok bool) {
data.correct = nil
}
} else {
pos, found := containsWrongPassword(&data, password)
pos, found := containsWrongPassword(data, password)
if found {
data.wrong[pos].timer.Stop()
data.wrong = data.wrong[:pos+copy(data.wrong[pos:], data.wrong[pos+1:])]

88
src/timeout_test.go Normal file
View file

@ -0,0 +1,88 @@
package main
import (
"bytes"
"fmt"
"os"
"testing"
"time"
)
const (
username1 = "Alice"
username2 = "James"
password1 = "master"
password2 = "shadow"
password3 = "qwerty"
)
func TestMain(m *testing.M) {
config.Timeout.Success = time.Second / 2
config.Timeout.Wrong = time.Second / 5
os.Exit(m.Run())
}
func printPassMap(t *testing.T, prefix string) {
buffer := bytes.Buffer{}
first := true
for k, v := range passwords {
if first {
first = false
} else {
buffer.WriteByte(',')
}
correct := "<nil>"
if v.correct != nil {
correct = v.correct.password
}
fmt.Fprintf(&buffer, "%s:{correct:%s,wrong:%+v}", k, correct, v.wrong)
}
t.Logf("%s passwords: %s\n", prefix, buffer.String())
}
func testCache(t *testing.T, id int, username, password string, eok, efound bool) {
ok, found := getCache(username, password)
if ok != eok || found != efound {
t.Errorf("Test %d expected (%v %v) given (%v %v)\n", id, eok, efound, ok, found)
}
}
func TestPasswordTimeout(t *testing.T) {
testCache(t, 0, username1, password1, false, false)
testCache(t, 1, username1, password2, false, false)
testCache(t, 2, username1, password3, false, false)
testCache(t, 3, username2, password1, false, false)
testCache(t, 4, username2, password2, false, false)
testCache(t, 5, username2, password3, false, false)
putCache(username1, password1, true)
putCache(username1, password3, false)
printPassMap(t, "add")
testCache(t, 6, username1, password1, true, true)
testCache(t, 7, username1, password2, false, false)
testCache(t, 8, username1, password3, false, true)
testCache(t, 9, username2, password1, false, false)
testCache(t, 10, username2, password2, false, false)
testCache(t, 11, username2, password3, false, false)
time.Sleep(config.Timeout.Wrong + config.Timeout.Wrong/2)
printPassMap(t, "timed")
testCache(t, 12, username1, password1, true, true)
testCache(t, 13, username1, password2, false, false)
testCache(t, 14, username1, password3, false, false)
testCache(t, 15, username2, password1, false, false)
testCache(t, 16, username2, password2, false, false)
testCache(t, 17, username2, password3, false, false)
time.Sleep(config.Timeout.Success - config.Timeout.Wrong)
printPassMap(t, "expired")
testCache(t, 18, username1, password1, false, false)
testCache(t, 19, username1, password2, false, false)
testCache(t, 20, username1, password3, false, false)
testCache(t, 21, username2, password1, false, false)
testCache(t, 22, username2, password2, false, false)
testCache(t, 23, username2, password3, false, false)
}