cmd/sanitize-db/main.go

package main

import (
	"context"
	"database/sql"
	"database/sql/driver"
	"encoding/json"
	"fmt"
	"log"
	"os"

	"github.com/duckdb/duckdb-go/v2"
)

func main() {
	if len(os.Args) != 2 {
		fmt.Fprintf(os.Stderr, "Usage: %s <database_file>\n", os.Args[0])
		fmt.Fprintf(os.Stderr, "Example: %s data/honeypot.db\n", os.Args[0])
		os.Exit(1)
	}

	databaseFile := os.Args[1]
	ctx := context.Background()

	// Use the DuckDB connector
	connector, err := duckdb.NewConnector(databaseFile, nil)
	if err != nil {
		log.Fatalf("failed to open database: %v", err)
	}
	defer connector.Close()

	db := sql.OpenDB(connector)
	defer db.Close()

	fmt.Printf("Sanitizing database: %s\n", databaseFile)

	// Count rows that have remote_addr in fields
	var count int
	err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM honeypot_events WHERE fields->>'remote_addr' IS NOT NULL").Scan(&count)
	if err != nil {
		log.Fatalf("failed to count rows: %v", err)
	}

	if count == 0 {
		fmt.Println("No rows found with 'remote_addr' in 'fields'.")
		return
	}

	fmt.Printf("Found %d rows with 'remote_addr' in 'fields' column.\n", count)
	fmt.Print("Are you sure you want to remove 'remote_addr' from the 'fields' JSON for these rows? (yes/no): ")
	var confirmation string
	fmt.Scanln(&confirmation)

	if confirmation != "yes" {
		fmt.Println("Sanitization cancelled.")
		return
	}

	fmt.Println("Starting sanitization...")

	// Get a single persistent connection for the whole process
	sqlConn, err := db.Conn(ctx)
	if err != nil {
		log.Fatalf("Failed to get connection: %v", err)
	}
	defer sqlConn.Close()

	// 1. Create a temporary table to hold the sanitized fields
	_, err = sqlConn.ExecContext(ctx, "CREATE TEMP TABLE temp_sanitized (id INTEGER, fields JSON)")
	if err != nil {
		log.Fatalf("Failed to create temp table: %v", err)
	}

	// 2. Open an appender using the SAME connection
	var appender *duckdb.Appender
	err = sqlConn.Raw(func(driverConn any) error {
		var err error
		appender, err = duckdb.NewAppenderFromConn(driverConn.(driver.Conn), "", "temp_sanitized")
		return err
	})
	if err != nil {
		log.Fatalf("Failed to create appender: %v", err)
	}

	// 3. Stream rows from the main table
	rows, err := db.QueryContext(ctx, "SELECT id, fields FROM honeypot_events WHERE fields->>'remote_addr' IS NOT NULL")
	if err != nil {
		log.Fatalf("Failed to query rows: %v", err)
	}
	defer rows.Close()

	processed := 0
	for rows.Next() {
		var id int
		var fieldsVal any
		if err := rows.Scan(&id, &fieldsVal); err != nil {
			log.Fatalf("Row scan failed: %v", err)
		}

		var fields map[string]any
		switch v := fieldsVal.(type) {
		case map[string]any:
			fields = v
		case []byte:
			if err := json.Unmarshal(v, &fields); err != nil {
				continue
			}
		case string:
			if err := json.Unmarshal([]byte(v), &fields); err != nil {
				continue
			}
		default:
			continue
		}

		delete(fields, "remote_addr")
		delete(fields, "port")

		// Append the map directly to temp table. DuckDB driver will handle it correctly for JSON columns.
		// If we passed a string here, it would be treated as a string literal and scanned as a string later.
		if err := appender.AppendRow(id, fields); err != nil {
			log.Fatalf("Appender failed at row %d: %v", id, err)
		}

		processed++
		if processed%100000 == 0 {
			fmt.Printf("Processed %d rows...\n", processed)
		}
	}

	if err := appender.Close(); err != nil {
		log.Fatalf("Failed to close appender: %v", err)
	}

	fmt.Printf("Updating %d rows in main table...\n", processed)

	// 4. Update the main table from the temp table using the SAME connection
	updateQuery := `
		UPDATE honeypot_events 
		SET fields = s.fields 
		FROM temp_sanitized s 
		WHERE honeypot_events.id = s.id
	`
	res, err := sqlConn.ExecContext(ctx, updateQuery)
	if err != nil {
		log.Fatalf("Bulk update failed: %v", err)
	}

	rowsAffected, _ := res.RowsAffected()
	fmt.Printf("Successfully sanitized %d rows.\n", rowsAffected)

	// 5. Cleanup temp table explicitly
	_, _ = sqlConn.ExecContext(ctx, "DROP TABLE temp_sanitized")

	fmt.Println("Database sanitization complete.")
}