internal/database/querybuilder.go

package database

import (
	"encoding/binary"
	"fmt"
	"net"
	"slices"
	"strconv"
	"strings"
	"time"
)

func HasGeoFields(query EventQuery) bool {
	return len(query.ASNs) > 0 || len(query.Countries) > 0 || len(query.Cities) > 0 || len(query.Domains) > 0 || len(query.FQDNs) > 0 ||
		slices.Contains(query.Columns, "country") || slices.Contains(query.Columns, "city") ||
		slices.Contains(query.Columns, "latitude") || slices.Contains(query.Columns, "longitude")
}

func buildQueryString(query EventQuery, total bool) (string, []any) {
	hasGeo := HasGeoFields(query)
	queryString := buildSelectClause(query, total, hasGeo)
	whereClauses, whereArgs := buildWhereClauses(query, hasGeo)

	if len(whereClauses) > 0 {
		queryString += " WHERE " + strings.Join(whereClauses, " AND ")
	}

	if !total {
		queryString += buildPaginationClause(query)
	}

	return queryString, whereArgs
}

func buildSelectClause(query EventQuery, total bool, hasGeo bool) string {
	if total {
		if hasGeo {
			return "SELECT COUNT(*) FROM honeypot_events JOIN ips ON honeypot_events.remote_addr = ips.ip"
		}
		return "SELECT COUNT(*) FROM honeypot_events"
	}

	allowedColumns := []string{"id", "time", "type", "event", "remote_addr", "remote_port", "dst_port", "fields", "country", "city", "latitude", "longitude"}
	validColumns := make([]string, 0, len(query.Columns))
	for _, col := range query.Columns {
		if slices.Contains(allowedColumns, col) {
			validColumns = append(validColumns, col)
		}
	}

	columns := strings.Join(validColumns, ", ")
	if columns == "" {
		if hasGeo {
			columns = "honeypot_events.id, honeypot_events.time, honeypot_events.type, honeypot_events.event, honeypot_events.remote_addr, honeypot_events.remote_port, honeypot_events.dst_port, honeypot_events.fields, ips.country, ips.city, ips.latitude, ips.longitude"
		} else {
			columns = "id, time, type, event, remote_addr, remote_port, dst_port, fields"
		}
	}

	if hasGeo {
		return "SELECT " + columns + " FROM honeypot_events JOIN ips ON honeypot_events.remote_addr = ips.ip"
	}

	return "SELECT " + columns + " FROM honeypot_events"
}

func buildWhereClauses(query EventQuery, hasGeo bool) ([]string, []any) {
	clauses := make([]string, 0)
	args := make([]any, 0)

	idField := "id"
	typeField := "type"
	eventField := "event"
	remoteAddrField := "remote_addr"
	remotePortField := "remote_port"
	dstPortField := "dst_port"

	if hasGeo {
		idField = "honeypot_events.id"
		typeField = "honeypot_events.type"
		eventField = "honeypot_events.event"
		remoteAddrField = "honeypot_events.remote_addr"
		remotePortField = "honeypot_events.remote_port"
		dstPortField = "honeypot_events.dst_port"
	}

	// Time range filtering
	addTimeRangeClauses(&clauses, &args, query.TimeStart, query.TimeEnd)

	// IN/NOT IN filtering
	addInNotInClauses(&clauses, &args, idField, query.IDs)
	addAddrClauses(&clauses, &args, query.RemoteAddrs, remoteAddrField)
	addPortClauses(&clauses, &args, query.RemotePorts, remotePortField)
	addPortClauses(&clauses, &args, query.DstPorts, dstPortField)
	addInNotInClauses(&clauses, &args, typeField, query.Type)
	addInNotInClauses(&clauses, &args, eventField, query.Event)

	// Geo filtering
	addInNotInClauses(&clauses, &args, "ips.asn", query.ASNs)
	addInNotInClauses(&clauses, &args, "ips.country", query.Countries)
	addInNotInClauses(&clauses, &args, "ips.city", query.Cities)
	addInNotInClauses(&clauses, &args, "ips.domain", query.Domains)
	addInNotInClauses(&clauses, &args, "ips.fqdn", query.FQDNs)

	// JSON field filtering
	addJSONFieldClauses(&clauses, &args, query.FieldFilters)
	addJSONFieldExistsClauses(&clauses, &args, query.FieldExists)

	return clauses, args
}

