diff --git a/main.go b/main.go index 385ecf7..e316da7 100644 --- a/main.go +++ b/main.go @@ -12,13 +12,23 @@ import ( const burstLimit = 2 func main() { + ctx := context.Background() throttles := make(map[string]<-chan time.Time) calledLastHour := make([]string, 0) http.HandleFunc("/", serveIndexFile) http.HandleFunc("/input.txt", serveInputFile) http.HandleFunc("/result", func(w http.ResponseWriter, r *http.Request) { - tryResult(w, r, throttles, calledLastHour) + + ctx, cancel := context.WithCancel(ctx) + throttle := throttling.CreateThrottle(ctx, burstLimit) + a := throttleWithCancel{ + throttle: throttle, + cancel: cancel, + } + con := context.WithValue(ctx, "throttle", a) + + tryResult(ctx, w, r, throttles, calledLastHour) }) err := http.ListenAndServe(":3333", nil) @@ -27,6 +37,11 @@ func main() { } } +type throttleWithCancel struct { + throttle <-chan time.Time + cancel context.CancelFunc +} + func serveIndexFile(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "static/index-with-text.html") } @@ -36,17 +51,17 @@ func serveInputFile(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, "static/input.txt") } -func tryResult(w http.ResponseWriter, r *http.Request, throttles map[string]<-chan time.Time, calledLastHour []string) { +func tryResult(ctx context.Context, w http.ResponseWriter, r *http.Request, throttles map[string]<-chan time.Time, calledLastHour []string) { clientIP := r.RemoteAddr clientResult := r.URL.Query().Get("result") fmt.Println(clientIP, clientResult) - if slices.Contains(calledLastHour, clientIP) { + if !slices.Contains(calledLastHour, clientIP) { calledLastHour = append(calledLastHour, clientIP) } throttle, ok := throttles[clientIP] if !ok { fmt.Println("Creating new throttle") - throttle = throttling.CreateThrottle(r.Context(), burstLimit) + throttle = throttling.CreateThrottle(ctx, burstLimit) throttles[clientIP] = throttle } payload := throttling.Payload{ @@ -55,7 +70,7 @@ func tryResult(w http.ResponseWriter, r *http.Request, throttles map[string]<-ch ClientResult: clientResult, } - throttling.CallFunction(context.TODO(), &CheckResult{}, &payload, throttle) + throttling.CallFunction(ctx, &CheckResult{}, &payload, throttle) } type CheckResult struct { diff --git a/throttling/rateLimit.go b/throttling/rateLimit.go index 6a10de2..3e9c610 100644 --- a/throttling/rateLimit.go +++ b/throttling/rateLimit.go @@ -24,8 +24,8 @@ type Payload struct { // CallFunction allows burst rate limiting client calls with the // payloads. func CallFunction(ctx context.Context, client Client, payload *Payload, throttle <-chan time.Time) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() + //ctx, cancel := context.WithCancel(ctx) + // defer cancel() <-throttle // rate limit our client calls client.Call(payload) @@ -43,11 +43,18 @@ func CreateThrottle(ctx context.Context, burstLimit int) <-chan time.Time { for t := range ticker.C { select { case throttle <- t: + { + fmt.Println("Add bucket to throttle") + } case <-ctx.Done(): { fmt.Println("Ticker done") return // exit goroutine when surrounding function returns } + default: + { + fmt.Println("Dropping bucket") + } } } }()