internal/database/blocklist.go

package database

import (
	"errors"
	"honeypot/internal/types"
	"strings"
)

func (db *Database) GetBlockedAddresses() ([]types.BlocklistEntry, error) {
	if db.DB == nil {
		return nil, errors.New("database is not connected")
	}

	query := `
		SELECT id, address, timestamp, expires, reason
		FROM blocklist
		WHERE expires > CURRENT_TIMESTAMP
		ORDER BY expires DESC
	`
	rows, err := db.DB.Query(query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var blocklist []types.BlocklistEntry
	for rows.Next() {
		var blocklistEntry types.BlocklistEntry
		if err := rows.Scan(
			&blocklistEntry.ID,
			&blocklistEntry.Address,
			&blocklistEntry.Timestamp,
			&blocklistEntry.Expires,
			&blocklistEntry.Reason,
		); err != nil {
			return nil, err
		}
		blocklist = append(blocklist, blocklistEntry)
	}
	return blocklist, nil
}

func (db *Database) GetBlocklistForNet(network string, limit int) ([]types.BlocklistEntry, error) {
	if db.DB == nil {
		return nil, errors.New("database is not connected")
	}

	_, _, err := GetIpWhere(network)
	if err != nil {
		return nil, err
	}

	var args []any

	// We join with the same logic as GetIpWhere but for the 'address' column in blocklist table.
	// Since GetIpWhere uses 'remote_ip_int' for IP/Subnet matching in honeypot_events,
	// we need to adapt it for the 'address' column in 'blocklist'.
	// In 'blocklist' table, 'address' is TEXT.

	var query string
	if strings.Contains(network, "/") {
		// Subnet matching for blocklist entries
		query = `
			SELECT id, address, timestamp, expires, reason
			FROM blocklist
			WHERE address::INET <<= ?::INET
			ORDER BY timestamp DESC
			LIMIT ?
		`
		args = []any{network, limit}
	} else {
		// Single IP matching
		query = `
			SELECT id, address, timestamp, expires, reason
			FROM blocklist
			WHERE address = ?
			ORDER BY timestamp DESC
			LIMIT ?
		`
		args = []any{network, limit}
	}

	rows, err := db.DB.Query(query, args...)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var blocklist []types.BlocklistEntry
	for rows.Next() {
		var blocklistEntry types.BlocklistEntry
		if err := rows.Scan(
			&blocklistEntry.ID,
			&blocklistEntry.Address,
			&blocklistEntry.Timestamp,
			&blocklistEntry.Expires,
			&blocklistEntry.Reason,
		); err != nil {
			return nil, err
		}
		blocklist = append(blocklist, blocklistEntry)
	}
	return blocklist, nil
}

func (db *Database) GetBlockCounts() (map[string]int, error) {
	query := `
		SELECT address, COUNT(*) as count
		FROM blocklist
		GROUP BY address
		ORDER BY count DESC
	`
	rows, err := db.DB.Query(query)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	counts := make(map[string]int)
	for rows.Next() {
		var address string
		var count int
		if err := rows.Scan(&address, &count); err != nil {
			return nil, err
		}
		counts[address] = count
	}
	return counts, nil
}