from typing import Tuple, Iterator, List
from pathlib import Path
from PIL import Image
from io import BytesIO
import numpy as np
import random

class DataLoaderSeq:
    def __init__(
        self,
        data_dir: str,
        batch_size: int,
        image_size: Tuple[int, int],
        drop_last: bool = True,
        seed: int | None = None,
    ):
        self.paths: List[Path] = [p for p in Path(data_dir).rglob("**/*")]
        self.batch_size = batch_size
        self.image_w, self.image_h = image_size
        self.drop_last = drop_last
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)

    def load_image(self, path: Path) -> bytes:
        with open(path, "rb") as f:
            return f.read()

    def decode_image(self, image_bytes: bytes) -> Image.Image:
        return Image.open(BytesIO(image_bytes)).convert("RGB")

    def process_image(self, image: Image.Image) -> np.ndarray:
        # resize
        img = image.resize((self.image_w, self.image_h), Image.BILINEAR)

        ## random crop
        cw = int(self.image_w * 0.9)
        ch = int(self.image_h * 0.9)
        ox = random.randint(0, self.image_w - cw)
        oy = random.randint(0, self.image_h - ch)
        img = img.crop((ox, oy, ox + cw, oy + ch)).resize((self.image_w, self.image_h))

        ## random rotate ±5°
        angle = random.uniform(-5, 5)
        img = img.rotate(angle, resample=Image.BILINEAR, expand=False)

        return np.asarray(img, dtype=np.uint8)

    def _iter_samples(self) -> Iterator[np.ndarray]:
        for p in self.paths:
            raw = self.load_image(p)
            dec = self.decode_image(raw)
            arr = self.process_image(dec)
            yield arr

    def __iter__(self) -> Iterator[np.ndarray]:
        batch = []
        for x in self._iter_samples():
            batch.append(x)
            if len(batch) == self.batch_size:
                yield np.stack(batch, axis=0)  # [B,H,W,3]
                batch = []
        if not self.drop_last and batch:
            yield np.stack(batch, axis=0)
        
    def __len__(self) -> int:
        return len(self.paths) // self.batch_size
