internal/honeypot/ssh/ssh.go

package ssh

import (
	"context"
	// #nosec G501
	"crypto/md5"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/hex"
	"encoding/pem"
	"fmt"
	"log/slog"
	"net"
	"os"
	"path/filepath"
	"sync"
	"time"

	"honeypot/internal/database"
	"honeypot/internal/honeypot"
	"honeypot/internal/logger"
	"honeypot/internal/types"
	"honeypot/internal/utils"

	"golang.org/x/crypto/ssh"
)

const (
	HoneypotType              = types.HoneypotTypeSSH
	HoneypotLabel             = "SSH"
	DefaultHostKeyFile        = "ssh_key"
	DefaultConnectionTimeout  = 3 * time.Minute
	DefaultRSAKeySize         = 3072
	DefaultKeyFilePermissions = 0600
	DefaultDirPermissions     = 0700
	DefaultPubKeyPermissions  = 0644
	ServerVersion             = "SSH-2.0-OpenSSH_9.6p1 Ubuntu-3ubuntu13.12"
	ShutdownTimeout           = 3 * time.Second
)

// Config holds the configuration for the SSH honeypot.
type Config struct {
	ListenAddr  string
	Ports       []uint16
	HostKeyFile string // Path to the host key file
}

// sshHoneypot implements the honeypot.Honeypot interface.
type sshHoneypot struct {
	config Config
	logger *slog.Logger
}

// New creates a new SSH honeypot instance.
func New(cfg Config) honeypot.Honeypot {
	// Register client_version field for top-N tracking
	// This is done during creation so fields are registered before log restoration
	logger.RegisterTopNField("ssh", "client_version")

	return &sshHoneypot{
		config: cfg,
	}
}

// Name returns the name of this honeypot.
func (h *sshHoneypot) Name() types.HoneypotType {
	return HoneypotType
}

// Label returns the label of this honeypot.
func (h *sshHoneypot) Label() string {
	return HoneypotLabel
}

// Start starts the SSH honeypot server.
func (h *sshHoneypot) Start(ctx context.Context, l *slog.Logger) error {
	h.logger = l

	sshConfig, err := h.createServerConfig()
	if err != nil {
		return err
	}

	var wg sync.WaitGroup

	// Start SSH servers on all configured ports
	for _, port := range h.config.Ports {
		if port > 0 {
			wg.Add(1)
			go func(p uint16) {
				defer wg.Done()
				h.startSSHServer(ctx, p, sshConfig)
			}(port)
		}
	}

	wg.Wait()
	logger.LogInfo(h.logger, HoneypotType, "honeypot shutdown complete", nil)
	return nil
}

// startSSHServer starts an SSH server on the specified port.
func (h *sshHoneypot) startSSHServer(ctx context.Context, port uint16, sshConfig *ssh.ServerConfig) {
	listenAddr := utils.BuildAddress(h.config.ListenAddr, port)
	listener, err := net.Listen("tcp", listenAddr)
	if err != nil {
		logger.LogError(h.logger, HoneypotType, "listen_failed", err, []any{
			"addr", listenAddr,
		})
		return
	}
	defer listener.Close()

	logger.LogInfo(h.logger, HoneypotType, "honeypot listening", []any{
		"port", port,
	})

	h.setupGracefulShutdown(ctx, listener)
	h.acceptConnections(ctx, listener, sshConfig)

	logger.LogInfo(h.logger, HoneypotType, "honeypot shutdown complete", []any{
		"port", port,
	})
}

// setupGracefulShutdown handles graceful shutdown on context cancellation.
func (h *sshHoneypot) setupGracefulShutdown(ctx context.Context, listener net.Listener) {
	go func() {
		<-ctx.Done()
		if err := listener.Close(); err != nil {
			logger.LogError(h.logger, HoneypotType, "listener_close_error", err, nil)
		}
	}()
}

// acceptConnections accepts incoming connections and handles them.
func (h *sshHoneypot) acceptConnections(ctx context.Context, listener net.Listener, cfg *ssh.ServerConfig) {
	for {
		conn, err := listener.Accept()
		if err != nil {
			if ctx.Err() != nil {
				break
			}
			logger.LogError(h.logger, HoneypotType, "accept_failed", err, nil)
			continue
		}

		go h.handleConn(ctx, conn, cfg)
	}
}

