In [ ]:
import os
import time
import shutil
import numpy as np
import rasterio
from rasterio.windows import from_bounds
import cv2
import ee
import geemap
from google.colab import drive
from sklearn.model_selection import train_test_split
!pip install geedim
!pip install geemap

# 1. SETUP
drive.mount('/content/drive', force_remount=True)


MY_PROJECT_ID = '[REDACTED_FOR_SECURITY]'

# Authentication Block
try:

    ee.Initialize(project=MY_PROJECT_ID)
    print(f" Earth Engine Initialized with YOUR project: {MY_PROJECT_ID}")
except:
    ee.Authenticate()
    ee.Initialize(project=MY_PROJECT_ID)
    print(f" Earth Engine Authenticated and Initialized with YOUR project: {MY_PROJECT_ID}")


ASSET_ID = '[REDACTED_FOR_SECURITY]'

SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR)

PATCH_SIZE = 224
S2_SCALE = 5000.0

# Prithvi 3-Step Temporal Configuration
TIME_WINDOWS = [
    ('2024-10-15', '2024-11-15'),  # T1: Sowing
    ('2025-01-01', '2025-01-31'),  # T2: Peak Growth 1
    ('2025-02-15', '2025-03-15')   # T3: Peak Growth 2
]

def generate_prithvi_npy():
    print(f" Starting Prithvi Data Generation.")
    print(f"   Billing Project: {MY_PROJECT_ID}")
    print(f"   Reading Asset:   {ASSET_ID}")

    # A. Mask Ingestion
    mask_img = ee.Image(ASSET_ID)
    roi_geom = mask_img.geometry()

    mask_file = 'local_mask_prithvi.tif'
    if not os.path.exists(mask_file):
        print("   Downloading Mask...")
        # geemap will now use MY_PROJECT_ID to authorize the download
        geemap.download_ee_image(mask_img, mask_file, region=roi_geom, scale=10, crs='EPSG:4326', overwrite=True)

    # B. Define Spatial Subset
    with rasterio.open(mask_file) as src:
        b = src.bounds
        cx, cy = (b.left + b.right)/2, (b.bottom + b.top)/2
        offset = 0.04
        window = from_bounds(cx-offset, cy-offset, cx+offset, cy+offset, src.transform)

        mask = src.read(1, window=window)
        mask = np.where(mask > 0, 1.0, 0.0).astype(np.float32)

        target_h, target_w = mask.shape
        small_roi = ee.Geometry.BBox(cx-offset, cy-offset, cx+offset, cy+offset)
        print(f"   ROI Subset Defined: {target_h}x{target_w}")

    # C. Multi-Temporal Stack Generation
    stack = []

    for i, (start, end) in enumerate(TIME_WINDOWS):
        fname = f'prithvi_time_{i}.tif'

        # Robust Download Logic
        attempts = 0
        while not os.path.exists(fname) and attempts < 3:
            try:
                print(f"   Downloading Time Step {i+1}/{len(TIME_WINDOWS)}: {start} to {end}...")

                # Optical Only (6 Bands)
                img = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED') \
                    .filterBounds(small_roi) \
                    .filterDate(start, end) \
                    .median() \
                    .select(['B2','B3','B4','B8','B11','B12'])

                geemap.download_ee_image(img, fname, region=small_roi, scale=10, crs='EPSG:4326', overwrite=True)
            except Exception as e:
                print(f"     Error downloading {fname}: {e}")
                attempts += 1
                time.sleep(2)

        # Fallback Logic
        if not os.path.exists(fname):
            print(f"    Failed to download {fname}. Attempting fallback...")
            if i > 0:
                shutil.copy(f'prithvi_time_{i-1}.tif', fname)
            else:
                # Create dummy zero array if first step fails
                with rasterio.open(mask_file) as src:
                     profile = src.profile
                     profile.update(count=6, dtype=rasterio.float32)
                     with rasterio.open(fname, 'w', **profile) as dst:
                         dst.write(np.zeros((6, target_h, target_w), dtype=np.float32))

        # Processing & Normalization
        with rasterio.open(fname) as src:
            arr = src.read() # (6, H, W)
            arr = np.transpose(arr, (1, 2, 0)) # (H, W, 6)

            if arr.shape[:2] != (target_h, target_w):
                arr = cv2.resize(arr, (target_w, target_h), interpolation=cv2.INTER_LINEAR)

            # Normalization
            arr = np.clip(arr / S2_SCALE, 0, 1).astype(np.float32)
            stack.append(arr)

    # D. Stack Construction
    full_cube = np.stack(stack, axis=2)

    # E. Tiling
    x_out, y_out = [], []
    stride = PATCH_SIZE

    print("   Tiling patches...")
    for y in range(0, target_h, stride):
        for x in range(0, target_w, stride):
            img_p = full_cube[y:y+stride, x:x+stride]
            mask_p = mask[y:y+stride, x:x+stride]

            if img_p.shape[0] != PATCH_SIZE or img_p.shape[1] != PATCH_SIZE: continue
            if np.mean(mask_p) < 0.01: continue
            if np.isnan(img_p).any(): continue

            x_out.append(img_p)
            y_out.append(mask_p)

    if len(x_out) == 0: raise ValueError("No valid patches generated.")

    # F. Transpose
    X = np.array(x_out, dtype=np.float32).transpose(0, 4, 3, 1, 2)
    y = np.array(y_out, dtype=np.float32)[:, None, :, :]

    print(f" Dataset Generated. Shape: {X.shape}")

    # G. Save to Drive
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

    print(" Saving .npy files to Google Drive...")
    np.save(os.path.join(SAVE_DIR, 'train_x.npy'), X_train)
    np.save(os.path.join(SAVE_DIR, 'train_y.npy'), y_train)
    np.save(os.path.join(SAVE_DIR, 'val_x.npy'), X_val)
    np.save(os.path.join(SAVE_DIR, 'val_y.npy'), y_val)

    print(" DONE! Proceed to Training Cell.")

# Execute
generate_prithvi_npy()
Collecting geedim
  Downloading geedim-2.0.0-py3-none-any.whl.metadata (6.0 kB)
