From a1228eddc3a932a1488518889cbe6353d2960021 Mon Sep 17 00:00:00 2001 From: Tiago Augusto Pimenta Date: Sat, 15 Sep 2018 17:43:44 -0300 Subject: [PATCH] Check timeout with test --- src/main.go | 6 ++-- src/timeout.go | 37 +++++++++---------- src/timeout_test.go | 88 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 20 deletions(-) create mode 100644 src/timeout_test.go diff --git a/src/main.go b/src/main.go index 0ae8bb3..068f374 100644 --- a/src/main.go +++ b/src/main.go @@ -9,8 +9,10 @@ import ( yaml "gopkg.in/yaml.v2" ) -var configFile = flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file") -var config Config +var ( + configFile = flag.String("config", "/etc/nginx-ldap-auth/config.yaml", "Configuration file") + config Config +) func main() { flag.Parse() diff --git a/src/timeout.go b/src/timeout.go index c49fe2c..432efba 100644 --- a/src/timeout.go +++ b/src/timeout.go @@ -17,7 +17,7 @@ type userpass struct { } var ( - passwords map[string]userpass + passwords = map[string]*userpass{} mutex = sync.RWMutex{} ) @@ -28,7 +28,7 @@ func containsWrongPassword(data *userpass, password string) (int, bool) { } pos := sort.Search(size, func(i int) bool { - return data.wrong[i].password < password + return data.wrong[i].password >= password }) return pos, pos < size && @@ -48,7 +48,7 @@ func getCache(username, password string) (bool, bool) { return true, true } - _, found = containsWrongPassword(&data, password) + _, found = containsWrongPassword(data, password) return false, found } @@ -59,28 +59,29 @@ func putCache(username, password string, ok bool) { data, found := passwords[username] if !found { - data = userpass{} + data = &userpass{} passwords[username] = data } + timeout := config.Timeout.Wrong + if ok { + timeout = config.Timeout.Success + } + + pass := passtimer{ + password: password, + timer: time.AfterFunc(timeout, func() { + removeCache(username, password, ok) + }), + } + if ok { if data.correct != nil { data.correct.timer.Stop() } - data.correct = &passtimer{ - password: password, - timer: time.AfterFunc(config.Timeout.Success, func() { - removeCache(username, "", true) - }), - } + data.correct = &pass } else { - pass := passtimer{ - password: password, - timer: time.AfterFunc(config.Timeout.Wrong, func() { - removeCache(username, password, false) - }), - } - pos, found := containsWrongPassword(&data, password) + pos, found := containsWrongPassword(data, password) if found { data.wrong[pos].timer.Stop() } else { @@ -106,7 +107,7 @@ func removeCache(username, password string, ok bool) { data.correct = nil } } else { - pos, found := containsWrongPassword(&data, password) + pos, found := containsWrongPassword(data, password) if found { data.wrong[pos].timer.Stop() data.wrong = data.wrong[:pos+copy(data.wrong[pos:], data.wrong[pos+1:])] diff --git a/src/timeout_test.go b/src/timeout_test.go new file mode 100644 index 0000000..daa07cf --- /dev/null +++ b/src/timeout_test.go @@ -0,0 +1,88 @@ +package main + +import ( + "bytes" + "fmt" + "os" + "testing" + "time" +) + +const ( + username1 = "Alice" + username2 = "James" + password1 = "master" + password2 = "shadow" + password3 = "qwerty" +) + +func TestMain(m *testing.M) { + config.Timeout.Success = time.Second / 2 + config.Timeout.Wrong = time.Second / 5 + os.Exit(m.Run()) +} + +func printPassMap(t *testing.T, prefix string) { + buffer := bytes.Buffer{} + first := true + for k, v := range passwords { + if first { + first = false + } else { + buffer.WriteByte(',') + } + correct := "" + if v.correct != nil { + correct = v.correct.password + } + fmt.Fprintf(&buffer, "%s:{correct:%s,wrong:%+v}", k, correct, v.wrong) + } + t.Logf("%s passwords: %s\n", prefix, buffer.String()) +} + +func testCache(t *testing.T, id int, username, password string, eok, efound bool) { + ok, found := getCache(username, password) + if ok != eok || found != efound { + t.Errorf("Test %d expected (%v %v) given (%v %v)\n", id, eok, efound, ok, found) + } +} + +func TestPasswordTimeout(t *testing.T) { + testCache(t, 0, username1, password1, false, false) + testCache(t, 1, username1, password2, false, false) + testCache(t, 2, username1, password3, false, false) + testCache(t, 3, username2, password1, false, false) + testCache(t, 4, username2, password2, false, false) + testCache(t, 5, username2, password3, false, false) + + putCache(username1, password1, true) + putCache(username1, password3, false) + printPassMap(t, "add") + + testCache(t, 6, username1, password1, true, true) + testCache(t, 7, username1, password2, false, false) + testCache(t, 8, username1, password3, false, true) + testCache(t, 9, username2, password1, false, false) + testCache(t, 10, username2, password2, false, false) + testCache(t, 11, username2, password3, false, false) + + time.Sleep(config.Timeout.Wrong + config.Timeout.Wrong/2) + printPassMap(t, "timed") + + testCache(t, 12, username1, password1, true, true) + testCache(t, 13, username1, password2, false, false) + testCache(t, 14, username1, password3, false, false) + testCache(t, 15, username2, password1, false, false) + testCache(t, 16, username2, password2, false, false) + testCache(t, 17, username2, password3, false, false) + + time.Sleep(config.Timeout.Success - config.Timeout.Wrong) + printPassMap(t, "expired") + + testCache(t, 18, username1, password1, false, false) + testCache(t, 19, username1, password2, false, false) + testCache(t, 20, username1, password3, false, false) + testCache(t, 21, username2, password1, false, false) + testCache(t, 22, username2, password2, false, false) + testCache(t, 23, username2, password3, false, false) +}