// createServerConfig creates and configures the SSH server config.
func (h *sshHoneypot) createServerConfig() (*ssh.ServerConfig, error) {
	cfg := &ssh.ServerConfig{
		ServerVersion:     ServerVersion,
		MaxAuthTries:      0,
		PasswordCallback:  h.passwordAuthCallback,
		PublicKeyCallback: h.publicKeyAuthCallback,
	}

	keyFile := h.getHostKeyFile()
	hostKey, err := h.loadOrGenerateHostKey(keyFile)
	if err != nil {
		logger.LogError(h.logger, HoneypotType, "load_or_generate_host_key_failed", err, []any{
			"key_file", keyFile,
		})
		return nil, err
	}

	cfg.AddHostKey(hostKey)
	return cfg, nil
}

// getHostKeyFile returns the host key file path, using default if not specified.
func (h *sshHoneypot) getHostKeyFile() string {
	if h.config.HostKeyFile != "" {
		return h.config.HostKeyFile
	}
	return DefaultHostKeyFile
}

// passwordAuthCallback handles password authentication attempts.
func (h *sshHoneypot) passwordAuthCallback(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
	h.logAuthAttempt(c, "password", map[string]interface{}{
		"username":       c.User(),
		"password":       string(pass),
		"client_version": string(c.ClientVersion()),
	})
	return nil, fmt.Errorf("access denied")
}

// publicKeyAuthCallback handles public key authentication attempts.
func (h *sshHoneypot) publicKeyAuthCallback(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
	h.logAuthAttempt(c, "public-key", map[string]interface{}{
		"username":               c.User(),
		"public_key_fingerprint": fingerprint(key),
		"client_version":         string(c.ClientVersion()),
	})
	return nil, fmt.Errorf("access denied")
}

// handleConn handles a new SSH connection.
func (h *sshHoneypot) handleConn(ctx context.Context, conn net.Conn, cfg *ssh.ServerConfig) {
	defer conn.Close()

	conn.SetDeadline(time.Now().Add(DefaultConnectionTimeout))

	sshConn, chans, reqs, err := ssh.NewServerConn(conn, cfg)
	if err != nil {
		return
	}
	defer sshConn.Close()

	h.handleSSHRequests(ctx, reqs)
	h.handleSSHChannels(ctx, chans)
}

// handleSSHRequests discards all SSH global requests.
func (h *sshHoneypot) handleSSHRequests(ctx context.Context, reqs <-chan *ssh.Request) {
	go ssh.DiscardRequests(reqs)
}

// handleSSHChannels rejects all SSH channel requests.
func (h *sshHoneypot) handleSSHChannels(ctx context.Context, chans <-chan ssh.NewChannel) {
	for ch := range chans {
		ch.Reject(ssh.Prohibited, "access denied")
	}
}

// logAuthAttempt logs an authentication attempt with the provided fields.
func (h *sshHoneypot) logAuthAttempt(c ssh.ConnMetadata, authMethod string, fields map[string]interface{}) {
	remoteHost, remotePort := utils.SplitAddr(c.RemoteAddr().String(), h.logger)
	_, dstPort := utils.SplitAddr(c.LocalAddr().String(), h.logger)

	if fields == nil {
		fields = make(map[string]interface{})
	}

	fields["auth_method"] = authMethod
	fields["supported_algorithms"] = supportedAlgorithms()

	logger.LogEvent(h.logger, types.LogEvent{
		Type:       HoneypotType,
		Event:      types.EventAuthAttempt,
		RemoteAddr: remoteHost,
		RemotePort: remotePort,
		DstPort:    dstPort,
		Fields:     fields,
	})
}

// loadOrGenerateHostKey loads an existing host key from the specified file,
// or generates a new one and saves it if the file doesn't exist.
func (h *sshHoneypot) loadOrGenerateHostKey(keyFile string) (ssh.Signer, error) {
	if _, err := os.Stat(keyFile); err == nil {
		return h.loadExistingHostKey(keyFile)
	}

	return h.generateAndSaveHostKey(keyFile)
}