Requirement already satisfied: numpy>=1.19 in /usr/local/lib/python3.12/dist-packages (from geedim) (2.0.2)
Requirement already satisfied: rasterio>=1.3.8 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.0)
Requirement already satisfied: click>=8 in /usr/local/lib/python3.12/dist-packages (from geedim) (8.3.1)
Requirement already satisfied: tqdm>=4.6 in /usr/local/lib/python3.12/dist-packages (from geedim) (4.67.1)
Requirement already satisfied: earthengine-api>=0.1.379 in /usr/local/lib/python3.12/dist-packages (from geedim) (1.5.24)
Requirement already satisfied: tabulate>=0.9 in /usr/local/lib/python3.12/dist-packages (from geedim) (0.9.0)
Requirement already satisfied: fsspec>=2025.2 in /usr/local/lib/python3.12/dist-packages (from geedim) (2025.3.0)
Requirement already satisfied: aiohttp>=3.11 in /usr/local/lib/python3.12/dist-packages (from geedim) (3.13.3)
Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (2.6.1)
Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.4.0)
Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (25.4.0)
Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.8.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (6.7.0)
Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (0.4.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp>=3.11->geedim) (1.22.0)
Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (3.7.0)
Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.187.0)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.43.0)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.3.0)
Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (0.31.0)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=0.1.379->geedim) (2.32.4)
Requirement already satisfied: affine in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2.4.0)
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (2026.1.4)
Requirement already satisfied: cligj>=0.5 in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (0.7.2)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.12/dist-packages (from rasterio>=1.3.8->geedim) (3.3.1)
Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.12/dist-packages (from aiosignal>=1.4.0->aiohttp>=3.11->geedim) (4.15.0)
Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (2.29.0)
Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (4.2.0)
Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (6.2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.4.2)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (4.9.1)
Requirement already satisfied: idna>=2.0 in /usr/local/lib/python3.12/dist-packages (from yarl<2.0,>=1.17.0->aiohttp>=3.11->geedim) (3.11)
Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.5.0)
Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (2.8.0)
Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=0.1.379->geedim) (1.8.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (3.4.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=0.1.379->geedim) (2.5.0)
Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.72.0)
Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (5.29.5)
Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=0.1.379->geedim) (1.27.0)
Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=0.1.379->geedim) (0.6.1)
Downloading geedim-2.0.0-py3-none-any.whl (73 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 73.1/73.1 kB 5.3 MB/s eta 0:00:00
Installing collected packages: geedim
Successfully installed geedim-2.0.0
Requirement already satisfied: geemap in /usr/local/lib/python3.12/dist-packages (0.35.3)
Requirement already satisfied: bqplot in /usr/local/lib/python3.12/dist-packages (from geemap) (0.12.45)
Requirement already satisfied: colour in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.5)
Requirement already satisfied: earthengine-api>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (1.5.24)
Requirement already satisfied: eerepr>=0.1.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.1.2)
Requirement already satisfied: folium>=0.17.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0)
Requirement already satisfied: geocoder in /usr/local/lib/python3.12/dist-packages (from geemap) (1.38.1)
Requirement already satisfied: ipyevents in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.4)
Requirement already satisfied: ipyfilechooser>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.6.0)
Requirement already satisfied: ipyleaflet>=0.19.2 in /usr/local/lib/python3.12/dist-packages (from geemap) (0.20.0)
Requirement already satisfied: ipytree in /usr/local/lib/python3.12/dist-packages (from geemap) (0.2.2)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from geemap) (3.10.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.12/dist-packages (from geemap) (2.0.2)
Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from geemap) (2.2.2)
Requirement already satisfied: plotly in /usr/local/lib/python3.12/dist-packages (from geemap) (5.24.1)
Requirement already satisfied: pyperclip in /usr/local/lib/python3.12/dist-packages (from geemap) (1.11.0)
Requirement already satisfied: pyshp>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from geemap) (3.0.3)
Requirement already satisfied: python-box in /usr/local/lib/python3.12/dist-packages (from geemap) (7.3.2)
Requirement already satisfied: scooby in /usr/local/lib/python3.12/dist-packages (from geemap) (0.11.0)
Requirement already satisfied: google-cloud-storage in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (3.7.0)
Requirement already satisfied: google-api-python-client>=1.12.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.187.0)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.43.0)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.3.0)
Requirement already satisfied: httplib2<1dev,>=0.9.2 in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (0.31.0)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from earthengine-api>=1.0.0->geemap) (2.32.4)
Requirement already satisfied: branca>=0.6.0 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (0.8.2)
Requirement already satisfied: jinja2>=2.9 in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (3.1.6)
Requirement already satisfied: xyzservices in /usr/local/lib/python3.12/dist-packages (from folium>=0.17.0->geemap) (2025.11.0)
Requirement already satisfied: ipywidgets in /usr/local/lib/python3.12/dist-packages (from ipyfilechooser>=0.6.0->geemap) (7.7.1)
Requirement already satisfied: jupyter-leaflet<0.21,>=0.20 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.20.0)
Requirement already satisfied: traittypes<3,>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from ipyleaflet>=0.19.2->geemap) (0.2.3)
Requirement already satisfied: traitlets>=4.3.0 in /usr/local/lib/python3.12/dist-packages (from bqplot->geemap) (5.7.1)
Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->geemap) (2025.3)
Requirement already satisfied: click in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (8.3.1)
Requirement already satisfied: future in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.0.0)
Requirement already satisfied: ratelim in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (0.1.6)
Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from geocoder->geemap) (1.17.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.3.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (4.61.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (1.4.9)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (25.0)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (11.3.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->geemap) (3.3.1)
Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.12/dist-packages (from plotly->geemap) (9.1.2)
Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (2.29.0)
Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.12/dist-packages (from google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (4.2.0)
Requirement already satisfied: cachetools<7.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (6.2.4)
Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.4.2)
Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.12/dist-packages (from google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (4.9.1)
Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.17.1)
Requirement already satisfied: ipython-genutils~=0.2.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0)
Requirement already satisfied: widgetsnbextension~=3.6.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.6.10)
Requirement already satisfied: ipython>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.34.0)
Requirement already satisfied: jupyterlab-widgets>=1.0.0 in /usr/local/lib/python3.12/dist-packages (from ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.16)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2>=2.9->folium>=0.17.0->geemap) (3.0.3)
Requirement already satisfied: google-cloud-core<3.0.0,>=2.4.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.5.0)
Requirement already satisfied: google-resumable-media<3.0.0,>=2.7.2 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (2.8.0)
Requirement already satisfied: google-crc32c<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from google-cloud-storage->earthengine-api>=1.0.0->geemap) (1.8.0)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->earthengine-api>=1.0.0->geemap) (2026.1.4)
Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from ratelim->geocoder->geemap) (4.4.2)
Requirement already satisfied: googleapis-common-protos<2.0.0,>=1.56.2 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.72.0)
Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<7.0.0,>=3.19.5 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (5.29.5)
Requirement already satisfied: proto-plus<2.0.0,>=1.22.3 in /usr/local/lib/python3.12/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.1->earthengine-api>=1.0.0->geemap) (1.27.0)
Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.8.15)
Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.4.9)
Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.1)
Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.6.0)
Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.5)
Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (26.2.1)
Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.12/dist-packages (from ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.1)
Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (75.2.0)
Collecting jedi>=0.16 (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap)
  Downloading jedi-0.19.2-py2.py3-none-any.whl.metadata (22 kB)
