"""Multimodal case study combining images with SmartPlant telemetry."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple

try:
    import torch
    from torch import Tensor, nn
    import torch.nn.functional as F
except ImportError as exc:  # pragma: no cover - optional dependency guard
    torch = None
    Tensor = None
    nn = None
    F = None
    _TORCH_ERROR = exc
else:
    _TORCH_ERROR = None

from .shared import SensorFeatureConfig, SensorDataPreprocessor


class VisionEncoder(nn.Module):
    """Light-weight CNN turning RGB+NIR crops into embeddings."""

    def __init__(self, in_channels: int = 4, embed_dim: int = 256) -> None:
        if torch is None:  # pragma: no cover - import guard
            raise ImportError("PyTorch is required for VisionEncoder") from _TORCH_ERROR
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
        )
        self.projection = nn.Linear(128, embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        latent = self.network(x).flatten(start_dim=1)
        return F.normalize(self.projection(latent), dim=-1)


class SensorSequenceEncoder(nn.Module):
    """Temporal encoder mapping SmartPlant sequences to embeddings."""

    def __init__(self, input_dim: int, embed_dim: int = 256, hidden_dim: int = 128) -> None:
        if torch is None:  # pragma: no cover - import guard
            raise ImportError("PyTorch is required for SensorSequenceEncoder") from _TORCH_ERROR
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.projection = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        _, (h_n, _) = self.lstm(x)
        return F.normalize(self.projection(h_n[-1]), dim=-1)


class MultimodalWaterStressStudy(nn.Module):
    """CLIP-style head fusing vision and sensor embeddings."""

    def __init__(
        self,
        vision_channels: int,
        sensor_dim: int,
        embed_dim: int = 256,
        forecast_classes: int = 4,
    ) -> None:
        if torch is None:  # pragma: no cover - import guard
            raise ImportError("PyTorch is required for MultimodalWaterStressStudy") from _TORCH_ERROR
        super().__init__()
        self.vision_encoder = VisionEncoder(vision_channels, embed_dim)
        self.sensor_encoder = SensorSequenceEncoder(sensor_dim, embed_dim)
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim * 2),
            nn.Linear(embed_dim * 2, embed_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(embed_dim, forecast_classes),
        )

    def forward(self, image_batch: Tensor, sensor_batch: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        image_embed = self.vision_encoder(image_batch)
        sensor_embed = self.sensor_encoder(sensor_batch)
        joint = torch.cat([image_embed, sensor_embed], dim=-1)
        logits = self.classifier(joint)
        return logits, image_embed, sensor_embed

    @staticmethod
    def contrastive_loss(image_embed: Tensor, sensor_embed: Tensor, temperature: float = 0.07) -> Tensor:
        """Contrastive alignment objective inspired by CLIP."""
        image_embed = F.normalize(image_embed, dim=-1)
        sensor_embed = F.normalize(sensor_embed, dim=-1)
        logits_per_image = image_embed @ sensor_embed.t() / temperature
        logits_per_sensor = sensor_embed @ image_embed.t() / temperature
        labels = torch.arange(image_embed.size(0), device=image_embed.device)
        loss_i = F.cross_entropy(logits_per_image, labels)
        loss_s = F.cross_entropy(logits_per_sensor, labels)
        return (loss_i + loss_s) / 2


__all__ = [
    "SensorFeatureConfig",
    "SensorDataPreprocessor",
    "VisionEncoder",
    "SensorSequenceEncoder",
    "MultimodalWaterStressStudy",
]