// loadExistingHostKey loads an existing host key from file.
func (h *sshHoneypot) loadExistingHostKey(keyFile string) (ssh.Signer, error) {
	keyBytes, err := os.ReadFile(keyFile)
	if err != nil {
		return nil, fmt.Errorf("failed to read key file: %w", err)
	}

	signer, err := ssh.ParsePrivateKey(keyBytes)
	if err != nil {
		return nil, fmt.Errorf("failed to parse private key: %w", err)
	}

	logger.LogInfo(h.logger, HoneypotType, "loaded existing host key", []any{
		"key_file", keyFile,
	})
	return signer, nil
}

// generateAndSaveHostKey generates a new RSA key pair and saves it to disk.
func (h *sshHoneypot) generateAndSaveHostKey(keyFile string) (ssh.Signer, error) {
	key, err := rsa.GenerateKey(rand.Reader, DefaultRSAKeySize)
	if err != nil {
		return nil, fmt.Errorf("failed to generate key: %w", err)
	}

	if err := h.ensureKeyFileDirectory(keyFile); err != nil {
		return nil, err
	}

	if err := h.savePrivateKey(keyFile, key); err != nil {
		return nil, err
	}

	h.savePublicKey(keyFile, &key.PublicKey)

	signer, err := ssh.NewSignerFromKey(key)
	if err != nil {
		return nil, fmt.Errorf("failed to create signer: %w", err)
	}

	logger.LogInfo(h.logger, HoneypotType, "generated new host key", []any{
		"key_file", keyFile,
	})
	return signer, nil
}

// ensureKeyFileDirectory creates the directory for the key file if needed.
func (h *sshHoneypot) ensureKeyFileDirectory(keyFile string) error {
	if dir := filepath.Dir(keyFile); dir != "." && dir != "" {
		if err := os.MkdirAll(dir, DefaultDirPermissions); err != nil {
			return fmt.Errorf("failed to create directory for key file: %w", err)
		}
	}
	return nil
}

// savePrivateKey saves the private key to disk in PEM format.
func (h *sshHoneypot) savePrivateKey(keyFile string, key *rsa.PrivateKey) error {
	privateKeyBytes := x509.MarshalPKCS1PrivateKey(key)
	privateKeyPEM := pem.EncodeToMemory(&pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: privateKeyBytes,
	})

	if err := os.WriteFile(keyFile, privateKeyPEM, DefaultKeyFilePermissions); err != nil {
		return fmt.Errorf("failed to write key file: %w", err)
	}
	return nil
}

// savePublicKey saves the public key to disk (optional, errors are logged but don't fail).
func (h *sshHoneypot) savePublicKey(keyFile string, pubKey *rsa.PublicKey) {
	sshPubKey, err := ssh.NewPublicKey(pubKey)
	if err != nil {
		return
	}

	pubKeyFile := keyFile + ".pub"
	pubKeyBytes := ssh.MarshalAuthorizedKey(sshPubKey)
	if err := os.WriteFile(pubKeyFile, pubKeyBytes, DefaultPubKeyPermissions); err != nil {
		// Log warning but don't fail - public key is optional
		logger.LogError(h.logger, HoneypotType, "failed to save public key", err, []any{
			"pub_key_file", pubKeyFile,
		})
	}
}

// fingerprint calculates the MD5 fingerprint of an SSH public key.
func fingerprint(k ssh.PublicKey) string {
	// #nosec G401
	h := md5.Sum(k.Marshal())
	return hex.EncodeToString(h[:])
}

// supportedAlgorithms returns a string describing the supported SSH algorithms.
func supportedAlgorithms() string {
	return "kex:curve25519-sha256; ciphers:chacha20-poly1305@openssh.com; macs:hmac-sha2-256"
}

func (h *sshHoneypot) GetScores(db *database.Database, interval string) honeypot.ScoreMap {
	// auth_attempt scores are handled in the honeypot.authAttemptScores function
	return honeypot.ScoreMap{}
}

func (h *sshHoneypot) Ports() []uint16 {
	return h.config.Ports
}