Requirement already satisfied: pickleshare in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.5)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.52)
Requirement already satisfied: pygments in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.19.2)
Requirement already satisfied: backcall in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.0)
Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.12/dist-packages (from ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.9.0)
Requirement already satisfied: pyasn1<0.7.0,>=0.6.1 in /usr/local/lib/python3.12/dist-packages (from pyasn1-modules>=0.2.1->google-auth>=1.4.1->earthengine-api>=1.0.0->geemap) (0.6.1)
Requirement already satisfied: notebook>=4.4.1 in /usr/local/lib/python3.12/dist-packages (from widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.5.7)
Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.12/dist-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.8.5)
Requirement already satisfied: entrypoints in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.4)
Requirement already satisfied: jupyter-core>=4.9.2 in /usr/local/lib/python3.12/dist-packages (from jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.9.1)
Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0)
Requirement already satisfied: nbformat in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (5.10.4)
Requirement already satisfied: nbconvert>=5 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.16.6)
Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0)
Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.18.1)
Requirement already satisfied: prometheus-client in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.23.1)
Requirement already satisfied: nbclassic>=0.4.7 in /usr/local/lib/python3.12/dist-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.3)
Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.12/dist-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.0)
Requirement already satisfied: wcwidth in /usr/local/lib/python3.12/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=4.0.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.14)
Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.12/dist-packages (from jupyter-core>=4.9.2->jupyter-client>=6.1.12->ipykernel>=4.5.1->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.5.1)
Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.12/dist-packages (from nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.2.4)
Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.13.5)
Requirement already satisfied: bleach!=5.0.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.3.0)
Requirement already satisfied: defusedxml in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.7.1)
Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.3.0)
Requirement already satisfied: mistune<4,>=2.0.3 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.2.0)
Requirement already satisfied: nbclient>=0.5.0 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.10.4)
Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.12/dist-packages (from nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1)
Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.21.2)
Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.12/dist-packages (from nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.26.0)
Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.12/dist-packages (from argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.1.0)
Requirement already satisfied: webencodings in /usr/local/lib/python3.12/dist-packages (from bleach!=5.0.0->bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.1)
Requirement already satisfied: tinycss2<1.5,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from bleach[css]!=5.0.0->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0)
Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.4.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2025.9.1)
Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.37.0)
Requirement already satisfied: rpds-py>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema>=2.6->nbformat->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.30.0)
Requirement already satisfied: jupyter-server<3,>=1.8 in /usr/local/lib/python3.12/dist-packages (from notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.14.0)
Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.0.0)
Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.8.1)
Requirement already satisfied: typing-extensions>=4.0.0 in /usr/local/lib/python3.12/dist-packages (from beautifulsoup4->nbconvert>=5->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.15.0)
Requirement already satisfied: pycparser in /usr/local/lib/python3.12/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (2.23)
Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.12.1)
Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.12.0)
Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.5.3)
Requirement already satisfied: overrides>=5.0 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (7.7.0)
Requirement already satisfied: websocket-client>=1.7 in /usr/local/lib/python3.12/dist-packages (from jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.9.0)
Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (4.0.0)
Requirement already satisfied: pyyaml>=5.3 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (6.0.3)
Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.4)
Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.12/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (0.1.1)
Requirement already satisfied: fqdn in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.5.1)
Requirement already satisfied: isoduration in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (20.11.0)
Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (3.0.0)
Requirement already satisfied: rfc3987-syntax>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.1.0)
Requirement already satisfied: uri-template in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.0)
Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.12/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (25.10.0)
Requirement already satisfied: lark>=1.2.2 in /usr/local/lib/python3.12/dist-packages (from rfc3987-syntax>=1.1.0->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.3.1)
Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.12/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.8->notebook-shim>=0.2.3->nbclassic>=0.4.7->notebook>=4.4.1->widgetsnbextension~=3.6.0->ipywidgets->ipyfilechooser>=0.6.0->geemap) (1.4.0)
Downloading jedi-0.19.2-py2.py3-none-any.whl (1.6 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.6/1.6 MB 86.8 MB/s eta 0:00:00
Installing collected packages: jedi
Successfully installed jedi-0.19.2
Mounted at /content/drive
 Earth Engine Authenticated and Initialized with YOUR project: local-dialect-484618-b9
 Starting Prithvi Data Generation.
   Billing Project: local-dialect-484618-b9
   Reading Asset:   projects/satmae-2026/assets/Punjab_Mask_2024_NEW
   Downloading Mask...
/usr/local/lib/python3.12/dist-packages/geemap/common.py:12471: FutureWarning: 'BaseImage' is deprecated and will be removed in a future release.  Please use the 'ee.Image.gd' accessor instead.
  img = gd.download.BaseImage(image)
...tmae-2026/assets/Punjab_Mask_2024_NEW:   0%|          |0/585 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
/usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'projects/satmae-2026/assets/Punjab_Mask_2024_NEW'.
  return STACClient().get(self.id)
   ROI Subset Defined: 891x891
   Downloading Time Step 1/3: 2024-10-15 to 2024-11-15...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
/usr/local/lib/python3.12/dist-packages/geedim/image.py:254: RuntimeWarning: Couldn't find STAC entry for: 'None'.
  return STACClient().get(self.id)
   Downloading Time Step 2/3: 2025-01-01 to 2025-01-31...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
   Downloading Time Step 3/3: 2025-02-15 to 2025-03-15...
  0%|          |0/12 tiles [00:00<?]
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: earthengine.googleapis.com. Connection pool size: 10
   Tiling patches...
 Dataset Generated. Shape: (9, 6, 3, 224, 224)
 Saving .npy files to Google Drive...
 DONE! Proceed to Training Cell.
In [ ]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset

SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'

CONFIG = {
    "EPOCHS": 5000,
    "PATIENCE": 200,
    "BATCH_SIZE": 8,
    "LEARNING_RATE": 1e-4,
    "WEIGHT_DECAY": 0.05,
    "WARMUP_EPOCHS": 20,        # Linear Warmup
    "SWA_START_EPOCH": None,    # Dynamic
    "RESUME": True,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    "IMG_SIZE": 224,
    "NUM_FRAMES": 3,
    "IN_CHANS": 6,
    "EMBED_DIM": 768,
    "DEPTH": 12,
    "NUM_HEADS": 12,
    "SAVE_DIR": SAVE_DIR
}

class PunjabWheatDataset(Dataset):
    def __init__(self, x_path, y_path):
        self.data = np.load(x_path, mmap_mode='r')
        self.masks = np.load(y_path, mmap_mode='r')

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        mask = self.masks[idx]
        return torch.from_numpy(img.copy()).float(), torch.from_numpy(mask.copy()).float()
In [ ]:
!pip install segmentation-models-pytorch
Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Requirement already satisfied: huggingface-hub>=0.24 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.36.0)
Requirement already satisfied: numpy>=1.19.3 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.0.2)
Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (11.3.0)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.7.0)
Requirement already satisfied: timm>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (1.0.24)
Requirement already satisfied: torch>=1.8 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (2.9.0+cu126)
Requirement already satisfied: torchvision>=0.9 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (0.24.0+cu126)
Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.12/dist-packages (from segmentation-models-pytorch) (4.67.1)
Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (3.20.2)
Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2025.3.0)
Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (25.0)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (6.0.3)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (2.32.4)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (4.15.0)
Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.24->segmentation-models-pytorch) (1.2.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (75.2.0)
Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.14.0)
Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.6.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.1.6)
Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77)
Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77)
Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.80)
Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (9.10.2.21)
Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.4.1)
Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.3.0.4)
Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (10.3.7.77)
Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (11.7.1.2)
Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.5.4.2)
Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (0.7.1)
Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (2.27.5)
Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.3.20)
Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.77)
Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (12.6.85)
Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (1.11.1.6)
Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.8->segmentation-models-pytorch) (3.5.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.8->segmentation-models-pytorch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.8->segmentation-models-pytorch) (3.0.3)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.4.4)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (3.11)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2.5.0)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface-hub>=0.24->segmentation-models-pytorch) (2026.1.4)
Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl (154 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 154.8/154.8 kB 12.7 MB/s eta 0:00:00
Installing collected packages: segmentation-models-pytorch
Successfully installed segmentation-models-pytorch-0.5.0
In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import distance_transform_edt as distance
import segmentation_models_pytorch as smp

# --- 1. HEAVY TRAINING AUGMENTATION ---
def apply_training_augmentation(x, y):
    # Random Horizontal Flip
    if np.random.rand() > 0.5:
        x = torch.flip(x, [4])
        y = torch.flip(y, [3])

    # Random Vertical Flip
    if np.random.rand() > 0.5:
        x = torch.flip(x, [3])
        y = torch.flip(y, [2])

    # Random Rotation (0, 90, 180, 270)
    k = np.random.randint(0, 4)
    x = torch.rot90(x, k, [3, 4])
    y = torch.rot90(y, k, [2, 3])

    # Random Intensity Scale (Brightness/Contrast jitter)
    if np.random.rand() > 0.5:
        noise = (torch.rand(x.shape[0], 1, 1, 1, 1, device=x.device) * 0.2) + 0.9
        x = x * noise

    # --- FIX: Force contiguous memory layout ---
    return x.contiguous(), y.contiguous()

# --- 2. VALIDATION TTA ---
class TestTimeAugmentation(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        pred_orig = self.model(x)

        x_hflip = torch.flip(x, [4])
        pred_hflip = torch.flip(self.model(x_hflip), [3])

        x_vflip = torch.flip(x, [3])
        pred_vflip = torch.flip(self.model(x_vflip), [2])

        x_rot = torch.rot90(x, 1, [3, 4])
        pred_rot = torch.rot90(self.model(x_rot), -1, [2, 3])

        return torch.stack([pred_orig, pred_hflip, pred_vflip, pred_rot]).mean(dim=0)

def validate_with_tta(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    tta_model = TestTimeAugmentation(model)

    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = tta_model(x)
            loss = criterion(preds, y)
            total_loss += loss.item()
    return total_loss / len(loader)

# --- 3. LOSS FUNCTIONS (Dice 0.7 + Hausdorff 0.3) ---
class HausdorffDTLoss(nn.Module):
    def __init__(self, alpha=2.0):
        super().__init__()
        self.alpha = alpha

    def forward(self, pred, gt):
        # Hausdorff Distance Transform (CPU based)
        with torch.no_grad():
            gt_np = gt.cpu().numpy()
            dist_map = np.zeros_like(gt_np)
            for i in range(len(gt_np)):
                mask = (gt_np[i, 0] > 0.5).astype(np.uint8)
                if mask.sum() == 0:
                    # If empty mask, distance is max everywhere
                    dist_map[i, 0] = np.ones_like(mask) * 100.0
                    continue
                d_in = distance(mask)
                d_out = distance(1 - mask)
                dist_map[i, 0] = (d_out - d_in)

            dist_map = torch.tensor(dist_map, device=pred.device, dtype=torch.float32)

        probs = torch.sigmoid(pred)
        # Weighted MSE based on distance map
        return torch.mean((probs - gt) ** 2 * (1 + self.alpha * torch.abs(dist_map)))

class CompoundLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dice = smp.losses.DiceLoss(mode='binary', from_logits=True)
        self.hausdorff = HausdorffDTLoss(alpha=2.0)

    def forward(self, p, t):
        # 0.7 Dice + 0.3 Hausdorff
        return 0.7 * self.dice(p, t) + 0.3 * self.hausdorff(p, t)
In [ ]:
import numpy as np
import torch
import torch.nn as nn

#  CORRECT 3D POSITIONAL EMBEDDING
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float32)
    omega /= embed_dim / 2.
    omega = 1. / 10000**omega

    pos = pos.reshape(-1)
    out = np.einsum('m,d->md', pos, omega)

    emb_sin = np.sin(out)
    emb_cos = np.cos(out)
    emb = np.concatenate([emb_sin, emb_cos], axis=1)
    return emb

def get_3d_sincos_pos_embed(embed_dim, grid_size, t_size):
    # 1. Height Embedding
    grid_h = np.arange(grid_size, dtype=np.float32)
    pos_embed_h = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_h)

    # 2. Width Embedding
    grid_w = np.arange(grid_size, dtype=np.float32)
    pos_embed_w = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_w)

    # 3. Time Embedding
    grid_t = np.arange(t_size, dtype=np.float32)
    pos_embed_t = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)

    # 4. Broadcast and Sum: Final = T + H + W
    pos_t = pos_embed_t[:, np.newaxis, np.newaxis, :]
    pos_h = pos_embed_h[np.newaxis, :, np.newaxis, :]
    pos_w = pos_embed_w[np.newaxis, np.newaxis, :, :]

    pos_embed = pos_t + pos_h + pos_w

    # Flatten to (T*H*W, D) to match Transformer input
    pos_embed = pos_embed.reshape(-1, embed_dim)

    return torch.from_numpy(pos_embed).float().unsqueeze(0)

