"""Tree-based water-stress classification case study."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Mapping, Sequence

import numpy as np

try:
    from sklearn.ensemble import RandomForestClassifier
except ImportError as exc:  # pragma: no cover - guard for optional dependency
    RandomForestClassifier = None
    _SKLEARN_ERROR = exc
else:
    _SKLEARN_ERROR = None

from .shared import SensorFeatureConfig, SensorDataPreprocessor


@dataclass
class WaterStressClassificationStudy:
    """Encapsulates preprocessing and training for four-status classification."""

    labels: Sequence[str]
    n_estimators: int = 400
    max_depth: int | None = None
    random_state: int = 42

    def __post_init__(self) -> None:
        if RandomForestClassifier is None:  # pragma: no cover - import guard
            raise ImportError("scikit-learn is required for WaterStressClassificationStudy") from _SKLEARN_ERROR
        self.label_to_index = {label: idx for idx, label in enumerate(self.labels)}
        self.model = RandomForestClassifier(
            n_estimators=self.n_estimators,
            max_depth=self.max_depth,
            random_state=self.random_state,
            n_jobs=-1,
            class_weight="balanced_subsample",
        )

    def fit(self, features: np.ndarray, targets: Sequence[str]) -> None:
        """Train the RandomForest base-line referenced in the case study."""
        numeric = np.array([self.label_to_index[label] for label in targets])
        self.model.fit(features, numeric)

    def predict(self, features: np.ndarray) -> list[str]:
        """Predict tomato water-stress status for each observation."""
        numeric = self.model.predict(features)
        return [self.labels[idx] for idx in numeric]

    def predict_proba(self, features: np.ndarray) -> np.ndarray:
        """Return per-class probabilities useful for irrigation decisions."""
        return self.model.predict_proba(features)

    def feature_importances(self) -> Mapping[str, float]:
        """Expose feature contributions to aid agronomic interpretation."""
        return dict(zip(self.model.feature_names_in_, self.model.feature_importances_))


__all__ = [
    "SensorFeatureConfig",
    "SensorDataPreprocessor",
    "WaterStressClassificationStudy",
]