func addInNotInClauses(clauses *[]string, args *[]any, field string, values []string) {
	if len(values) == 0 {
		return
	}

	var inVals, notInVals []string
	var inGlobVals, notInGlobVals []string

	for _, v := range values {
		exclude := false
		if strings.HasPrefix(v, "!") && len(v) > 1 {
			exclude = true
			v = v[1:]
		}

		hasWildcard := strings.ContainsAny(v, "*?")

		if exclude {
			if hasWildcard {
				notInGlobVals = append(notInGlobVals, v)
			} else {
				notInVals = append(notInVals, v)
			}
		} else {
			if hasWildcard {
				inGlobVals = append(inGlobVals, v)
			} else {
				inVals = append(inVals, v)
			}
		}
	}

	if len(inVals) > 0 || len(inGlobVals) > 0 {
		var parts []string
		if len(inVals) > 0 {
			placeholders := strings.TrimRight(strings.Repeat("?, ", len(inVals)), ", ")
			parts = append(parts, fmt.Sprintf("%s IN (%s)", field, placeholders))
			*args = append(*args, sliceToAny(inVals)...)
		}
		for _, g := range inGlobVals {
			parts = append(parts, fmt.Sprintf("%s GLOB ?", field))
			*args = append(*args, g)
		}

		if len(parts) > 1 {
			*clauses = append(*clauses, "("+strings.Join(parts, " OR ")+")")
		} else {
			*clauses = append(*clauses, parts[0])
		}
	}

	if len(notInVals) > 0 || len(notInGlobVals) > 0 {
		if len(notInVals) > 0 {
			placeholders := strings.TrimRight(strings.Repeat("?, ", len(notInVals)), ", ")
			*clauses = append(*clauses, fmt.Sprintf("%s NOT IN (%s)", field, placeholders))
			*args = append(*args, sliceToAny(notInVals)...)
		}
		for _, g := range notInGlobVals {
			*clauses = append(*clauses, fmt.Sprintf("%s NOT GLOB ?", field))
			*args = append(*args, g)
		}
	}
}

func addAddrClauses(clauses *[]string, args *[]any, values []string, field string) {
	if len(values) == 0 {
		return
	}

	var inParts []string
	var notInParts []string
	var inArgs []any
	var notInArgs []any

	for _, v := range values {
		exclude := false
		if strings.HasPrefix(v, "!") && len(v) > 1 {
			exclude = true
			v = v[1:]
		}

		if strings.Contains(v, "/") {
			// Subnet matching
			_, ipNet, err := net.ParseCIDR(v)
			if err == nil && ipNet.IP.To4() != nil {
				// IPv4 Subnet
				mask := binary.BigEndian.Uint32(ipNet.Mask)
				start := binary.BigEndian.Uint32(ipNet.IP.To4())
				end := start | ^mask
				intField := field + "_int"
				if field == "remote_addr" || field == "honeypot_events.remote_addr" {
					intField = "remote_ip_int"
					if field == "honeypot_events.remote_addr" {
						intField = "honeypot_events.remote_ip_int"
					}
				}
				if exclude {
					notInParts = append(notInParts, fmt.Sprintf("NOT (%s >= ? AND %s <= ?)", intField, intField))
					notInArgs = append(notInArgs, start, end)
				} else {
					inParts = append(inParts, fmt.Sprintf("(%s >= ? AND %s <= ?)", intField, intField))
					inArgs = append(inArgs, start, end)
				}
			} else {
				// Fallback to string matching for IPv6 or if parsing failed
				if exclude {
					notInParts = append(notInParts, fmt.Sprintf("NOT %s::INET <<= ?::INET", field))
					notInArgs = append(notInArgs, v)
				} else {
					inParts = append(inParts, fmt.Sprintf("%s::INET <<= ?::INET", field))
					inArgs = append(inArgs, v)
				}
			}
		} else {
			// Single address
			ip := net.ParseIP(v)
			if ip != nil && ip.To4() != nil {
				// IPv4 address
				ipInt := binary.BigEndian.Uint32(ip.To4())
				intField := field + "_int"
				if field == "remote_addr" || field == "honeypot_events.remote_addr" {
					intField = "remote_ip_int"
					if field == "honeypot_events.remote_addr" {
						intField = "honeypot_events.remote_ip_int"
					}
				}
				if exclude {
					notInParts = append(notInParts, fmt.Sprintf("%s != ?", intField))
					notInArgs = append(notInArgs, ipInt)
				} else {
					inParts = append(inParts, fmt.Sprintf("%s = ?", intField))
					inArgs = append(inArgs, ipInt)
				}
			} else {
				// Fallback to string matching for IPv6 or if parsing failed
				if exclude {
					notInParts = append(notInParts, fmt.Sprintf("%s != ?", field))
					notInArgs = append(notInArgs, v)
				} else {
					inParts = append(inParts, fmt.Sprintf("%s = ?", field))
					inArgs = append(inArgs, v)
				}
			}
		}
	}

	if len(inParts) > 0 {
		*clauses = append(*clauses, "("+strings.Join(inParts, " OR ")+")")
		*args = append(*args, inArgs...)
	}
	if len(notInParts) > 0 {
		*clauses = append(*clauses, "("+strings.Join(notInParts, " AND ")+")")
		*args = append(*args, notInArgs...)
	}
}