class PatchEmbed3D(nn.Module):
    def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv3d(
            in_chans, embed_dim,
            kernel_size=(1, patch_size, patch_size),
            stride=(1, patch_size, patch_size)
        )
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class PrithviScratch(nn.Module):
    def __init__(self):
        super().__init__()

        self.patch_embed = PatchEmbed3D(
            frames=CONFIG["NUM_FRAMES"],
            in_chans=CONFIG["IN_CHANS"],
            embed_dim=CONFIG["EMBED_DIM"]
        )

        self.register_buffer(
            "pos_embed",
            get_3d_sincos_pos_embed(
                CONFIG["EMBED_DIM"],
                CONFIG["IMG_SIZE"] // 16,
                CONFIG["NUM_FRAMES"]
            )
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=CONFIG["EMBED_DIM"], nhead=CONFIG["NUM_HEADS"],
            dim_feedforward=CONFIG["EMBED_DIM"]*4, dropout=0.1,
            activation='gelu', batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=CONFIG["DEPTH"])

        in_dim = CONFIG["EMBED_DIM"] * CONFIG["NUM_FRAMES"]
        self.decoder = nn.Sequential(
            nn.Conv2d(in_dim, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),

            nn.Conv2d(64, 1, kernel_size=1)
        )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None: nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.encoder(x)

        H_p = CONFIG["IMG_SIZE"] // 16
        x = x.transpose(1, 2).view(B, CONFIG["EMBED_DIM"], CONFIG["NUM_FRAMES"], H_p, H_p)
        x = x.reshape(B, -1, H_p, H_p)

        return self.decoder(x)
