"""Sequence-to-status forecasting case study using LSTM."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Mapping, Optional, Sequence, Tuple

import numpy as np
import pandas as pd

try:
    import torch
    from torch import Tensor, nn
    import torch.nn.functional as F
except ImportError as exc:  # pragma: no cover - optional dependency guard
    torch = None
    Tensor = None
    nn = None
    F = None
    _TORCH_ERROR = exc
else:
    _TORCH_ERROR = None

from .shared import SensorFeatureConfig, SensorDataPreprocessor


class WaterStressForecaster(nn.Module):
    """LSTM architecture aligned with the tomato forecasting reference."""

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int = 30,
        output_dim: int = 4,
        dropout: float = 0.25,
    ) -> None:
        if torch is None:  # pragma: no cover - import guard
            raise ImportError("PyTorch is required for WaterStressForecaster") from _TORCH_ERROR
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(hidden_dim, output_dim)

    def forward(self, x: Tensor) -> Tensor:
        _, (h_n, _) = self.lstm(x)
        latent = self.dropout(h_n[-1])
        return self.classifier(latent)

    @staticmethod
    def build_sequences(
        panel: pd.DataFrame,
        feature_cols: Sequence[str],
        target_col: str,
        sequence_length: int = 3,
    ) -> Tuple[np.ndarray, np.ndarray]:
        sequences: list[np.ndarray] = []
        labels: list[str] = []
        grouped = panel.groupby("plant_id")
        for _, group in grouped:
            features = group[list(feature_cols)].to_numpy(dtype=np.float32)
            targets = group[target_col].to_numpy()
            for idx in range(sequence_length - 1, len(group)):
                start = idx - sequence_length + 1
                window = features[start : idx + 1]
                # Preserve temporal order so the final row aligns with the prediction horizon.
                sequences.append(window)
                labels.append(targets[idx])
        X = np.stack(sequences)
        y = np.array(labels)
        return X, y


@dataclass
class WaterStressForecastingStudy:
    """Configures training and class weighting for the LSTM forecaster."""

    labels: Sequence[str]
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4

    def __post_init__(self) -> None:
        if torch is None:  # pragma: no cover - import guard
            raise ImportError("PyTorch is required for WaterStressForecastingStudy") from _TORCH_ERROR
        self.label_to_index = {label: idx for idx, label in enumerate(self.labels)}

    def make_optimizer(self, model: nn.Module) -> torch.optim.Optimizer:
        """Create the Adam optimizer used in the case study."""
        return torch.optim.Adam(model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)

    def compute_class_weights(self, targets: Sequence[str]) -> Optional[Tensor]:
        """Derive inverse-frequency weights to counter class imbalance."""
        counts = {label: 0 for label in self.labels}
        for target in targets:
            counts[target] += 1
        total = sum(counts.values())
        weights = [0.0] * len(self.labels)
        for label, idx in self.label_to_index.items():
            freq = counts[label] / max(total, 1)
            weights[idx] = 0.0 if freq == 0 else 1.0 / freq
        tensor = torch.tensor(weights, dtype=torch.float32) if torch is not None else None
        return tensor

    def training_step(
        self,
        model: WaterStressForecaster,
        batch: Tuple[Tensor, Tensor],
        optimizer: torch.optim.Optimizer,
        class_weights: Optional[Tensor] = None,
    ) -> float:
        """Run one optimization step on a batch of sequences."""
        model.train()
        inputs, labels = batch
        logits = model(inputs)
        indices = torch.tensor([self.label_to_index[label] for label in labels], device=logits.device)
        loss = F.cross_entropy(logits, indices, weight=class_weights)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return float(loss.item())


__all__ = [
    "SensorFeatureConfig",
    "SensorDataPreprocessor",
    "WaterStressForecaster",
    "WaterStressForecastingStudy",
]