func addPortClauses(clauses *[]string, args *[]any, values []string, field string) {
	if len(values) == 0 {
		return
	}

	var inParts []string
	var notInParts []string
	var inArgs []any
	var notInArgs []any

	for _, v := range values {
		exclude := false
		if strings.HasPrefix(v, "!") && len(v) > 1 {
			exclude = true
			v = v[1:]
		}

		if strings.Contains(v, "-") {
			parts := strings.Split(v, "-")
			if len(parts) == 2 {
				start, err1 := strconv.Atoi(strings.TrimSpace(parts[0]))
				end, err2 := strconv.Atoi(strings.TrimSpace(parts[1]))
				if err1 == nil && err2 == nil {
					if exclude {
						notInParts = append(notInParts, fmt.Sprintf("%s NOT BETWEEN ? AND ?", field))
						notInArgs = append(notInArgs, start, end)
					} else {
						inParts = append(inParts, fmt.Sprintf("%s BETWEEN ? AND ?", field))
						inArgs = append(inArgs, start, end)
					}
					continue
				}
			}
		}

		// Single port
		port, err := strconv.Atoi(strings.TrimSpace(v))
		if err == nil {
			if exclude {
				notInParts = append(notInParts, fmt.Sprintf("%s != ?", field))
				notInArgs = append(notInArgs, port)
			} else {
				inParts = append(inParts, fmt.Sprintf("%s = ?", field))
				inArgs = append(inArgs, port)
			}
		}
	}

	if len(inParts) > 0 {
		*clauses = append(*clauses, "("+strings.Join(inParts, " OR ")+")")
		*args = append(*args, inArgs...)
	}
	if len(notInParts) > 0 {
		*clauses = append(*clauses, "("+strings.Join(notInParts, " AND ")+")")
		*args = append(*args, notInArgs...)
	}
}

func addTimeRangeClauses(clauses *[]string, args *[]any, timeStart, timeEnd time.Time) {
	if !timeStart.IsZero() {
		*clauses = append(*clauses, "time >= ?")
		*args = append(*args, timeStart.Format(time.RFC3339))
	}
	if !timeEnd.IsZero() {
		*clauses = append(*clauses, "time <= ?")
		*args = append(*args, timeEnd.Format(time.RFC3339))
	}
}

func addJSONFieldClauses(clauses *[]string, args *[]any, fieldFilters map[string][]string) {
	for key, values := range fieldFilters {
		var inVals, notInVals []string
		var inGlobVals, notInGlobVals []string

		for _, v := range values {
			exclude := false
			if strings.HasPrefix(v, "!") && len(v) > 1 {
				exclude = true
				v = v[1:]
			}

			hasWildcard := strings.ContainsAny(v, "*?")

			if exclude {
				if hasWildcard {
					notInGlobVals = append(notInGlobVals, v)
				} else {
					notInVals = append(notInVals, v)
				}
			} else {
				if hasWildcard {
					inGlobVals = append(inGlobVals, v)
				} else {
					inVals = append(inVals, v)
				}
			}
		}

		jsonPath := fmt.Sprintf("$.%s", key)

		if len(inVals) > 0 || len(inGlobVals) > 0 {
			var parts []string
			if len(inVals) > 0 {
				placeholders := strings.TrimRight(strings.Repeat("?, ", len(inVals)), ", ")
				parts = append(parts, fmt.Sprintf("json_extract_string(fields, ?) IN (%s)", placeholders))
				*args = append(*args, jsonPath)
				*args = append(*args, sliceToAny(inVals)...)
			}
			for _, g := range inGlobVals {
				parts = append(parts, "json_extract_string(fields, ?) GLOB ?")
				*args = append(*args, jsonPath, g)
			}

			if len(parts) > 1 {
				*clauses = append(*clauses, "("+strings.Join(parts, " OR ")+")")
			} else {
				*clauses = append(*clauses, parts[0])
			}
		}

		if len(notInVals) > 0 || len(notInGlobVals) > 0 {
			if len(notInVals) > 0 {
				placeholders := strings.TrimRight(strings.Repeat("?, ", len(notInVals)), ", ")
				*clauses = append(*clauses, fmt.Sprintf("json_extract_string(fields, ?) NOT IN (%s)", placeholders))
				*args = append(*args, jsonPath)
				*args = append(*args, sliceToAny(notInVals)...)
			}
			for _, g := range notInGlobVals {
				*clauses = append(*clauses, "json_extract_string(fields, ?) NOT GLOB ?")
				*args = append(*args, jsonPath, g)
			}
		}
	}
}

func addJSONFieldExistsClauses(clauses *[]string, args *[]any, fieldExists []string) {
	for _, key := range fieldExists {
		*clauses = append(*clauses, "json_extract(fields, ?) IS NOT NULL")
		*args = append(*args, fmt.Sprintf("$.%s", key))
	}
}

func buildPaginationClause(query EventQuery) string {
	var parts []string

	if query.OrderDirection != "" {
		parts = append(parts, "ORDER BY time "+query.OrderDirection)
	}
	if query.Limit > 0 {
		parts = append(parts, "LIMIT "+strconv.Itoa(query.Limit))
	}
	if query.Offset > 0 {
		parts = append(parts, "OFFSET "+strconv.Itoa(query.Offset))
	}

	if len(parts) == 0 {
		return ""
	}

	return " " + strings.Join(parts, " ")
}

func sliceToAny(s []string) []any {
	result := make([]any, len(s))
	for i, v := range s {
		result[i] = v
	}
	return result
}