In [ ]:
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR

train_x = os.path.join(SAVE_DIR, 'train_x.npy')
train_y = os.path.join(SAVE_DIR, 'train_y.npy')
val_x = os.path.join(SAVE_DIR, 'val_x.npy')
val_y = os.path.join(SAVE_DIR, 'val_y.npy')

if not os.path.exists(train_x):
    raise FileNotFoundError(f"Data not found in {SAVE_DIR}. Run Cell 1 first.")

print("Loading Data...")
train_loader = DataLoader(PunjabWheatDataset(train_x, train_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=True, num_workers=2)
val_loader = DataLoader(PunjabWheatDataset(val_x, val_y), batch_size=CONFIG["BATCH_SIZE"], shuffle=False, num_workers=2)

model = PrithviScratch().to(CONFIG["DEVICE"])
criterion = CompoundLoss().to(CONFIG["DEVICE"])
optimizer = optim.AdamW(model.parameters(), lr=CONFIG["LEARNING_RATE"], weight_decay=CONFIG["WEIGHT_DECAY"])

# --- SCHEDULER: WARMUP + COSINE DECAY ---
scheduler_warmup = LinearLR(optimizer, start_factor=0.01, total_iters=CONFIG["WARMUP_EPOCHS"])
scheduler_cosine = CosineAnnealingLR(optimizer, T_max=CONFIG["EPOCHS"] - CONFIG["WARMUP_EPOCHS"], eta_min=1e-6)
scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_cosine], milestones=[CONFIG["WARMUP_EPOCHS"]])

# --- DYNAMIC SWA ---
swa_model = AveragedModel(model)
swa_scheduler = SWALR(optimizer, swa_lr=5e-5)
swa_active = False

checkpoint_path = os.path.join(CONFIG["SAVE_DIR"], "checkpoint.pth")
start_epoch = 0
best_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': []}

if CONFIG["RESUME"] and os.path.exists(checkpoint_path):
    print("Found checkpoint. Attempting to load...")
    try:
        ckpt = torch.load(checkpoint_path)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        start_epoch = ckpt['epoch'] + 1
        best_loss = ckpt.get('best_loss', float('inf'))
        history = ckpt.get('history', {'train_loss': [], 'val_loss': []})

        if 'swa_model' in ckpt:
            swa_model.load_state_dict(ckpt['swa_model'])
        if 'swa_active' in ckpt:
            swa_active = ckpt['swa_active']

        print(f"Successfully resumed from Epoch {start_epoch} (SWA: {swa_active})")
    except Exception as e:
        print(f"Checkpoint corrupted ({e}). Starting from Epoch 0.")
        start_epoch = 0

print("Starting Training...")
for epoch in range(start_epoch, CONFIG["EPOCHS"]):
    model.train()
    train_loss = 0

    for x, y in train_loader:
        x, y = x.to(CONFIG["DEVICE"]), y.to(CONFIG["DEVICE"])

        # Heavy Augmentation
        x, y = apply_training_augmentation(x, y)

        optimizer.zero_grad()
        preds = model(x)

        # FIX: Ensure tensors are contiguous before Loss
        loss = criterion(preds, y.contiguous())

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    # Validation with TTA
    avg_val_loss = validate_with_tta(model, val_loader, criterion, CONFIG["DEVICE"])

    history['train_loss'].append(avg_train_loss)
    history['val_loss'].append(avg_val_loss)

    # --- DYNAMIC SWA LOGIC ---
    SWA_TRIGGER = CONFIG["PATIENCE"] // 2

    if patience_counter >= SWA_TRIGGER and not swa_active:
        print(f"Stagnation Detected (Patience {patience_counter}). Activating SWA.")
        swa_active = True
        swa_model.update_parameters(model)

    if swa_active:
        swa_model.update_parameters(model)
        swa_scheduler.step()
    else:
        scheduler.step()

    checkpoint_dict = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
        'swa_model': swa_model.state_dict(),
        'history': history,
        'swa_active': swa_active
    }
    torch.save(checkpoint_dict, checkpoint_path)

    if epoch % 10 == 0:
        backup_path = os.path.join(CONFIG["SAVE_DIR"], f"checkpoint_epoch_{epoch}.pth")
        torch.save(checkpoint_dict, backup_path)

    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "best_model.pth"))
        patience_counter = 0
        print(f"New Best Model! Loss: {best_loss:.4f}")
    else:
        patience_counter += 1

    if patience_counter >= CONFIG["PATIENCE"]:
        print(f"Early Stopping at Epoch {epoch}")
        if swa_active:
            torch.save(swa_model.state_dict(), os.path.join(CONFIG["SAVE_DIR"], "swa_model_final.pth"))
        break

    if epoch % 5 == 0:
        mode = "SWA" if swa_active else "STD"
        print(f"Ep {epoch} [{mode}] | Train: {avg_train_loss:.4f} | Val: {avg_val_loss:.4f}")
