internal/honeypot/http/http_test.go

package http

import (
	"net/http"
	"net/http/httptest"
	"testing"
)

func TestGetRemoteAddr(t *testing.T) {
	h := &httpHoneypot{}

	tests := []struct {
		name       string
		remoteAddr string
		headers    map[string]string
		wantHost   string
	}{
		{
			name:       "Normal remote IP",
			remoteAddr: "1.2.3.4:12345",
			wantHost:   "1.2.3.4",
		},
		{
			name:       "Localhost without X-Real-Ip",
			remoteAddr: "127.0.0.1:12345",
			wantHost:   "127.0.0.1",
		},
		{
			name:       "Localhost (IPv6) without X-Real-Ip",
			remoteAddr: "[::1]:12345",
			wantHost:   "::1",
		},
		{
			name:       "Localhost with X-Real-Ip",
			remoteAddr: "127.0.0.1:12345",
			headers:    map[string]string{"X-Real-Ip": "5.6.7.8"},
			wantHost:   "5.6.7.8",
		},
		{
			name:       "Localhost (IPv6) with X-Real-Ip",
			remoteAddr: "[::1]:12345",
			headers:    map[string]string{"X-Real-Ip": "9.10.11.12"},
			wantHost:   "9.10.11.12",
		},
		{
			name:       "Non-localhost with X-Real-Ip (ignored)",
			remoteAddr: "1.1.1.1:12345",
			headers:    map[string]string{"X-Real-Ip": "2.2.2.2"},
			wantHost:   "1.1.1.1",
		},
		{
			name:       "Localhost with X-Forwarded-For",
			remoteAddr: "127.0.0.1:12345",
			headers:    map[string]string{"X-Forwarded-For": "5.6.7.8"},
			wantHost:   "5.6.7.8",
		},
		{
			name:       "Localhost with multiple X-Forwarded-For",
			remoteAddr: "127.0.0.1:12345",
			headers:    map[string]string{"X-Forwarded-For": "5.6.7.8, 10.0.0.1"},
			wantHost:   "5.6.7.8",
		},
		{
			name:       "Localhost with X-Real-Ip and X-Forwarded-For (X-Real-Ip wins)",
			remoteAddr: "127.0.0.1:12345",
			headers:    map[string]string{"X-Real-Ip": "5.6.7.8", "X-Forwarded-For": "1.1.1.1"},
			wantHost:   "5.6.7.8",
		},
		{
			name:       "Non-localhost with X-Forwarded-For (ignored)",
			remoteAddr: "1.1.1.1:12345",
			headers:    map[string]string{"X-Forwarded-For": "2.2.2.2"},
			wantHost:   "1.1.1.1",
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			req := httptest.NewRequest("GET", "/", nil)
			req.RemoteAddr = tt.remoteAddr
			for k, v := range tt.headers {
				req.Header.Set(k, v)
			}

			gotHost, _ := h.getRemoteAddr(req)
			if gotHost != tt.wantHost {
				t.Errorf("getRemoteAddr() host = %v, want %v", gotHost, tt.wantHost)
			}
		})
	}
}

func TestPortMiddleware(t *testing.T) {
	h := &httpHoneypot{}

	tests := []struct {
		name        string
		remoteAddr  string
		headers     map[string]string
		listenPort  uint16
		wantDstPort uint16
	}{
		{
			name:        "Normal request",
			remoteAddr:  "1.2.3.4:12345",
			listenPort:  8080,
			wantDstPort: 8080,
		},
		{
			name:        "Localhost request without header",
			remoteAddr:  "127.0.0.1:12345",
			listenPort:  8080,
			wantDstPort: 8080,
		},
		{
			name:        "Localhost request with X-Forwarded-Proto https",
			remoteAddr:  "127.0.0.1:12345",
			headers:     map[string]string{"X-Forwarded-Proto": "https"},
			listenPort:  8080,
			wantDstPort: 443,
		},
		{
			name:        "Localhost request with X-Forwarded-Proto http",
			remoteAddr:  "127.0.0.1:12345",
			headers:     map[string]string{"X-Forwarded-Proto": "http"},
			listenPort:  8080,
			wantDstPort: 80,
		},
		{
			name:        "IPv6 Localhost request with X-Forwarded-Proto https",
			remoteAddr:  "[::1]:12345",
			headers:     map[string]string{"X-Forwarded-Proto": "https"},
			listenPort:  8080,
			wantDstPort: 443,
		},
		{
			name:        "Non-localhost request with X-Forwarded-Proto ignored",
			remoteAddr:  "1.2.3.4:12345",
			headers:     map[string]string{"X-Forwarded-Proto": "https"},
			listenPort:  8080,
			wantDstPort: 8080,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			req := httptest.NewRequest("GET", "/", nil)
			req.RemoteAddr = tt.remoteAddr
			for k, v := range tt.headers {
				req.Header.Set(k, v)
			}

			handler := h.portMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				gotPort, ok := r.Context().Value(dstPortKey).(uint16)
				if !ok {
					t.Fatal("dstPortKey not found in context")
				}
				if gotPort != tt.wantDstPort {
					t.Errorf("portMiddleware context port = %v, want %v", gotPort, tt.wantDstPort)
				}
			}), tt.listenPort)

			handler.ServeHTTP(httptest.NewRecorder(), req)
		})
	}
}