Loading Data...
Starting Training...
New Best Model! Loss: 1.4913
Ep 0 [STD] | Train: 1.4160 | Val: 1.4913
New Best Model! Loss: 1.2440
New Best Model! Loss: 0.8907
Ep 5 [STD] | Train: 1.0207 | Val: 0.8907
New Best Model! Loss: 0.8533
Ep 10 [STD] | Train: 0.7532 | Val: 0.8762
Ep 15 [STD] | Train: 0.7301 | Val: 0.8679
New Best Model! Loss: 0.8505
New Best Model! Loss: 0.7767
New Best Model! Loss: 0.6073
Ep 20 [STD] | Train: 0.7010 | Val: 0.6073
Ep 25 [STD] | Train: 0.4928 | Val: 0.8917
Ep 30 [STD] | Train: 0.4096 | Val: 0.8945
New Best Model! Loss: 0.5724
New Best Model! Loss: 0.5573
Ep 35 [STD] | Train: 0.3281 | Val: 0.5573
New Best Model! Loss: 0.5204
New Best Model! Loss: 0.4516
New Best Model! Loss: 0.3923
New Best Model! Loss: 0.3576
New Best Model! Loss: 0.3575
Ep 40 [STD] | Train: 0.3195 | Val: 0.3575
New Best Model! Loss: 0.3457
New Best Model! Loss: 0.3014
New Best Model! Loss: 0.2913
Ep 45 [STD] | Train: 0.2988 | Val: 0.2972
Ep 50 [STD] | Train: 0.2828 | Val: 0.3110
New Best Model! Loss: 0.2756
New Best Model! Loss: 0.2655
New Best Model! Loss: 0.2621
New Best Model! Loss: 0.2613
Ep 55 [STD] | Train: 0.2701 | Val: 0.2613
Ep 60 [STD] | Train: 0.2748 | Val: 0.2653
New Best Model! Loss: 0.2567
New Best Model! Loss: 0.2563
Ep 65 [STD] | Train: 0.2697 | Val: 0.2563
New Best Model! Loss: 0.2558
Ep 70 [STD] | Train: 0.2699 | Val: 0.2564
New Best Model! Loss: 0.2539
Ep 75 [STD] | Train: 0.2644 | Val: 0.2539
Ep 80 [STD] | Train: 0.2602 | Val: 0.2589
New Best Model! Loss: 0.2536
New Best Model! Loss: 0.2522
New Best Model! Loss: 0.2517
New Best Model! Loss: 0.2502
Ep 85 [STD] | Train: 0.2627 | Val: 0.2502
New Best Model! Loss: 0.2487
Ep 90 [STD] | Train: 0.2536 | Val: 0.2597
Ep 95 [STD] | Train: 0.2594 | Val: 0.2517
New Best Model! Loss: 0.2474
Ep 100 [STD] | Train: 0.2469 | Val: 0.2514
Ep 105 [STD] | Train: 0.2486 | Val: 0.2518
New Best Model! Loss: 0.2471
New Best Model! Loss: 0.2461
New Best Model! Loss: 0.2449
New Best Model! Loss: 0.2426
Ep 110 [STD] | Train: 0.2565 | Val: 0.2426
New Best Model! Loss: 0.2423
Ep 115 [STD] | Train: 0.2402 | Val: 0.2498
Ep 120 [STD] | Train: 0.2469 | Val: 0.2491
Ep 125 [STD] | Train: 0.2452 | Val: 0.2521
Ep 130 [STD] | Train: 0.2408 | Val: 0.2504
Ep 135 [STD] | Train: 0.2365 | Val: 0.2515
Ep 140 [STD] | Train: 0.2506 | Val: 0.2509
Ep 145 [STD] | Train: 0.2356 | Val: 0.2450
Ep 150 [STD] | Train: 0.2304 | Val: 0.2486
Ep 155 [STD] | Train: 0.2301 | Val: 0.2494
Ep 160 [STD] | Train: 0.2270 | Val: 0.2480
Ep 165 [STD] | Train: 0.2268 | Val: 0.2476
Ep 170 [STD] | Train: 0.2314 | Val: 0.2550
Ep 175 [STD] | Train: 0.2247 | Val: 0.2519
Ep 180 [STD] | Train: 0.2059 | Val: 0.2533
Ep 185 [STD] | Train: 0.2057 | Val: 0.2474
Ep 190 [STD] | Train: 0.2229 | Val: 0.2552
Ep 195 [STD] | Train: 0.2053 | Val: 0.2520
Ep 200 [STD] | Train: 0.2148 | Val: 0.2574
Ep 205 [STD] | Train: 0.2072 | Val: 0.2529
Ep 210 [STD] | Train: 0.2096 | Val: 0.2477
Stagnation Detected (Patience 100). Activating SWA.
Ep 215 [SWA] | Train: 0.2010 | Val: 0.2515
Ep 220 [SWA] | Train: 0.2073 | Val: 0.2522
Ep 225 [SWA] | Train: 0.2085 | Val: 0.2508
Ep 230 [SWA] | Train: 0.1971 | Val: 0.2531
Ep 235 [SWA] | Train: 0.2049 | Val: 0.2537
Ep 240 [SWA] | Train: 0.1915 | Val: 0.2501
Ep 245 [SWA] | Train: 0.1790 | Val: 0.2558
Ep 250 [SWA] | Train: 0.1950 | Val: 0.2504
Ep 255 [SWA] | Train: 0.1838 | Val: 0.2571
Ep 260 [SWA] | Train: 0.1806 | Val: 0.2510
Ep 265 [SWA] | Train: 0.1949 | Val: 0.2492
Ep 270 [SWA] | Train: 0.1942 | Val: 0.2477
Ep 275 [SWA] | Train: 0.1859 | Val: 0.2481
Ep 280 [SWA] | Train: 0.1840 | Val: 0.2563
Ep 285 [SWA] | Train: 0.1765 | Val: 0.2496
Ep 290 [SWA] | Train: 0.1741 | Val: 0.2524
Ep 295 [SWA] | Train: 0.1721 | Val: 0.2519
Ep 300 [SWA] | Train: 0.1744 | Val: 0.2473
Ep 305 [SWA] | Train: 0.1691 | Val: 0.2512
Ep 310 [SWA] | Train: 0.1663 | Val: 0.2504
Early Stopping at Epoch 311
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import os
import json
from sklearn.metrics import confusion_matrix

# --- 1. CONFIGURATION ---
SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Auto-detect best model in the Scratch Directory
if os.path.exists(os.path.join(SAVE_DIR, "swa_model_final.pth")):
    CHECKPOINT_PATH = os.path.join(SAVE_DIR, "swa_model_final.pth")
    MODEL_TYPE = "SWA (Best Generalization)"
elif os.path.exists(os.path.join(SAVE_DIR, "best_model.pth")):
    CHECKPOINT_PATH = os.path.join(SAVE_DIR, "best_model.pth")
    MODEL_TYPE = "Best Valid Loss"
else:
    CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint.pth")
    MODEL_TYPE = "Latest Epoch"

print(f" EVALUATING SCRATCH MODEL: {MODEL_TYPE}")
print(f" Path: {CHECKPOINT_PATH}")

# --- 2. MODEL DEFINITION (Prithvi Scratch Architecture) ---
class PatchEmbed3D(nn.Module):
    def __init__(self, patch_size=16, frames=3, in_chans=6, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
    def forward(self, x):
        return self.proj(x).flatten(2).transpose(1, 2)

class PrithviScratch(nn.Module):
    def __init__(self):
        super().__init__()
        # Standard Prithvi Config
        self.patch_embed = PatchEmbed3D(patch_size=16, frames=3, in_chans=6, embed_dim=768)
        # Pos Embed (Fixed Size)
        self.register_buffer("pos_embed", torch.zeros(1, 3 * (224//16)**2, 768))

        # Encoder (12 Layers)
        encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=12, dim_feedforward=768*4, dropout=0.1, activation='gelu', batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=12)

        # Decoder (U-Net Style)
        in_dim = 768 * 3
        self.decoder = nn.Sequential(
            nn.Conv2d(in_dim, 512, 3, padding=1), nn.BatchNorm2d(512), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(512, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, 1, 1)
        )

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        # Robust addition for Pos Embed
        if self.pos_embed.shape[1] == x.shape[1]:
            x = x + self.pos_embed
        x = self.encoder(x)

        # Reshape for Decoder
        H_p = 14 # 224 / 16
        x = x.transpose(1, 2).view(B, 768, 3, H_p, H_p)
        x = x.reshape(B, -1, H_p, H_p)
        return self.decoder(x)

# --- 3. HELPER FUNCTIONS ---
def plot_training_curves(ckpt_path):
    if not os.path.exists(ckpt_path):
        print(" History file not found.")
        return

    # Load history from checkpoint.pth (always contains full history)
    hist_path = os.path.join(SAVE_DIR, "checkpoint.pth")
    ckpt = torch.load(hist_path, map_location='cpu')
    history = ckpt.get('history', None)

    if history is None:
        print(" No training history found.")
        return

    plt.figure(figsize=(12, 5))
    plt.plot(history['train_loss'], 'b-', label='Training Loss', alpha=0.7)
    plt.plot(history['val_loss'], 'r-', label='Validation Loss', linewidth=2)

    # Mark Min Val Loss
    val_loss = history['val_loss']
    if len(val_loss) > 0:
        min_loss = min(val_loss)
        min_epoch = val_loss.index(min_loss)
        plt.scatter(min_epoch, min_loss, c='red', zorder=5)
        plt.text(min_epoch, min_loss, f" Best: {min_loss:.4f}", verticalalignment='bottom')

    plt.title('Prithvi Scratch: Training Progress')
    plt.xlabel('Epochs'); plt.ylabel('Loss'); plt.legend(); plt.grid(True, linestyle='--')
    plt.show()

def visualize_predictions(model, loader, device, num_samples=3):
    model.eval()
    try:
        x, y = next(iter(loader))
    except StopIteration:
        print(" Loader is empty.")
        return

    x, y = x.to(device), y.to(device)

    with torch.no_grad():
        preds = (torch.sigmoid(model(x)) > 0.5).float().cpu().numpy()

    x_np, y_np = x.cpu().numpy(), y.cpu().numpy()

    # SAFETY: Ensure we don't index out of bounds if batch size < num_samples
    actual_samples = min(num_samples, len(x_np))

    fig, axs = plt.subplots(actual_samples, 3, figsize=(15, 5 * actual_samples))
    plt.suptitle("Prithvi Scratch Results: Input vs Truth vs Prediction", fontsize=16)

    for i in range(actual_samples):
        # Create FCC: NIR(3), Red(2), Green(1) from Time Step 1
        img = np.stack([x_np[i,3,1], x_np[i,2,1], x_np[i,1,1]], axis=2)
        # Normalize for Display
        img = (img - img.min()) / (img.max() - img.min() + 1e-6)

        # Handle axes if only 1 sample
        ax = axs[i] if actual_samples > 1 else axs

        ax[0].imshow(img)
        ax[0].set_title("Satellite Input (NIR-R-G)")

        ax[1].imshow(y_np[i,0], cmap='gray')
        ax[1].set_title("Ground Truth")

        ax[2].imshow(preds[i,0], cmap='gray')
        ax[2].set_title("Prediction")

        for a in ax: a.axis('off')

    plt.tight_layout()
    plt.show()

# --- 4. EXECUTION ---
# A. Plot Curves
plot_training_curves(CHECKPOINT_PATH)

# B. Load Model & Visualize
model = PrithviScratch().to(DEVICE)
print(f"Loading weights...")

if os.path.exists(CHECKPOINT_PATH):
    ckpt = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt

    # Robust Key Fixer (Handles 'blocks' vs 'encoder.layers')
    new_state_dict = {}
    for k, v in state_dict.items():
        k = k.replace('blocks', 'encoder.layers')
        if 'pos_embed' in k and v.shape != model.pos_embed.shape: continue
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict, strict=False)
    print(" Weights Loaded Successfully.")

    # C. Run Viz
    if 'val_loader' in locals():
        visualize_predictions(model, val_loader, DEVICE, num_samples=3)
    else:
        print(" 'val_loader' not found. Please run the Dataset Loading cell first.")
else:
    print(f" Checkpoint not found at {CHECKPOINT_PATH}")
📊 EVALUATING SCRATCH MODEL: SWA (Best Generalization)
📂 Path: /content/drive/MyDrive/Prithvi_Scratch_Results_1/swa_model_final.pth
No description has been provided for this image
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True
  warnings.warn(
Loading weights...
✅ Weights Loaded Successfully.
No description has been provided for this image
In [ ]:
import time
import torch
import numpy as np
import os
import json
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import confusion_matrix
from scipy.ndimage import binary_dilation

# --- 1. CONFIGURATION ---
# Using the path defined in Cell 7 (or manually set here)
if 'CHECKPOINT_PATH' not in locals():
    # Fallback if Cell 7 wasn't run
    SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
    CHECKPOINT_PATH = os.path.join(SAVE_DIR, "swa_model_final.pth")
    if not os.path.exists(CHECKPOINT_PATH):
        CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint.pth")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f" STARTING DEEP EVALUATION")
print(f" Model: {CHECKPOINT_PATH}")

# --- 2. METRIC FUNCTIONS ---
def compute_boundary_iou(pred_mask, gt_mask, dilation=2):
    """
    Computes IoU specifically along the edges of the fields.
    Crucial for agricultural segmentation where boundaries matter.
    """
    # Get edges by dilating and subtracting original
    pred_edges = binary_dilation(pred_mask, iterations=dilation) ^ pred_mask
    gt_edges = binary_dilation(gt_mask, iterations=dilation) ^ gt_mask

    intersection = (pred_edges & gt_edges).sum()
    union = (pred_edges | gt_edges).sum()

    if union == 0: return 1.0 # Perfect match (both empty edges)
    return intersection / union

def compute_hausdorff_distance(pred_mask, gt_mask):
    """
    Computes 95th Percentile Hausdorff Distance.
    Measures the worst-case distance between prediction and ground truth contours.
    Lower is better.
    """
    if pred_mask.sum() == 0 or gt_mask.sum() == 0:
        return 0.0 # Handle empty masks safely

    # Get coordinates of all 'True' pixels
    pred_coords = np.argwhere(pred_mask)
    gt_coords = np.argwhere(gt_mask)

    # Calculate directed distances
    d_forward = directed_hausdorff(pred_coords, gt_coords)[0]
    d_backward = directed_hausdorff(gt_coords, pred_coords)[0]

    return max(d_forward, d_backward)

# --- 3. EVALUATION LOOP ---
def evaluate_model_depth(loader, model_path):
    # A. Load Model
    # Ensure PrithviScratch class is defined (from Cell 7)
    model = PrithviScratch().to(DEVICE)

    if not os.path.exists(model_path):
        print(f" Error: Model file not found at {model_path}")
        return None

    ckpt = torch.load(model_path, map_location=DEVICE)
    state_dict = ckpt['model_state_dict'] if 'model_state_dict' in ckpt else ckpt

    # Auto-Fix Keys
    new_state_dict = {}
    for k, v in state_dict.items():
        k = k.replace('blocks', 'encoder.layers')
        if 'pos_embed' in k and v.shape != model.pos_embed.shape: continue
        new_state_dict[k] = v

    model.load_state_dict(new_state_dict, strict=False)
    model.eval()

    # B. Initialize Accumulators
    tp_total, fp_total, fn_total, tn_total = 0, 0, 0, 0
    boundary_ious = []
    hausdorff_dists = []
    inference_times = []

    print(" Processing Validation Set...", end="")

    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            x = x.to(DEVICE)
            y_true = y.cpu().numpy() # (B, 1, H, W)

            # 1. Inference Speed Test
            start_time = time.time()
            preds = model(x)
            if torch.cuda.is_available(): torch.cuda.synchronize()
            end_time = time.time()
            inference_times.append((end_time - start_time) / x.shape[0]) # Time per image

            # 2. Thresholding
            probs = torch.sigmoid(preds)
            y_pred = (probs > 0.5).float().cpu().numpy()

            # 3. Global Pixel Metrics (Confusion Matrix)
            # Flatten arrays to 1D for fast calculation
            y_true_flat = y_true.flatten().astype(int)
            y_pred_flat = y_pred.flatten().astype(int)

            tn, fp, fn, tp = confusion_matrix(y_true_flat, y_pred_flat, labels=[0, 1]).ravel()
            tp_total += tp
            fp_total += fp
            fn_total += fn
            tn_total += tn

            # 4. Shape Metrics (Per Image)
            for b in range(x.shape[0]):
                p_m = y_pred[b, 0].astype(bool)
                g_m = y_true[b, 0].astype(bool)

                # Boundary IoU
                boundary_ious.append(compute_boundary_iou(p_m, g_m))

                # Hausdorff (Only if both masks are non-empty)
                if p_m.sum() > 0 and g_m.sum() > 0:
                    hausdorff_dists.append(compute_hausdorff_distance(p_m, g_m))

            # Progress Bar
            if i % 10 == 0: print(".", end="")

    print("\n Evaluation Complete.")

    # C. Calculate Final Scores
    epsilon = 1e-6

    # Pixel Scores
    pixel_acc = (tp_total + tn_total) / (tp_total + tn_total + fp_total + fn_total + epsilon)
    iou = tp_total / (tp_total + fp_total + fn_total + epsilon)
    precision = tp_total / (tp_total + fp_total + epsilon)
    recall = tp_total / (tp_total + fn_total + epsilon)
    f1 = 2 * (precision * recall) / (precision + recall + epsilon)

    # Shape Scores
    avg_boundary_iou = np.mean(boundary_ious) if boundary_ious else 0.0
    avg_hausdorff = np.mean(hausdorff_dists) if hausdorff_dists else 0.0

    # Performance
    fps = 1.0 / np.mean(inference_times)

    # D. Format Results
    results = {
        "Model": "Prithvi Scratch",
        "Path": model_path,
        "Pixel_Accuracy": round(pixel_acc, 4),
        "IoU (Jaccard)": round(iou, 4),
        "F1-Score (Dice)": round(f1, 4),
        "Precision": round(precision, 4),
        "Recall": round(recall, 4),
        "Boundary_IoU": round(avg_boundary_iou, 4),
        "Hausdorff_Dist_px": round(avg_hausdorff, 2),
        "Inference_Speed_FPS": round(fps, 2)
    }

    print("\n" + "="*40)
    print("  FINAL PERFORMANCE REPORT")
    print("="*40)
    print(json.dumps(results, indent=4))
    print("="*40)

    return results

# --- 4. EXECUTION ---
if 'val_loader' in locals():
    metrics = evaluate_model_depth(val_loader, CHECKPOINT_PATH)
else:
    print(" Error: 'val_loader' is not defined. Please run the Dataset Loading cell first.")
 STARTING DEEP EVALUATION
 Model: /content/drive/MyDrive/Prithvi_Scratch_Results_1/swa_model_final.pth
/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True
  warnings.warn(
 Processing Validation Set....
 Evaluation Complete.

========================================
  FINAL PERFORMANCE REPORT
========================================
{
    "Model": "Prithvi Scratch",
    "Path": "/content/drive/MyDrive/Prithvi_Scratch_Results_1/swa_model_final.pth",
    "Pixel_Accuracy": 0.7088,
    "IoU (Jaccard)": 0.7088,
    "F1-Score (Dice)": 0.8296,
    "Precision": 0.7088,
    "Recall": 1.0,
    "Boundary_IoU": 0.0,
    "Hausdorff_Dist_px": 20.62,
    "Inference_Speed_FPS": 17.17
}
========================================
In [ ]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import os

# --- CONFIGURATION ---
# Ensure this matches your setup
SAVE_DIR = '/content/drive/MyDrive/Prithvi_Scratch_Results_1/'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_PATH = os.path.join(SAVE_DIR, "checkpoint.pth") # or "swa_model_final.pth"

# --- FIXED VISUALIZATION FUNCTION ---
def visualize_robust(model, loader):
    model.eval()

    try:
        x, y = next(iter(loader))
    except StopIteration:
        print(" Error: Validation loader is empty.")
        return

    x, y = x.to(DEVICE), y.to(DEVICE)

    with torch.no_grad():
        preds = torch.sigmoid(model(x))
        masks = (preds > 0.5).float().cpu().numpy()

    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()

    # --- CRITICAL FIX: DYNAMIC BATCH SIZE ---
    # We take the minimum of 3 OR the actual batch size (e.g., 2)
    num_samples = min(3, x_np.shape[0])

    print(f" Visualizing {num_samples} samples (Batch Size: {x_np.shape[0]})")

    # Create subplots dynamically based on actual samples
    fig, axs = plt.subplots(num_samples, 3, figsize=(15, 5 * num_samples))
    plt.suptitle("RGB Input | Ground Truth | Prediction", fontsize=16)

    for i in range(num_samples):
        # Handle axes indexing (Matplotlib returns 1D array if rows=1, 2D if rows>1)
        if num_samples == 1:
            ax = axs
        else:
            ax = axs[i]

        # 1. RGB Composite (Red, Green, Blue)
        # Indices: 2=Red, 1=Green, 0=Blue. Time Step: 1 (Middle)
        rgb = np.stack([x_np[i, 2, 1], x_np[i, 1, 1], x_np[i, 0, 1]], axis=2)

        # Robust Normalization
        p2, p98 = np.percentile(rgb, (2, 98))
        rgb = np.clip((rgb - p2) / (p98 - p2 + 1e-6), 0, 1)

        # 2. Plotting
        ax[0].imshow(rgb)
        ax[0].set_title(f"Sample {i+1} (RGB)")

        ax[1].imshow(y_np[i, 0], cmap='gray')
        ax[1].set_title("Ground Truth")

        ax[2].imshow(masks[i, 0], cmap='gray')
        ax[2].set_title(f"Prediction (Mean: {masks[i,0].mean():.2f})")

        for a in ax: a.axis('off')

    plt.tight_layout()
    plt.show()



if 'model' in locals() and 'val_loader' in locals():
    visualize_robust(model, val_loader)
else:
    print(" Model or Loader missing. Please re-run the 'Recovery & Visualization' cell above.")
🖼️ Visualizing 2 samples (Batch Size: 2)
No description has been provided for this image